Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5aa54958
Commit
5aa54958
authored
Nov 08, 2023
by
Christina Floristean
Browse files
Merge branch 'main' into deepspeed-evo-attention
parents
f545323c
099769d2
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
336 additions
and
12 deletions
+336
-12
scripts/precompute_embeddings.py
scripts/precompute_embeddings.py
+200
-0
tests/test_embedders.py
tests/test_embedders.py
+23
-0
tests/test_evoformer.py
tests/test_evoformer.py
+59
-2
tests/test_model.py
tests/test_model.py
+50
-10
train_openfold.py
train_openfold.py
+4
-0
No files found.
scripts/precompute_embeddings.py
0 → 100644
View file @
5aa54958
# Some functions borrowed from [ESM](https://www.github.com/facebookresearch/esm)
import
argparse
import
logging
import
os
import
torch
from
openfold.data
import
parsers
logging
.
basicConfig
(
level
=
logging
.
INFO
)
class
SequenceDataset
(
object
):
def
__init__
(
self
,
labels
,
sequences
)
->
None
:
self
.
labels
=
labels
self
.
sequences
=
sequences
@
classmethod
def
from_file
(
cls
,
fasta_file
):
labels
,
sequences
=
[],
[]
with
open
(
fasta_file
,
"r"
)
as
infile
:
fasta_str
=
infile
.
read
()
sequences
,
labels
=
parsers
.
parse_fasta
(
fasta_str
)
assert
len
(
set
(
labels
))
==
len
(
labels
),
\
"Sequence labels need to be unique. Duplicates found!"
return
cls
(
labels
,
sequences
)
def
__len__
(
self
):
return
len
(
self
.
labels
)
def
__getitem__
(
self
,
idx
):
return
self
.
labels
[
idx
],
self
.
sequences
[
idx
]
def
get_batch_indices
(
self
,
toks_per_batch
,
extra_toks_per_seq
):
sizes
=
[(
len
(
s
),
i
)
for
i
,
s
in
enumerate
(
self
.
sequences
)]
sizes
.
sort
()
batches
=
[]
buf
=
[]
max_len
=
0
def
_flush_current_buf
():
nonlocal
max_len
,
buf
if
len
(
buf
)
==
0
:
return
batches
.
append
(
buf
)
buf
=
[]
max_len
=
0
for
sz
,
i
in
sizes
:
sz
+=
extra_toks_per_seq
if
max
(
sz
,
max_len
)
*
(
len
(
buf
)
+
1
)
>
toks_per_batch
:
_flush_current_buf
()
max_len
=
max
(
max_len
,
sz
)
buf
.
append
(
i
)
_flush_current_buf
()
return
batches
class
EmbeddingGenerator
:
"""Generates the ESM-1b embeddings for the single sequence model"""
def
__init__
(
self
,
toks_per_batch
:
int
=
4096
,
truncate
:
bool
=
True
,
use_local_esm
:
str
=
None
,
nogpu
:
bool
=
False
,
):
self
.
toks_per_batch
=
toks_per_batch
self
.
truncate
=
truncate
self
.
use_local_esm
=
use_local_esm
self
.
nogpu
=
nogpu
# Generate embeddings in bulk
if
self
.
use_local_esm
:
self
.
model
,
self
.
alphabet
=
torch
.
hub
.
load
(
self
.
use_local_esm
,
"esm1b_t33_650M_UR50S"
,
source
=
'local'
)
else
:
self
.
model
,
self
.
alphabet
=
torch
.
hub
.
load
(
"facebookresearch/esm:main"
,
"esm1b_t33_650M_UR50S"
)
if
torch
.
cuda
.
is_available
()
and
not
self
.
nogpu
:
self
.
model
=
self
.
model
.
to
(
device
=
"cuda"
)
def
parse_sequences
(
self
,
fasta_dir
,
output_dir
):
labels
=
[]
seqs
=
[]
# Generate a single bulk file
for
f
in
os
.
listdir
(
fasta_dir
):
f_name
,
ext
=
os
.
path
.
splitext
(
f
)
if
ext
!=
'.fasta'
and
ext
!=
'.fa'
:
logging
.
warning
(
f
"Ignoring non-FASTA file:
{
f
}
"
)
continue
with
open
(
os
.
path
.
join
(
fasta_dir
,
f
),
'r'
)
as
infile
:
seq
=
infile
.
readlines
()[
1
].
strip
()
labels
.
append
(
f_name
)
seqs
.
append
(
seq
)
lines
=
[]
for
label
,
seq
in
zip
(
labels
,
seqs
):
lines
+=
f
'>
{
label
}
\n
'
lines
+=
f
'
{
seq
}
\n
'
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
temp_fasta_file
=
os
.
path
.
join
(
output_dir
,
'temp.fasta'
)
with
open
(
temp_fasta_file
,
'w'
)
as
outfile
:
outfile
.
writelines
(
lines
)
return
temp_fasta_file
def
run
(
self
,
fasta_file
,
output_dir
,
):
dataset
=
SequenceDataset
.
from_file
(
fasta_file
)
batches
=
dataset
.
get_batch_indices
(
self
.
toks_per_batch
,
extra_toks_per_seq
=
1
)
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
collate_fn
=
self
.
alphabet
.
get_batch_converter
(),
batch_sampler
=
batches
)
logging
.
info
(
"Loaded all sequences"
)
repr_layers
=
[
33
]
with
torch
.
no_grad
():
for
batch_idx
,
(
labels
,
strs
,
toks
)
in
enumerate
(
data_loader
):
logging
.
info
(
f
"Processing
{
batch_idx
+
1
}
of
{
len
(
batches
)
}
batches (
{
toks
.
size
(
0
)
}
sequences)"
)
if
torch
.
cuda
.
is_available
()
and
not
self
.
nogpu
:
toks
=
toks
.
to
(
device
=
"cuda"
,
non_blocking
=
True
)
if
self
.
truncate
:
toks
=
toks
[:
1022
]
out
=
self
.
model
(
toks
,
repr_layers
=
repr_layers
,
return_contacts
=
False
)
representations
=
{
33
:
out
[
"representations"
][
33
].
to
(
device
=
"cpu"
)
}
for
i
,
label
in
enumerate
(
labels
):
os
.
makedirs
(
os
.
path
.
join
(
output_dir
,
label
),
exist_ok
=
True
)
result
=
{
"label"
:
label
}
result
[
"representations"
]
=
{
33
:
representations
[
33
][
i
,
1
:
len
(
strs
[
i
])
+
1
].
clone
()
}
torch
.
save
(
result
,
os
.
path
.
join
(
output_dir
,
label
,
label
+
".pt"
)
)
def
main
(
args
):
logging
.
info
(
"Loading the model..."
)
embedding_generator
=
EmbeddingGenerator
(
args
.
toks_per_batch
,
args
.
truncate
,
args
.
use_local_esm
,
args
.
nogpu
)
logging
.
info
(
"Loading the sequences and running the inference..."
)
temp_fasta_file
=
embedding_generator
.
parse_sequences
(
args
.
fasta_dir
,
args
.
output_dir
)
embedding_generator
.
run
(
temp_fasta_file
,
args
.
output_dir
)
os
.
remove
(
temp_fasta_file
)
logging
.
info
(
"Completed."
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"fasta_dir"
,
type
=
str
,
help
=
"""Path to directory containing FASTA files."""
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
"Directory in which to output embeddings"
)
parser
.
add_argument
(
"--toks_per_batch"
,
type
=
int
,
default
=
4096
,
help
=
"maximum tokens in a batch"
)
parser
.
add_argument
(
"--truncate"
,
action
=
"store_true"
,
default
=
True
,
help
=
"Truncate sequences longer than 1022 (ESM restriction). Default: True"
)
parser
.
add_argument
(
"--use_local_esm"
,
type
=
str
,
default
=
None
,
help
=
"Use a local ESM repository instead of cloning from Github"
)
parser
.
add_argument
(
"--nogpu"
,
action
=
"store_true"
,
help
=
"Do not use GPU"
)
args
=
parser
.
parse_args
()
main
(
args
)
tests/test_embedders.py
View file @
5aa54958
...
...
@@ -17,6 +17,7 @@ import numpy as np
import
unittest
from
openfold.model.embedders
import
(
InputEmbedder
,
PreembeddingEmbedder
,
RecyclingEmbedder
,
TemplateAngleEmbedder
,
TemplatePairEmbedder
,
...
...
@@ -46,6 +47,28 @@ class TestInputEmbedder(unittest.TestCase):
self
.
assertTrue
(
pair_emb
.
shape
==
(
b
,
n_res
,
n_res
,
c_z
))
class
TestPreembeddingEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
tf_dim
=
22
preembedding_dim
=
1280
c_z
=
4
c_m
=
6
relpos_k
=
10
batch_size
=
4
num_res
=
20
tf
=
torch
.
rand
((
batch_size
,
num_res
,
tf_dim
))
ri
=
torch
.
rand
((
batch_size
,
num_res
))
preemb
=
torch
.
rand
((
batch_size
,
num_res
,
preembedding_dim
))
pe
=
PreembeddingEmbedder
(
tf_dim
,
preembedding_dim
,
c_z
,
c_m
,
relpos_k
)
seq_emb
,
pair_emb
=
pe
(
tf
,
ri
,
preemb
)
self
.
assertTrue
(
seq_emb
.
shape
==
(
batch_size
,
1
,
num_res
,
c_m
))
self
.
assertTrue
(
pair_emb
.
shape
==
(
batch_size
,
num_res
,
num_res
,
c_z
))
class
TestRecyclingEmbedder
(
unittest
.
TestCase
):
def
test_shape
(
self
):
batch_size
=
2
...
...
tests/test_evoformer.py
View file @
5aa54958
...
...
@@ -66,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout
,
pair_stack_dropout
,
blocks_per_ckpt
=
None
,
no_column_attention
=
False
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
...
...
@@ -86,6 +87,62 @@ class TestEvoformerStack(unittest.TestCase):
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
self
.
assertTrue
(
s
.
shape
==
(
batch_size
,
n_res
,
c_s
))
def
test_shape_without_column_attention
(
self
):
batch_size
=
consts
.
batch_size
n_seq
=
consts
.
n_seq
n_res
=
consts
.
n_res
c_m
=
consts
.
c_m
c_z
=
consts
.
c_z
c_hidden_msa_att
=
12
c_hidden_opm
=
17
c_hidden_mul
=
19
c_hidden_pair_att
=
14
c_s
=
consts
.
c_s
no_heads_msa
=
3
no_heads_pair
=
7
no_blocks
=
2
transition_n
=
2
msa_dropout
=
0.15
pair_stack_dropout
=
0.25
inf
=
1e9
eps
=
1e-10
es
=
EvoformerStack
(
c_m
,
c_z
,
c_hidden_msa_att
,
c_hidden_opm
,
c_hidden_mul
,
c_hidden_pair_att
,
c_s
,
no_heads_msa
,
no_heads_pair
,
no_blocks
,
transition_n
,
msa_dropout
,
pair_stack_dropout
,
blocks_per_ckpt
=
None
,
no_column_attention
=
True
,
inf
=
inf
,
eps
=
eps
,
).
eval
()
m_init
=
torch
.
rand
((
batch_size
,
n_seq
,
n_res
,
c_m
))
z_init
=
torch
.
rand
((
batch_size
,
n_res
,
n_res
,
c_z
))
msa_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_seq
,
n_res
))
pair_mask
=
torch
.
randint
(
0
,
2
,
size
=
(
batch_size
,
n_res
,
n_res
))
shape_m_before
=
m_init
.
shape
shape_z_before
=
z_init
.
shape
m
,
z
,
s
=
es
(
m_init
,
z_init
,
chunk_size
=
4
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
)
self
.
assertTrue
(
m
.
shape
==
shape_m_before
)
self
.
assertTrue
(
z
.
shape
==
shape_z_before
)
self
.
assertTrue
(
s
.
shape
==
(
batch_size
,
n_res
,
c_s
))
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
run_ei
(
activations
,
masks
):
...
...
@@ -206,7 +263,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res
,
),
device
=
"cuda"
,
)
)
.
float
()
pair_mask
=
torch
.
randint
(
0
,
2
,
...
...
@@ -216,7 +273,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res
,
),
device
=
"cuda"
,
)
)
.
float
()
shape_z_before
=
z
.
shape
...
...
tests/test_model.py
View file @
5aa54958
...
...
@@ -47,33 +47,73 @@ class TestModel(unittest.TestCase):
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
# deepspeed for this test
model
=
AlphaFold
(
c
)
model
=
AlphaFold
(
c
)
.
cuda
()
model
.
eval
()
batch
=
{}
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
tf
=
torch
.
randint
(
c
.
model
.
input_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
.
cuda
()
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
model
.
input_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
).
float
()
.
cuda
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
.
cuda
()
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
.
cuda
()
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
c
.
model
.
input_embedder
.
msa_dim
))
.
cuda
()
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
.
cuda
()
for
k
,
v
in
t_feats
.
items
()})
extra_feats
=
random_extra_msa_feats
(
n_extra_seq
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
extra_feats
.
items
()})
batch
.
update
({
k
:
torch
.
tensor
(
v
)
.
cuda
()
for
k
,
v
in
extra_feats
.
items
()})
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)
).
float
()
).
float
().
cuda
()
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
().
cuda
()
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
).
cuda
()
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
with
torch
.
no_grad
():
out
=
model
(
batch
)
def
test_dry_run_seqemb_mode
(
self
):
n_seq
=
1
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
msa_dim
=
49
c
=
model_config
(
"seq_model_esm1b"
)
c
.
model
.
evoformer_stack
.
no_blocks
=
2
c
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
model
=
AlphaFold
(
c
)
model
.
to
(
torch
.
device
(
'cuda'
))
model
.
eval
()
batch
=
{}
tf
=
torch
.
randint
(
c
.
model
.
preembedding_embedder
.
tf_dim
-
1
,
size
=
(
n_res
,))
batch
[
"target_feat"
]
=
nn
.
functional
.
one_hot
(
tf
,
c
.
model
.
preembedding_embedder
.
tf_dim
).
float
()
batch
[
"aatype"
]
=
torch
.
argmax
(
batch
[
"target_feat"
],
dim
=-
1
)
batch
[
"residue_index"
]
=
torch
.
arange
(
n_res
)
batch
[
"msa_feat"
]
=
torch
.
rand
((
n_seq
,
n_res
,
msa_dim
))
batch
[
"seq_embedding"
]
=
torch
.
rand
((
n_res
,
c
.
model
.
preembedding_embedder
.
preembedding_dim
))
t_feats
=
random_template_feats
(
n_templ
,
n_res
)
batch
.
update
({
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
t_feats
.
items
()})
batch
[
"seq_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,)).
float
()
batch
.
update
(
data_transforms
.
make_atom14_masks
(
batch
))
batch
[
"msa_mask"
]
=
torch
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_seq
,
n_res
)).
float
()
batch
[
"no_recycling_iters"
]
=
torch
.
tensor
(
2.
)
add_recycling_dims
=
lambda
t
:
(
t
.
unsqueeze
(
-
1
).
expand
(
*
t
.
shape
,
c
.
data
.
common
.
max_recycling_iters
)
)
batch
=
tensor_tree_map
(
add_recycling_dims
,
batch
)
to_cuda_device
=
lambda
t
:
t
.
to
(
torch
.
device
(
"cuda"
))
batch
=
tensor_tree_map
(
to_cuda_device
,
batch
)
with
torch
.
no_grad
():
out
=
model
(
batch
)
...
...
train_openfold.py
View file @
5aa54958
...
...
@@ -416,6 +416,10 @@ if __name__ == "__main__":
help
=
'''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser
.
add_argument
(
"--use_single_seq_mode"
,
type
=
str
,
default
=
False
,
help
=
"Use single sequence embeddings instead of MSAs."
)
parser
.
add_argument
(
"--distillation_data_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing training PDB files"
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment