Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
c1129bef
"lib/bindings/python/vscode:/vscode.git/clone" did not exist on "1c67be28d50f5e1a53e08e471bdabf72aea8f7b5"
Commit
c1129bef
authored
Jun 02, 2023
by
Christina Floristean
Browse files
Fixed bug in triangle multiplicative update and added early stop recycling.
parent
425bdb5e
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
173 additions
and
79 deletions
+173
-79
openfold/config.py
openfold/config.py
+34
-9
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+18
-29
openfold/data/templates.py
openfold/data/templates.py
+2
-2
openfold/model/evoformer.py
openfold/model/evoformer.py
+5
-9
openfold/model/model.py
openfold/model/model.py
+47
-5
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+16
-8
scripts/data_dir_to_fasta.py
scripts/data_dir_to_fasta.py
+22
-2
scripts/flatten_roda.sh
scripts/flatten_roda.sh
+6
-1
scripts/generate_chain_data_cache.py
scripts/generate_chain_data_cache.py
+21
-13
scripts/precompute_alignments.py
scripts/precompute_alignments.py
+2
-1
No files found.
openfold/config.py
View file @
c1129bef
...
@@ -155,21 +155,38 @@ def model_config(
...
@@ -155,21 +155,38 @@ def model_config(
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
globals
.
is_multimer
=
True
c
.
globals
.
bfloat16
=
True
c
.
globals
.
bfloat16_output
=
False
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
model
[
k
]
=
v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
c
.
model
.
evoformer
.
num_msa
=
252
#c.model.input_embedder.num_msa = 252
c
.
model
.
evoformer
.
num_extra_msa
=
1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
model
.
evoformer
.
fuse_projection_weights
=
False
c
.
data
.
train
.
max_msa_clusters
=
252
c
.
data
.
predict
.
max_msa_clusters
=
252
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
model
.
evoformer_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
elif
name
==
'model_4_multimer_v3'
:
elif
name
==
'model_4_multimer_v3'
:
c
.
model
.
evoformer
.
num_extra_msa
=
1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
elif
name
==
'model_5_multimer_v3'
:
elif
name
==
'model_5_multimer_v3'
:
c
.
model
.
evoformer
.
num_extra_msa
=
1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
for
k
,
v
in
multimer_model_config_update
.
items
():
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
model
[
k
]
=
v
else
:
c
.
data
.
train
.
max_msa_clusters
=
508
c
.
data
.
predict
.
max_msa_clusters
=
508
c
.
data
.
train
.
max_extra_msa
=
2048
c
.
data
.
predict
.
max_extra_msa
=
2048
c
.
data
.
common
.
unsupervised_features
.
extend
([
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"msa_mask"
,
...
@@ -646,6 +663,12 @@ config = mlc.ConfigDict(
...
@@ -646,6 +663,12 @@ config = mlc.ConfigDict(
"eps"
:
eps
,
"eps"
:
eps
,
},
},
"ema"
:
{
"decay"
:
0.999
},
"ema"
:
{
"decay"
:
0.999
},
# A negative value indicates that no early stopping will occur, i.e.
# the model will always run `max_recycling_iters` number of recycling
# iterations. A positive value will enable early stopping if the
# difference in pairwise distances is less than the tolerance between
# recycling steps.
"recycle_early_stop_tolerance"
:
-
1
}
}
)
)
...
@@ -653,6 +676,7 @@ multimer_model_config_update = {
...
@@ -653,6 +676,7 @@ multimer_model_config_update = {
"input_embedder"
:
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"tf_dim"
:
21
,
"msa_dim"
:
49
,
"msa_dim"
:
49
,
#"num_msa": 508,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"relpos_k"
:
32
,
...
@@ -703,6 +727,7 @@ multimer_model_config_update = {
...
@@ -703,6 +727,7 @@ multimer_model_config_update = {
"extra_msa_embedder"
:
{
"extra_msa_embedder"
:
{
"c_in"
:
25
,
"c_in"
:
25
,
"c_out"
:
c_e
,
"c_out"
:
c_e
,
#"num_extra_msa": 2048
},
},
"extra_msa_stack"
:
{
"extra_msa_stack"
:
{
"c_m"
:
c_e
,
"c_m"
:
c_e
,
...
@@ -788,5 +813,5 @@ multimer_model_config_update = {
...
@@ -788,5 +813,5 @@ multimer_model_config_update = {
"c_out"
:
37
,
"c_out"
:
37
,
},
},
},
},
"recycle_early_stop_tolerance"
:
0.5
}
}
openfold/data/data_pipeline.py
View file @
c1129bef
...
@@ -280,6 +280,7 @@ def run_msa_tool(
...
@@ -280,6 +280,7 @@ def run_msa_tool(
else
:
else
:
result
=
msa_runner
.
query
(
fasta_path
)[
0
]
result
=
msa_runner
.
query
(
fasta_path
)[
0
]
assert
msa_out_path
.
split
(
'.'
)[
-
1
]
==
msa_format
with
open
(
msa_out_path
,
"w"
)
as
f
:
with
open
(
msa_out_path
,
"w"
)
as
f
:
f
.
write
(
result
[
msa_format
])
f
.
write
(
result
[
msa_format
])
...
@@ -321,6 +322,7 @@ def make_sequence_features_with_custom_template(
...
@@ -321,6 +322,7 @@ def make_sequence_features_with_custom_template(
**
template_features
.
features
**
template_features
.
features
}
}
class
AlignmentRunner
:
class
AlignmentRunner
:
"""Runs alignment tools and saves the results"""
"""Runs alignment tools and saves the results"""
def
__init__
(
def
__init__
(
...
@@ -372,6 +374,8 @@ class AlignmentRunner:
...
@@ -372,6 +374,8 @@ class AlignmentRunner:
Max number of uniref hits
Max number of uniref hits
mgnify_max_hits:
mgnify_max_hits:
Max number of mgnify hits
Max number of mgnify hits
uniprot_max_hits:
Max number of uniprot hits
"""
"""
db_map
=
{
db_map
=
{
"jackhmmer"
:
{
"jackhmmer"
:
{
...
@@ -468,7 +472,7 @@ class AlignmentRunner:
...
@@ -468,7 +472,7 @@ class AlignmentRunner:
):
):
"""Runs alignment tools on a sequence"""
"""Runs alignment tools on a sequence"""
if
(
self
.
jackhmmer_uniref90_runner
is
not
None
):
if
(
self
.
jackhmmer_uniref90_runner
is
not
None
):
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"uniref90_hits.
a3m
"
)
uniref90_out_path
=
os
.
path
.
join
(
output_dir
,
"uniref90_hits.
sto
"
)
jackhmmer_uniref90_result
=
run_msa_tool
(
jackhmmer_uniref90_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
msa_runner
=
self
.
jackhmmer_uniref90_runner
,
...
@@ -505,7 +509,7 @@ class AlignmentRunner:
...
@@ -505,7 +509,7 @@ class AlignmentRunner:
)
)
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
if
(
self
.
jackhmmer_mgnify_runner
is
not
None
):
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.
a3m
"
)
mgnify_out_path
=
os
.
path
.
join
(
output_dir
,
"mgnify_hits.
sto
"
)
jackhmmer_mgnify_result
=
run_msa_tool
(
jackhmmer_mgnify_result
=
run_msa_tool
(
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
msa_runner
=
self
.
jackhmmer_mgnify_runner
,
fasta_path
=
fasta_path
,
fasta_path
=
fasta_path
,
...
@@ -719,16 +723,14 @@ class DataPipeline:
...
@@ -719,16 +723,14 @@ class DataPipeline:
msa
=
parsers
.
parse_a3m
(
msa
=
parsers
.
parse_a3m
(
read_msa
(
start
,
size
)
read_msa
(
start
,
size
)
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
msa
.
deletion_matrix
}
# The "hmm_output" exception is a crude way to exclude
# The "hmm_output" exception is a crude way to exclude
# multimer template hits.
# multimer template hits.
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
msa
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
))
msa
=
parsers
.
parse_stockholm
(
read_msa
(
start
,
size
))
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
msa
.
deletion_matrix
}
else
:
else
:
continue
continue
msa_data
[
name
]
=
dat
a
msa_data
[
name
]
=
ms
a
fp
.
close
()
fp
.
close
()
else
:
else
:
...
@@ -739,17 +741,15 @@ class DataPipeline:
...
@@ -739,17 +741,15 @@ class DataPipeline:
if
(
ext
==
".a3m"
):
if
(
ext
==
".a3m"
):
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
msa
=
parsers
.
parse_a3m
(
fp
.
read
())
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
msa
.
deletion_matrix
}
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
elif
(
ext
==
".sto"
and
not
"hmm_output"
==
filename
):
with
open
(
path
,
"r"
)
as
fp
:
with
open
(
path
,
"r"
)
as
fp
:
msa
=
parsers
.
parse_stockholm
(
msa
=
parsers
.
parse_stockholm
(
fp
.
read
()
fp
.
read
()
)
)
data
=
{
"msa"
:
msa
,
"deletion_matrix"
:
msa
.
deletion_matrix
}
else
:
else
:
continue
continue
msa_data
[
f
]
=
dat
a
msa_data
[
f
]
=
ms
a
return
msa_data
return
msa_data
...
@@ -831,8 +831,6 @@ class DataPipeline:
...
@@ -831,8 +831,6 @@ class DataPipeline:
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
hits
=
parsers
.
parse_hhr
(
fp
.
read
())
all_hits
[
f
]
=
hits
all_hits
[
f
]
=
hits
return
def
_get_msas
(
self
,
def
_get_msas
(
self
,
alignment_dir
:
str
,
alignment_dir
:
str
,
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
...
@@ -849,16 +847,11 @@ class DataPipeline:
...
@@ -849,16 +847,11 @@ class DataPipeline:
)
)
deletion_matrix
=
[[
0
for
_
in
input_sequence
]]
deletion_matrix
=
[[
0
for
_
in
input_sequence
]]
msa_data
[
"dummy"
]
=
{
msa_data
[
"dummy"
]
=
parsers
.
Msa
(
sequences
=
input_sequence
,
"msa"
:
parsers
.
Msa
(
sequences
=
input_sequence
,
deletion_matrix
=
deletion_matrix
,
descriptions
=
None
),
deletion_matrix
=
deletion_matrix
,
"deletion_matrix"
:
deletion_matrix
,
descriptions
=
None
)
}
msas
,
deletion_matrices
=
zip
(
*
[
(
v
[
"msa"
],
v
[
"deletion_matrix"
])
for
v
in
msa_data
.
values
()
])
return
msas
,
deletion_matrices
return
list
(
msa_data
.
values
())
def
_process_msa_feats
(
def
_process_msa_feats
(
self
,
self
,
...
@@ -866,7 +859,7 @@ class DataPipeline:
...
@@ -866,7 +859,7 @@ class DataPipeline:
input_sequence
:
Optional
[
str
]
=
None
,
input_sequence
:
Optional
[
str
]
=
None
,
alignment_index
:
Optional
[
str
]
=
None
alignment_index
:
Optional
[
str
]
=
None
)
->
Mapping
[
str
,
Any
]:
)
->
Mapping
[
str
,
Any
]:
msas
,
deletion_matrices
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
input_sequence
,
alignment_index
alignment_dir
,
input_sequence
,
alignment_index
)
)
msa_features
=
make_msa_features
(
msa_features
=
make_msa_features
(
...
@@ -944,7 +937,6 @@ class DataPipeline:
...
@@ -944,7 +937,6 @@ class DataPipeline:
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
input_sequence
=
mmcif
.
chain_to_seqres
[
chain_id
]
hits
=
self
.
_parse_template_hits
(
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_dir
,
input_sequence
,
alignment_index
)
alignment_index
)
template_features
=
make_template_features
(
template_features
=
make_template_features
(
...
@@ -994,7 +986,6 @@ class DataPipeline:
...
@@ -994,7 +986,6 @@ class DataPipeline:
hits
=
self
.
_parse_template_hits
(
hits
=
self
.
_parse_template_hits
(
alignment_dir
,
alignment_dir
,
input_sequence
,
alignment_index
alignment_index
)
)
...
@@ -1080,11 +1071,11 @@ class DataPipeline:
...
@@ -1080,11 +1071,11 @@ class DataPipeline:
alignment_dir
=
os
.
path
.
join
(
alignment_dir
=
os
.
path
.
join
(
super_alignment_dir
,
desc
super_alignment_dir
,
desc
)
)
msas
,
deletion_mats
=
self
.
_get_msas
(
msas
=
self
.
_get_msas
(
alignment_dir
,
seq
,
None
alignment_dir
,
seq
,
None
)
)
msa_list
.
append
(
msas
)
msa_list
.
append
(
[
m
.
sequences
for
m
in
msas
]
)
deletion_mat_list
.
append
(
deletion_mat
s
)
deletion_mat_list
.
append
(
[
m
.
deletion_mat
rix
for
m
in
msas
]
)
final_msa
=
[]
final_msa
=
[]
final_deletion_mat
=
[]
final_deletion_mat
=
[]
...
@@ -1181,12 +1172,10 @@ class DataPipelineMultimer:
...
@@ -1181,12 +1172,10 @@ class DataPipelineMultimer:
def
_all_seq_msa_features
(
self
,
fasta_path
,
alignment_dir
):
def
_all_seq_msa_features
(
self
,
fasta_path
,
alignment_dir
):
"""Get MSA features for unclustered uniprot, for pairing."""
"""Get MSA features for unclustered uniprot, for pairing."""
#TODO: Quick fix, change back to .sto after parsing fixed
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.sto"
)
uniprot_msa_path
=
os
.
path
.
join
(
alignment_dir
,
"uniprot_hits.a3m"
)
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
with
open
(
uniprot_msa_path
,
"r"
)
as
fp
:
uniprot_msa_string
=
fp
.
read
()
uniprot_msa_string
=
fp
.
read
()
msa
=
parsers
.
parse_a3m
(
uniprot_msa_string
)
msa
=
parsers
.
parse_stockholm
(
uniprot_msa_string
)
#msa = parsers.parse_stockholm(uniprot_msa_string)
all_seq_features
=
make_msa_features
([
msa
])
all_seq_features
=
make_msa_features
([
msa
])
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
valid_feats
=
msa_pairing
.
MSA_FEATURES
+
(
'msa_species_identifiers'
,
'msa_species_identifiers'
,
...
...
openfold/data/templates.py
View file @
c1129bef
...
@@ -902,7 +902,7 @@ def _process_single_hit(
...
@@ -902,7 +902,7 @@ def _process_single_hit(
%
(
%
(
hit_pdb_code
,
hit_pdb_code
,
hit_chain_id
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
sum_probs
if
hit
.
sum_probs
else
0.
,
hit
.
index
,
hit
.
index
,
str
(
e
),
str
(
e
),
parsing_result
.
errors
,
parsing_result
.
errors
,
...
@@ -919,7 +919,7 @@ def _process_single_hit(
...
@@ -919,7 +919,7 @@ def _process_single_hit(
%
(
%
(
hit_pdb_code
,
hit_pdb_code
,
hit_chain_id
,
hit_chain_id
,
hit
.
sum_probs
,
hit
.
sum_probs
if
hit
.
sum_probs
else
0.
,
hit
.
index
,
hit
.
index
,
str
(
e
),
str
(
e
),
parsing_result
.
errors
,
parsing_result
.
errors
,
...
...
openfold/model/evoformer.py
View file @
c1129bef
...
@@ -525,16 +525,14 @@ class EvoformerBlock(MSABlock):
...
@@ -525,16 +525,14 @@ class EvoformerBlock(MSABlock):
_attn_chunk_size
=
_attn_chunk_size
_attn_chunk_size
=
_attn_chunk_size
)
)
m
=
input_tensors
[
0
]
if
(
_offload_inference
and
inplace_safe
):
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
# m: GPU, z: GPU
device
=
z
.
device
device
=
z
.
device
del
m
,
z
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
_
=
input_tensors
m
,
z
=
input_tensors
else
:
m
=
input_tensors
[
0
]
return
m
,
z
return
m
,
z
...
@@ -713,12 +711,10 @@ class ExtraMSABlock(MSABlock):
...
@@ -713,12 +711,10 @@ class ExtraMSABlock(MSABlock):
if
(
_offload_inference
and
inplace_safe
):
if
(
_offload_inference
and
inplace_safe
):
# m: GPU, z: GPU
# m: GPU, z: GPU
device
=
z
.
device
device
=
z
.
device
del
m
,
z
del
m
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
0
])
==
2
)
assert
(
sys
.
getrefcount
(
input_tensors
[
1
])
==
2
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
_
=
input_tensors
m
,
z
=
input_tensors
return
m
,
z
return
m
,
z
...
...
openfold/model/model.py
View file @
c1129bef
...
@@ -25,6 +25,7 @@ from openfold.utils.feats import (
...
@@ -25,6 +25,7 @@ from openfold.utils.feats import (
dgram_from_positions
,
dgram_from_positions
,
atom14_to_atom37
,
atom14_to_atom37
,
)
)
from
openfold.utils.tensor_utils
import
masked_mean
from
openfold.model.embedders
import
(
from
openfold.model.embedders
import
(
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
InputEmbedderMultimer
,
...
@@ -165,6 +166,38 @@ class AlphaFold(nn.Module):
...
@@ -165,6 +166,38 @@ class AlphaFold(nn.Module):
return
template_embeds
return
template_embeds
def
tolerance_reached
(
self
,
prev_pos
,
next_pos
,
mask
,
no_batch_dims
,
eps
=
1e-8
)
->
bool
:
"""
Early stopping criteria based on criteria used in
AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
Args:
prev_pos: Previous atom positions in atom37/14 representation
next_pos: Current atom positions in atom37/14 representation
mask: 1-D sequence mask
eps: Epsilon used in square root calculation
Returns:
Whether to stop recycling early based on the desired tolerance.
"""
def
distances
(
points
):
"""Compute all pairwise distances for a set of points."""
d
=
points
[...,
None
,
:]
-
points
[...,
None
,
:,
:]
return
torch
.
sqrt
(
torch
.
sum
(
d
**
2
,
dim
=-
1
))
if
self
.
config
.
recycle_early_stop_tolerance
<
0
:
return
False
if
no_batch_dims
==
0
:
prev_pos
=
prev_pos
.
unsqueeze
(
dim
=
0
)
next_pos
=
next_pos
.
unsqueeze
(
dim
=
0
)
mask
=
mask
.
unsqueeze
(
dim
=
0
)
ca_idx
=
residue_constants
.
atom_order
[
'CA'
]
sq_diff
=
(
distances
(
prev_pos
[...,
ca_idx
,
:])
-
distances
(
next_pos
[...,
ca_idx
,
:]))
**
2
mask
=
mask
[...,
None
]
*
mask
[...,
None
,
:]
sq_diff
=
masked_mean
(
mask
=
mask
,
value
=
sq_diff
,
dim
=
list
(
range
(
len
(
mask
.
shape
))))
diff
=
torch
.
sqrt
(
sq_diff
+
eps
)
return
diff
<=
self
.
config
.
recycle_early_stop_tolerance
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
# Primary output dictionary
# Primary output dictionary
outputs
=
{}
outputs
=
{}
...
@@ -263,7 +296,7 @@ class AlphaFold(nn.Module):
...
@@ -263,7 +296,7 @@ class AlphaFold(nn.Module):
# Deletions like these become significant for inference with large N,
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# where they free unused tensors and remove references to others such
# that they can be offloaded later
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
...
@@ -406,10 +439,16 @@ class AlphaFold(nn.Module):
...
@@ -406,10 +439,16 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
outputs
[
"pair"
]
z_prev
=
outputs
[
"pair"
]
early_stop
=
False
if
self
.
globals
.
is_multimer
:
early_stop
=
self
.
tolerance_reached
(
x_prev
,
outputs
[
"final_atom_positions"
],
seq_mask
,
no_batch_dims
)
del
x_prev
# [*, N, 3]
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
x_prev
=
outputs
[
"final_atom_positions"
]
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
,
early_stop
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_embedder
.
template_pair_stack
.
blocks_per_ckpt
=
None
...
@@ -488,13 +527,14 @@ class AlphaFold(nn.Module):
...
@@ -488,13 +527,14 @@ class AlphaFold(nn.Module):
# Main recycling loop
# Main recycling loop
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
num_iters
=
batch
[
"aatype"
].
shape
[
-
1
]
early_stop
=
False
for
cycle_no
in
range
(
num_iters
):
for
cycle_no
in
range
(
num_iters
):
# Select the features for the current recycling cycle
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
is_final_iter
=
cycle_no
==
(
num_iters
-
1
)
or
early_stop
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
with
torch
.
set_grad_enabled
(
is_grad_enabled
and
is_final_iter
):
if
is_final_iter
:
if
is_final_iter
:
# Sidestep AMP bug (PyTorch issue #65766)
# Sidestep AMP bug (PyTorch issue #65766)
...
@@ -502,16 +542,18 @@ class AlphaFold(nn.Module):
...
@@ -502,16 +542,18 @@ class AlphaFold(nn.Module):
torch
.
clear_autocast_cache
()
torch
.
clear_autocast_cache
()
# Run the next iteration of the model
# Run the next iteration of the model
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
outputs
,
m_1_prev
,
z_prev
,
x_prev
,
early_stop
=
self
.
iteration
(
feats
,
feats
,
prevs
,
prevs
,
_recycle
=
(
num_iters
>
1
)
_recycle
=
(
num_iters
>
1
)
)
)
if
(
not
is_final_iter
)
:
if
not
is_final_iter
:
del
outputs
del
outputs
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
prevs
=
[
m_1_prev
,
z_prev
,
x_prev
]
del
m_1_prev
,
z_prev
,
x_prev
del
m_1_prev
,
z_prev
,
x_prev
else
:
break
# Run auxiliary heads
# Run auxiliary heads
outputs
.
update
(
self
.
aux_heads
(
outputs
))
outputs
.
update
(
self
.
aux_heads
(
outputs
))
...
...
openfold/model/triangular_multiplicative_update.py
View file @
c1129bef
...
@@ -509,7 +509,6 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -509,7 +509,6 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
def
compute_projection_helper
(
pair
,
mask
):
def
compute_projection_helper
(
pair
,
mask
):
pair
=
self
.
layer_norm_in
(
pair
)
p
=
self
.
linear_ab_g
(
pair
)
p
=
self
.
linear_ab_g
(
pair
)
p
.
sigmoid_
()
p
.
sigmoid_
()
p
*=
self
.
linear_ab_p
(
pair
)
p
*=
self
.
linear_ab_p
(
pair
)
...
@@ -519,16 +518,21 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -519,16 +518,21 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
def
compute_projection
(
pair
,
mask
):
def
compute_projection
(
pair
,
mask
):
p
=
compute_projection_helper
(
pair
,
mask
)
p
=
compute_projection_helper
(
pair
,
mask
)
a
=
p
[...,
:
self
.
c_hidden
]
if
self
.
_outgoing
:
b
=
p
[...,
self
.
c_hidden
:]
left
=
p
[...,
:
self
.
c_hidden
]
right
=
p
[...,
self
.
c_hidden
:]
else
:
left
=
p
[...,
self
.
c_hidden
:]
right
=
p
[...,
:
self
.
c_hidden
]
return
a
,
b
return
left
,
right
a
,
b
=
compute_projection
(
z
,
mask
)
z_norm_in
=
self
.
layer_norm_in
(
z
)
a
,
b
=
compute_projection
(
z_norm_in
,
mask
)
x
=
self
.
_combine_projections
(
a
,
b
,
_inplace_chunk_size
=
_inplace_chunk_size
)
x
=
self
.
_combine_projections
(
a
,
b
,
_inplace_chunk_size
=
_inplace_chunk_size
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
layer_norm_out
(
x
)
x
=
self
.
linear_z
(
x
)
x
=
self
.
linear_z
(
x
)
g
=
self
.
linear_g
(
z
)
g
=
self
.
linear_g
(
z
_norm_in
)
g
.
sigmoid_
()
g
.
sigmoid_
()
x
*=
g
x
*=
g
if
(
with_add
):
if
(
with_add
):
...
@@ -573,8 +577,12 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
...
@@ -573,8 +577,12 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
ab
=
ab
*
self
.
sigmoid
(
self
.
linear_ab_g
(
z
))
ab
=
ab
*
self
.
sigmoid
(
self
.
linear_ab_g
(
z
))
ab
=
ab
*
self
.
linear_ab_p
(
z
)
ab
=
ab
*
self
.
linear_ab_p
(
z
)
a
=
ab
[...,
:
self
.
c_hidden
]
if
self
.
_outgoing
:
b
=
ab
[...,
self
.
c_hidden
:]
a
=
ab
[...,
:
self
.
c_hidden
]
b
=
ab
[...,
self
.
c_hidden
:]
else
:
b
=
ab
[...,
:
self
.
c_hidden
]
a
=
ab
[...,
self
.
c_hidden
:]
# Prevents overflow of torch.matmul in combine projections in
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
# reduced-precision modes
...
...
scripts/data_dir_to_fasta.py
View file @
c1129bef
import
argparse
import
argparse
import
logging
import
logging
import
os
import
os
import
string
from
collections
import
defaultdict
from
openfold.data
import
mmcif_parsing
from
openfold.data
import
mmcif_parsing
from
openfold.np
import
protein
,
residue_constants
from
openfold.np
import
protein
,
residue_constants
...
@@ -22,7 +23,7 @@ def main(args):
...
@@ -22,7 +23,7 @@ def main(args):
if
(
mmcif
.
mmcif_object
is
None
):
if
(
mmcif
.
mmcif_object
is
None
):
logging
.
warning
(
f
'Failed to parse
{
fname
}
...'
)
logging
.
warning
(
f
'Failed to parse
{
fname
}
...'
)
if
(
args
.
raise_errors
):
if
(
args
.
raise_errors
):
raise
list
(
mmcif
.
errors
.
values
())[
0
]
raise
Exception
(
list
(
mmcif
.
errors
.
values
())[
0
]
)
else
:
else
:
continue
continue
...
@@ -31,6 +32,25 @@ def main(args):
...
@@ -31,6 +32,25 @@ def main(args):
chain_id
=
'_'
.
join
([
basename
,
chain
])
chain_id
=
'_'
.
join
([
basename
,
chain
])
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
seq
)
fasta
.
append
(
seq
)
elif
(
ext
==
".pdb"
):
with
open
(
fpath
,
'r'
)
as
fp
:
pdb_str
=
fp
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_str
)
aatype
=
protein_object
.
aatype
chain_index
=
protein_object
.
chain_index
last_chain_index
=
chain_index
[
0
]
chain_dict
=
defaultdict
(
list
)
for
i
in
range
(
aatype
.
shape
[
0
]):
chain_dict
[
chain_index
[
i
]].
append
(
residue_constants
.
restypes_with_x
[
aatype
[
i
]])
chain_tags
=
string
.
ascii_uppercase
for
chain
,
seq
in
chain_dict
.
items
():
chain_id
=
'_'
.
join
([
basename
,
chain_tags
[
chain
]])
fasta
.
append
(
f
">
{
chain_id
}
"
)
fasta
.
append
(
''
.
join
(
seq
))
elif
(
ext
==
".core"
):
elif
(
ext
==
".core"
):
with
open
(
fpath
,
'r'
)
as
fp
:
with
open
(
fpath
,
'r'
)
as
fp
:
core_str
=
fp
.
read
()
core_str
=
fp
.
read
()
...
...
scripts/flatten_roda.sh
View file @
c1129bef
...
@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}"
...
@@ -26,7 +26,12 @@ mkdir -p "${ALIGNMENT_DIR}"
for
chain_dir
in
$(
ls
"
${
RODA_DIR
}
"
)
;
do
for
chain_dir
in
$(
ls
"
${
RODA_DIR
}
"
)
;
do
CHAIN_DIR_PATH
=
"
${
RODA_DIR
}
/
${
chain_dir
}
"
CHAIN_DIR_PATH
=
"
${
RODA_DIR
}
/
${
chain_dir
}
"
for
subdir
in
$(
ls
"
${
CHAIN_DIR_PATH
}
"
)
;
do
for
subdir
in
$(
ls
"
${
CHAIN_DIR_PATH
}
"
)
;
do
if
[[
$subdir
=
"pdb"
]]
||
[[
$subdir
=
"cif"
]]
;
then
if
[[
!
-d
"
$subdir
"
]]
;
then
echo
"
$subdir
is not directory"
continue
elif
[[
-z
$(
ls
"
${
subdir
}
"
)
]]
;
then
continue
elif
[[
$subdir
=
"pdb"
]]
||
[[
$subdir
=
"cif"
]]
;
then
mv
"
${
CHAIN_DIR_PATH
}
/
${
subdir
}
"
/
*
"
${
DATA_DIR
}
"
mv
"
${
CHAIN_DIR_PATH
}
/
${
subdir
}
"
/
*
"
${
DATA_DIR
}
"
else
else
CHAIN_ALIGNMENT_DIR
=
"
${
ALIGNMENT_DIR
}
/
${
chain_dir
}
"
CHAIN_ALIGNMENT_DIR
=
"
${
ALIGNMENT_DIR
}
/
${
chain_dir
}
"
...
...
scripts/generate_chain_data_cache.py
View file @
c1129bef
...
@@ -4,10 +4,11 @@ import json
...
@@ -4,10 +4,11 @@ import json
import
logging
import
logging
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
import
os
import
os
import
string
import
sys
import
sys
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
sys
.
path
.
append
(
"."
)
# an innocent hack to get this to run from the top level
from
collections
import
defaultdict
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
openfold.data.mmcif_parsing
import
parse
from
openfold.data.mmcif_parsing
import
parse
...
@@ -49,20 +50,27 @@ def parse_file(
...
@@ -49,20 +50,27 @@ def parse_file(
pdb_string
=
fp
.
read
()
pdb_string
=
fp
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_string
,
None
)
protein_object
=
protein
.
from_pdb_string
(
pdb_string
,
None
)
aatype
=
protein_object
.
aatype
chain_index
=
protein_object
.
chain_index
chain_dict
=
{}
chain_dict
=
defaultdict
(
list
)
chain_dict
[
"seq"
]
=
residue_constants
.
aatype_to_str_sequence
(
for
i
in
range
(
aatype
.
shape
[
0
]):
protein_object
.
aatype
,
chain_dict
[
chain_index
[
i
]].
append
(
residue_constants
.
restypes_with_x
[
aatype
[
i
]])
)
chain_dict
[
"resolution"
]
=
0.
if
(
chain_cluster_size_dict
is
not
None
):
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
chain_dict
[
"cluster_size"
]
=
cluster_size
out
=
{
file_id
:
chain_dict
}
out
=
{}
chain_tags
=
string
.
ascii_uppercase
for
chain
,
seq
in
chain_dict
.
items
():
full_name
=
"_"
.
join
([
file_id
,
chain_tags
[
chain
]])
out
[
full_name
]
=
{}
local_data
=
out
[
full_name
]
local_data
[
"resolution"
]
=
0.
local_data
[
"seq"
]
=
''
.
join
(
seq
)
if
(
chain_cluster_size_dict
is
not
None
):
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
local_data
[
"cluster_size"
]
=
cluster_size
return
out
return
out
...
...
scripts/precompute_alignments.py
View file @
c1129bef
...
@@ -40,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
...
@@ -40,7 +40,8 @@ def run_seq_group_alignments(seq_groups, alignment_runner, args):
alignment_runner
.
run
(
alignment_runner
.
run
(
fasta_path
,
alignment_dir
fasta_path
,
alignment_dir
)
)
except
:
except
Exception
as
e
:
logging
.
warning
(
e
)
logging
.
warning
(
f
"Failed to run alignments for
{
first_name
}
. Skipping..."
)
logging
.
warning
(
f
"Failed to run alignments for
{
first_name
}
. Skipping..."
)
os
.
remove
(
fasta_path
)
os
.
remove
(
fasta_path
)
os
.
rmdir
(
alignment_dir
)
os
.
rmdir
(
alignment_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment