Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
8185c307
Commit
8185c307
authored
Oct 24, 2023
by
Sachin Kadyan
Browse files
Just-in-time embedding generation for the SoloSeq model
parent
4c8e3764
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
66 deletions
+101
-66
run_pretrained_openfold.py
run_pretrained_openfold.py
+3
-0
scripts/precompute_embeddings.py
scripts/precompute_embeddings.py
+98
-66
No files found.
run_pretrained_openfold.py
View file @
8185c307
...
@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
...
@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq
,
pad_feature_dict_seq
,
trace_model_
,
trace_model_
,
)
)
from
scripts.precompute_embeddings
import
EmbeddingGenerator
from
scripts.utils
import
add_data_args
from
scripts.utils
import
add_data_args
...
@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
pdb70_database_path
=
args
.
pdb70_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
,
)
)
embedding_generator
=
EmbeddingGenerator
()
embedding_generator
.
run
(
args
.
fasta_dir
,
local_alignment_dir
)
else
:
else
:
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
...
...
scripts/precompute_embeddings.py
View file @
8185c307
...
@@ -58,74 +58,106 @@ class SequenceDataset(object):
...
@@ -58,74 +58,106 @@ class SequenceDataset(object):
_flush_current_buf
()
_flush_current_buf
()
return
batches
return
batches
def
main
(
args
):
labels
=
[]
class
EmbeddingGenerator
:
seqs
=
[]
"""Generates the ESM-1b embeddings for the single sequence model"""
def
__init__
(
self
,
# Generate a single bulk file
toks_per_batch
:
int
=
4096
,
for
f
in
os
.
listdir
(
args
.
fasta_dir
):
truncate
:
bool
=
True
,
f_name
,
ext
=
os
.
path
.
splitext
(
f
)
use_local_esm
:
str
=
None
,
if
ext
!=
'.fasta'
and
ext
!=
'.fa'
:
nogpu
:
bool
=
False
,
logging
.
warning
(
f
"Ignoring non-FASTA file:
{
f
}
"
)
):
continue
self
.
toks_per_batch
=
toks_per_batch
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
f
),
'r'
)
as
infile
:
self
.
truncate
=
truncate
seq
=
infile
.
readlines
()[
1
].
strip
()
self
.
use_local_esm
=
use_local_esm
labels
.
append
(
f_name
)
self
.
nogpu
=
nogpu
seqs
.
append
(
seq
)
# Generate embeddings in bulk
if
self
.
use_local_esm
:
self
.
model
,
self
.
alphabet
=
torch
.
hub
.
load
(
self
.
use_local_esm
,
"esm1b_t33_650M_UR50S"
,
source
=
'local'
)
else
:
self
.
model
,
self
.
alphabet
=
torch
.
hub
.
load
(
"facebookresearch/esm:main"
,
"esm1b_t33_650M_UR50S"
)
if
torch
.
cuda
.
is_available
()
and
not
self
.
nogpu
:
self
.
model
=
self
.
model
.
to
(
device
=
"cuda"
)
lines
=
[]
def
run
(
for
label
,
seq
in
zip
(
labels
,
seqs
):
self
,
lines
+=
f
'>
{
label
}
\n
'
fasta_dir
,
lines
+=
f
'
{
seq
}
\n
'
output_dir
,
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
):
temp_fasta_file
=
os
.
path
.
join
(
args
.
output_dir
,
'temp.fasta'
)
labels
=
[]
with
open
(
temp_fasta_file
,
'w'
)
as
outfile
:
seqs
=
[]
outfile
.
writelines
(
lines
)
# Generate a single bulk file
# Generate embeddings in bulk
for
f
in
os
.
listdir
(
fasta_dir
):
if
args
.
use_local_esm
:
f_name
,
ext
=
os
.
path
.
splitext
(
f
)
model
,
alphabet
=
torch
.
hub
.
load
(
args
.
use_local_esm
,
"esm1b_t33_650M_UR50S"
,
source
=
'local'
)
if
ext
!=
'.fasta'
and
ext
!=
'.fa'
:
else
:
logging
.
warning
(
f
"Ignoring non-FASTA file:
{
f
}
"
)
model
,
alphabet
=
torch
.
hub
.
load
(
"facebookresearch/esm:main"
,
"esm1b_t33_650M_UR50S"
)
continue
if
torch
.
cuda
.
is_available
()
and
not
args
.
nogpu
:
with
open
(
os
.
path
.
join
(
fasta_dir
,
f
),
'r'
)
as
infile
:
model
=
model
.
to
(
device
=
"cuda"
)
seq
=
infile
.
readlines
()[
1
].
strip
()
dataset
=
SequenceDataset
.
from_file
(
temp_fasta_file
)
labels
.
append
(
f_name
)
batches
=
dataset
.
get_batch_indices
(
args
.
toks_per_batch
,
extra_toks_per_seq
=
1
)
seqs
.
append
(
seq
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
collate_fn
=
alphabet
.
get_batch_converter
(),
batch_sampler
=
batches
lines
=
[]
)
for
label
,
seq
in
zip
(
labels
,
seqs
):
logging
.
info
(
"Loaded all sequences"
)
lines
+=
f
'>
{
label
}
\n
'
repr_layers
=
[
33
]
lines
+=
f
'
{
seq
}
\n
'
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
with
torch
.
no_grad
():
temp_fasta_file
=
os
.
path
.
join
(
output_dir
,
'temp.fasta'
)
for
batch_idx
,
(
labels
,
strs
,
toks
)
in
enumerate
(
data_loader
):
with
open
(
temp_fasta_file
,
'w'
)
as
outfile
:
logging
.
info
(
f
"Processing
{
batch_idx
+
1
}
of
{
len
(
batches
)
}
batches (
{
toks
.
size
(
0
)
}
sequences)"
)
outfile
.
writelines
(
lines
)
if
torch
.
cuda
.
is_available
()
and
not
args
.
nogpu
:
toks
=
toks
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
dataset
=
SequenceDataset
.
from_file
(
temp_fasta_file
)
batches
=
dataset
.
get_batch_indices
(
self
.
toks_per_batch
,
extra_toks_per_seq
=
1
)
if
args
.
truncate
:
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
toks
=
toks
[:
1022
]
dataset
,
collate_fn
=
self
.
alphabet
.
get_batch_converter
(),
batch_sampler
=
batches
)
out
=
model
(
toks
,
repr_layers
=
repr_layers
,
return_contacts
=
False
)
logging
.
info
(
"Loaded all sequences"
)
repr_layers
=
[
33
]
logits
=
out
[
"logits"
].
to
(
device
=
"cpu"
)
representations
=
{
with
torch
.
no_grad
():
33
:
out
[
"representations"
][
33
].
to
(
device
=
"cpu"
)
for
batch_idx
,
(
labels
,
strs
,
toks
)
in
enumerate
(
data_loader
):
}
logging
.
info
(
f
"Processing
{
batch_idx
+
1
}
of
{
len
(
batches
)
}
batches (
{
toks
.
size
(
0
)
}
sequences)"
)
if
torch
.
cuda
.
is_available
()
and
not
self
.
nogpu
:
for
i
,
label
in
enumerate
(
labels
):
toks
=
toks
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
args
.
output_dir
,
label
),
exist_ok
=
True
)
result
=
{
"label"
:
label
}
if
self
.
truncate
:
toks
=
toks
[:
1022
]
result
[
"representations"
]
=
{
33
:
representations
[
33
][
i
,
1
:
len
(
strs
[
i
])
+
1
].
clone
()
out
=
self
.
model
(
toks
,
repr_layers
=
repr_layers
,
return_contacts
=
False
)
representations
=
{
33
:
out
[
"representations"
][
33
].
to
(
device
=
"cpu"
)
}
}
torch
.
save
(
result
,
for
i
,
label
in
enumerate
(
labels
):
os
.
path
.
join
(
args
.
output_dir
,
label
,
label
+
".pt"
)
os
.
makedirs
(
os
.
path
.
join
(
output_dir
,
label
),
exist_ok
=
True
)
)
result
=
{
"label"
:
label
}
os
.
remove
(
temp_fasta_file
)
result
[
"representations"
]
=
{
33
:
representations
[
33
][
i
,
1
:
len
(
strs
[
i
])
+
1
].
clone
()
}
torch
.
save
(
result
,
os
.
path
.
join
(
output_dir
,
label
,
label
+
".pt"
)
)
os
.
remove
(
temp_fasta_file
)
def
main
(
args
):
logging
.
info
(
"Loading the model..."
)
embedding_generator
=
EmbeddingGenerator
(
args
.
toks_per_batch
,
args
.
truncate
,
args
.
use_local_esm
,
args
.
nogpu
)
logging
.
info
(
"Loading the sequences and running the inference..."
)
embedding_generator
.
run
(
args
.
fasta_dir
,
args
.
output_dir
)
logging
.
info
(
"Completed."
)
logging
.
info
(
"Completed."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment