Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
8185c307
Commit
8185c307
authored
Oct 24, 2023
by
Sachin Kadyan
Browse files
Just-in-time embedding generation for the SoloSeq model
parent
4c8e3764
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
66 deletions
+101
-66
run_pretrained_openfold.py
run_pretrained_openfold.py
+3
-0
scripts/precompute_embeddings.py
scripts/precompute_embeddings.py
+98
-66
No files found.
run_pretrained_openfold.py
View file @
8185c307
...
@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
...
@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq
,
pad_feature_dict_seq
,
trace_model_
,
trace_model_
,
)
)
from
scripts.precompute_embeddings
import
EmbeddingGenerator
from
scripts.utils
import
add_data_args
from
scripts.utils
import
add_data_args
...
@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
...
@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
pdb70_database_path
=
args
.
pdb70_database_path
,
pdb70_database_path
=
args
.
pdb70_database_path
,
no_cpus
=
args
.
cpus
,
no_cpus
=
args
.
cpus
,
)
)
embedding_generator
=
EmbeddingGenerator
()
embedding_generator
.
run
(
args
.
fasta_dir
,
local_alignment_dir
)
else
:
else
:
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
alignment_runner
=
data_pipeline
.
AlignmentRunner
(
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
jackhmmer_binary_path
=
args
.
jackhmmer_binary_path
,
...
...
scripts/precompute_embeddings.py
View file @
8185c307
...
@@ -58,17 +58,43 @@ class SequenceDataset(object):
...
@@ -58,17 +58,43 @@ class SequenceDataset(object):
_flush_current_buf
()
_flush_current_buf
()
return
batches
return
batches
def
main
(
args
):
class
EmbeddingGenerator
:
"""Generates the ESM-1b embeddings for the single sequence model"""
def
__init__
(
self
,
toks_per_batch
:
int
=
4096
,
truncate
:
bool
=
True
,
use_local_esm
:
str
=
None
,
nogpu
:
bool
=
False
,
):
self
.
toks_per_batch
=
toks_per_batch
self
.
truncate
=
truncate
self
.
use_local_esm
=
use_local_esm
self
.
nogpu
=
nogpu
# Generate embeddings in bulk
if
self
.
use_local_esm
:
self
.
model
,
self
.
alphabet
=
torch
.
hub
.
load
(
self
.
use_local_esm
,
"esm1b_t33_650M_UR50S"
,
source
=
'local'
)
else
:
self
.
model
,
self
.
alphabet
=
torch
.
hub
.
load
(
"facebookresearch/esm:main"
,
"esm1b_t33_650M_UR50S"
)
if
torch
.
cuda
.
is_available
()
and
not
self
.
nogpu
:
self
.
model
=
self
.
model
.
to
(
device
=
"cuda"
)
def
run
(
self
,
fasta_dir
,
output_dir
,
):
labels
=
[]
labels
=
[]
seqs
=
[]
seqs
=
[]
# Generate a single bulk file
# Generate a single bulk file
for
f
in
os
.
listdir
(
args
.
fasta_dir
):
for
f
in
os
.
listdir
(
fasta_dir
):
f_name
,
ext
=
os
.
path
.
splitext
(
f
)
f_name
,
ext
=
os
.
path
.
splitext
(
f
)
if
ext
!=
'.fasta'
and
ext
!=
'.fa'
:
if
ext
!=
'.fasta'
and
ext
!=
'.fa'
:
logging
.
warning
(
f
"Ignoring non-FASTA file:
{
f
}
"
)
logging
.
warning
(
f
"Ignoring non-FASTA file:
{
f
}
"
)
continue
continue
with
open
(
os
.
path
.
join
(
args
.
fasta_dir
,
f
),
'r'
)
as
infile
:
with
open
(
os
.
path
.
join
(
fasta_dir
,
f
),
'r'
)
as
infile
:
seq
=
infile
.
readlines
()[
1
].
strip
()
seq
=
infile
.
readlines
()[
1
].
strip
()
labels
.
append
(
f_name
)
labels
.
append
(
f_name
)
seqs
.
append
(
seq
)
seqs
.
append
(
seq
)
...
@@ -77,22 +103,15 @@ def main(args):
...
@@ -77,22 +103,15 @@ def main(args):
for
label
,
seq
in
zip
(
labels
,
seqs
):
for
label
,
seq
in
zip
(
labels
,
seqs
):
lines
+=
f
'>
{
label
}
\n
'
lines
+=
f
'>
{
label
}
\n
'
lines
+=
f
'
{
seq
}
\n
'
lines
+=
f
'
{
seq
}
\n
'
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
temp_fasta_file
=
os
.
path
.
join
(
args
.
output_dir
,
'temp.fasta'
)
temp_fasta_file
=
os
.
path
.
join
(
output_dir
,
'temp.fasta'
)
with
open
(
temp_fasta_file
,
'w'
)
as
outfile
:
with
open
(
temp_fasta_file
,
'w'
)
as
outfile
:
outfile
.
writelines
(
lines
)
outfile
.
writelines
(
lines
)
# Generate embeddings in bulk
if
args
.
use_local_esm
:
model
,
alphabet
=
torch
.
hub
.
load
(
args
.
use_local_esm
,
"esm1b_t33_650M_UR50S"
,
source
=
'local'
)
else
:
model
,
alphabet
=
torch
.
hub
.
load
(
"facebookresearch/esm:main"
,
"esm1b_t33_650M_UR50S"
)
if
torch
.
cuda
.
is_available
()
and
not
args
.
nogpu
:
model
=
model
.
to
(
device
=
"cuda"
)
dataset
=
SequenceDataset
.
from_file
(
temp_fasta_file
)
dataset
=
SequenceDataset
.
from_file
(
temp_fasta_file
)
batches
=
dataset
.
get_batch_indices
(
args
.
toks_per_batch
,
extra_toks_per_seq
=
1
)
batches
=
dataset
.
get_batch_indices
(
self
.
toks_per_batch
,
extra_toks_per_seq
=
1
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
collate_fn
=
alphabet
.
get_batch_converter
(),
batch_sampler
=
batches
dataset
,
collate_fn
=
self
.
alphabet
.
get_batch_converter
(),
batch_sampler
=
batches
)
)
logging
.
info
(
"Loaded all sequences"
)
logging
.
info
(
"Loaded all sequences"
)
repr_layers
=
[
33
]
repr_layers
=
[
33
]
...
@@ -100,21 +119,20 @@ def main(args):
...
@@ -100,21 +119,20 @@ def main(args):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
batch_idx
,
(
labels
,
strs
,
toks
)
in
enumerate
(
data_loader
):
for
batch_idx
,
(
labels
,
strs
,
toks
)
in
enumerate
(
data_loader
):
logging
.
info
(
f
"Processing
{
batch_idx
+
1
}
of
{
len
(
batches
)
}
batches (
{
toks
.
size
(
0
)
}
sequences)"
)
logging
.
info
(
f
"Processing
{
batch_idx
+
1
}
of
{
len
(
batches
)
}
batches (
{
toks
.
size
(
0
)
}
sequences)"
)
if
torch
.
cuda
.
is_available
()
and
not
args
.
nogpu
:
if
torch
.
cuda
.
is_available
()
and
not
self
.
nogpu
:
toks
=
toks
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
toks
=
toks
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
if
args
.
truncate
:
if
self
.
truncate
:
toks
=
toks
[:
1022
]
toks
=
toks
[:
1022
]
out
=
model
(
toks
,
repr_layers
=
repr_layers
,
return_contacts
=
False
)
out
=
self
.
model
(
toks
,
repr_layers
=
repr_layers
,
return_contacts
=
False
)
logits
=
out
[
"logits"
].
to
(
device
=
"cpu"
)
representations
=
{
representations
=
{
33
:
out
[
"representations"
][
33
].
to
(
device
=
"cpu"
)
33
:
out
[
"representations"
][
33
].
to
(
device
=
"cpu"
)
}
}
for
i
,
label
in
enumerate
(
labels
):
for
i
,
label
in
enumerate
(
labels
):
os
.
makedirs
(
os
.
path
.
join
(
args
.
output_dir
,
label
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
join
(
output_dir
,
label
),
exist_ok
=
True
)
result
=
{
"label"
:
label
}
result
=
{
"label"
:
label
}
result
[
"representations"
]
=
{
result
[
"representations"
]
=
{
...
@@ -122,10 +140,24 @@ def main(args):
...
@@ -122,10 +140,24 @@ def main(args):
}
}
torch
.
save
(
torch
.
save
(
result
,
result
,
os
.
path
.
join
(
args
.
output_dir
,
label
,
label
+
".pt"
)
os
.
path
.
join
(
output_dir
,
label
,
label
+
".pt"
)
)
)
os
.
remove
(
temp_fasta_file
)
os
.
remove
(
temp_fasta_file
)
def
main
(
args
):
logging
.
info
(
"Loading the model..."
)
embedding_generator
=
EmbeddingGenerator
(
args
.
toks_per_batch
,
args
.
truncate
,
args
.
use_local_esm
,
args
.
nogpu
)
logging
.
info
(
"Loading the sequences and running the inference..."
)
embedding_generator
.
run
(
args
.
fasta_dir
,
args
.
output_dir
)
logging
.
info
(
"Completed."
)
logging
.
info
(
"Completed."
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment