Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb3f51e5
Unverified
Commit
bb3f51e5
authored
Feb 07, 2024
by
Christina Floristean
Committed by
GitHub
Feb 07, 2024
Browse files
Merge pull request #405 from aqlaboratory/multimer
Full multimer merge
parents
ce211367
c33a0bd6
Changes
106
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
274 additions
and
86 deletions
+274
-86
tests/test_primitives.py
tests/test_primitives.py
+4
-4
tests/test_structure_module.py
tests/test_structure_module.py
+81
-25
tests/test_template.py
tests/test_template.py
+123
-26
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+4
-4
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+19
-12
train_openfold.py
train_openfold.py
+43
-15
No files found.
tests/test_primitives.py
View file @
bb3f51e5
...
...
@@ -30,7 +30,7 @@ class TestLMA(unittest.TestCase):
q
,
kv
,
_
,
biases
=
random_attention_inputs
(
batch_size
=
consts
.
batch_size
,
n_seq
=
consts
.
n_seq
,
n
=
2
**
12
,
n
=
2
**
12
,
no_heads
=
no_heads
,
c_hidden
=
c_hidden
)
...
...
tests/test_structure_module.py
View file @
bb3f51e5
...
...
@@ -18,21 +18,19 @@ import unittest
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
)
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
InvariantPointAttention
,
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
(
...
...
@@ -46,6 +44,20 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
...
...
@@ -81,6 +93,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor
,
ar_epsilon
,
inf
,
is_multimer
=
consts
.
is_multimer
)
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
...
@@ -89,7 +102,11 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
({
"single"
:
s
,
"pair"
:
z
},
f
)
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
else
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
...
...
@@ -121,11 +138,14 @@ class TestStructureModule(unittest.TestCase):
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
fold
ing
.
StructureModule
(
c_sm
,
c_global
)
sm
=
self
.
am_
fold
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
return
sm
(
representations
,
batch
,
is_training
=
False
,
compute_loss
=
True
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
...
...
@@ -177,10 +197,24 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
0.05
)
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
c_m
=
13
c_z
=
17
...
...
@@ -197,13 +231,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rot_mats
)
translation
=
Vec3Array
.
from_array
(
trans
)
r
=
Rigid3Array
(
rotation
,
translation
)
else
:
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
,
is_multimer
=
consts
.
is_multimer
)
shape_before
=
s
.
shape
...
...
@@ -215,16 +254,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
fold
ing
.
InvariantPointAttention
(
ipa
=
self
.
am_
fold
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
)
if
consts
.
is_multimer
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
,
rigid
=
affine
)
else
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
f
=
hk
.
transform
(
run_ipa
)
...
...
@@ -238,12 +287,19 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
rigids
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
quats
=
self
.
am_rigid
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
...
@@ -265,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
...
...
tests/test_template.py
View file @
bb3f51e5
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
import
torch
import
numpy
as
np
import
unittest
...
...
@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention
,
TemplatePairStack
,
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
...
...
@@ -54,6 +54,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
...
...
@@ -65,6 +79,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout
=
0.25
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
blocks_per_ckpt
=
None
chunk_size
=
4
inf
=
1e7
...
...
@@ -78,6 +94,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
eps
=
eps
,
...
...
@@ -96,7 +114,35 @@ class TestTemplatePairStack(unittest.TestCase):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
if
consts
.
is_multimer
:
safe_key
=
alphafold
.
model
.
prng
.
SafeKey
(
hk
.
next_rng_key
())
template_iteration
=
self
.
am_modules
.
TemplateEmbeddingIteration
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
pair_mask
,
is_training
=
False
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
config
.
model
.
global_config
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
alphafold
.
model
.
layer_stack
.
layer_stack
(
c_ee
.
template
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
_
=
template_stack
((
pair_act
,
safe_subkey
))
else
:
tps
=
self
.
am_modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
...
...
@@ -115,6 +161,12 @@ class TestTemplatePairStack(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
if
consts
.
is_multimer
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_embedding_iteration"
)
else
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
...
...
@@ -132,25 +184,43 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_
embedder
.
template_
pair_stack
(
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
_mask_trans
=
False
,
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
,
mc_
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
self
.
am_
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
)
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
mc_mask_2d
,
is_training
=
False
)
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
...
...
@@ -162,6 +232,14 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
multichain_mask_2d
=
None
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
...
@@ -169,7 +247,7 @@ class Template(unittest.TestCase):
)
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
,
multichain_mask_2d
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
...
@@ -177,17 +255,36 @@ class Template(unittest.TestCase):
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
out_repro_all
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
_mask_trans
=
False
,
use_lma
=
False
,
inplace_safe
=
False
)
else
:
out_repro_all
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
mask_trans
=
False
,
use_lma
=
False
,
inplace_safe
=
False
)
out_repro
=
out_repro
[
"template_pair_embedding"
]
out_repro
=
out_repro_all
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_triangular_attention.py
View file @
bb3f51e5
...
...
@@ -79,16 +79,16 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_start
if
starting
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
# To save memory, the full model transposes inputs outside of the
...
...
@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
bb3f51e5
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
torch
import
re
import
numpy
as
np
import
unittest
from
openfold.model.triangular_multiplicative_update
import
*
...
...
@@ -31,6 +32,12 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z
=
consts
.
c_z
c
=
11
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
):
tm
=
FusedTriangleMultiplicationOutgoing
(
c_z
,
c
,
)
else
:
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
...
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config
.
model
.
global_config
,
name
=
name
,
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
tri_mul
(
pair_act
,
pair_mask
)
return
act
f
=
hk
.
transform
(
run_tri_mul
)
...
...
@@ -78,24 +85,25 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
inplace_safe
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
...
...
@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
train_openfold.py
View file @
bb3f51e5
import
argparse
import
logging
import
os
import
random
import
sys
import
time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
OpenFoldDataModule
,
DummyDataLoader
,
)
from
openfold.data.data_modules
import
OpenFoldDataModule
,
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.argparse
_utils
import
remove_arguments
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
...
...
@@ -39,6 +33,7 @@ from openfold.utils.validation_metrics import (
)
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
import_openfold_weights_
)
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
,
...
...
@@ -53,7 +48,10 @@ class OpenFoldWrapper(pl.LightningModule):
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
is_multimer
=
self
.
config
.
globals
.
is_multimer
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
...
...
@@ -98,12 +96,19 @@ class OpenFoldWrapper(pl.LightningModule):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
# Run the model
outputs
=
self
(
batch
)
# Remove the recycling dimension
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
if
self
.
is_multimer
:
batch
=
multi_chain_permutation_align
(
out
=
outputs
,
features
=
batch
,
ground_truth
=
ground_truth
)
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
...
...
@@ -127,12 +132,20 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
# Run the model
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
if
self
.
is_multimer
:
batch
=
multi_chain_permutation_align
(
out
=
outputs
,
features
=
batch
,
ground_truth
=
ground_truth
)
# Compute loss and other metrics
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
)
...
...
@@ -221,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
last_epoch
=
self
.
last_lr_step
)
return
{
...
...
@@ -265,8 +279,8 @@ def main(args):
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
...
...
@@ -281,7 +295,7 @@ def main(args):
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_
state_dict
(
sd
)
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
...
...
@@ -291,7 +305,13 @@ def main(args):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
#data_module = DummyDataLoader("new_batch.pickle")
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
else
:
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
...
...
@@ -416,6 +436,10 @@ if __name__ == "__main__":
help
=
'''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser
.
add_argument
(
"--train_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to the json file which records all the information of mmcif structures used during training"
)
parser
.
add_argument
(
"--use_single_seq_mode"
,
type
=
str
,
default
=
False
,
help
=
"Use single sequence embeddings instead of MSAs."
...
...
@@ -436,6 +460,10 @@ if __name__ == "__main__":
"--val_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed validation alignments"
)
parser
.
add_argument
(
"--val_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"path to the json file which records all the information of mmcif structures used during validation"
)
parser
.
add_argument
(
"--kalign_binary_path"
,
type
=
str
,
default
=
'/usr/bin/kalign'
,
help
=
"Path to the kalign binary"
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment