Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
bb3f51e5
Unverified
Commit
bb3f51e5
authored
Feb 07, 2024
by
Christina Floristean
Committed by
GitHub
Feb 07, 2024
Browse files
Merge pull request #405 from aqlaboratory/multimer
Full multimer merge
parents
ce211367
c33a0bd6
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
274 additions
and
86 deletions
+274
-86
tests/test_primitives.py
tests/test_primitives.py
+4
-4
tests/test_structure_module.py
tests/test_structure_module.py
+81
-25
tests/test_template.py
tests/test_template.py
+123
-26
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+4
-4
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+19
-12
train_openfold.py
train_openfold.py
+43
-15
No files found.
tests/test_primitives.py
View file @
bb3f51e5
...
@@ -30,7 +30,7 @@ class TestLMA(unittest.TestCase):
...
@@ -30,7 +30,7 @@ class TestLMA(unittest.TestCase):
q
,
kv
,
_
,
biases
=
random_attention_inputs
(
batch_size
=
consts
.
batch_size
,
q
,
kv
,
_
,
biases
=
random_attention_inputs
(
batch_size
=
consts
.
batch_size
,
n_seq
=
consts
.
n_seq
,
n_seq
=
consts
.
n_seq
,
n
=
2
**
12
,
n
=
2
**
12
,
no_heads
=
no_heads
,
no_heads
=
no_heads
,
c_hidden
=
c_hidden
)
c_hidden
=
c_hidden
)
...
@@ -44,10 +44,10 @@ class TestLMA(unittest.TestCase):
...
@@ -44,10 +44,10 @@ class TestLMA(unittest.TestCase):
l
=
a
(
q
,
kv
,
biases
=
biases
,
use_lma
=
True
).
cpu
()
l
=
a
(
q
,
kv
,
biases
=
biases
,
use_lma
=
True
).
cpu
()
real
=
a
(
q
,
kv
,
biases
=
biases
).
cpu
()
real
=
a
(
q
,
kv
,
biases
=
biases
).
cpu
()
err
=
torch
.
max
(
torch
.
abs
(
l
-
real
))
err
=
torch
.
max
(
torch
.
abs
(
l
-
real
))
self
.
assertTrue
(
err
<
consts
.
eps
,
f
'Error:
{
err
}
'
)
self
.
assertTrue
(
err
<
consts
.
eps
,
f
'Error:
{
err
}
'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
\ No newline at end of file
tests/test_structure_module.py
View file @
bb3f51e5
...
@@ -18,21 +18,19 @@ import unittest
...
@@ -18,21 +18,19 @@ import unittest
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.data.data_transforms
import
make_atom14_masks_np
from
openfold.np.residue_constants
import
(
from
openfold.np.residue_constants
import
(
restype_rigid_group_default_frame
,
restype_atom14_to_rigid_group
,
restype_atom14_mask
,
restype_atom14_mask
,
restype_atom14_rigid_group_positions
,
restype_atom37_mask
,
restype_atom37_mask
,
)
)
from
openfold.model.structure_module
import
(
from
openfold.model.structure_module
import
(
StructureModule
,
StructureModule
,
StructureModuleTransition
,
StructureModuleTransition
,
BackboneUpdate
,
AngleResnet
,
AngleResnet
,
InvariantPointAttention
,
InvariantPointAttention
,
)
)
import
openfold.utils.feats
as
feats
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.rigid_utils
import
Rotation
,
Rigid
from
openfold.utils.geometry.rigid_matrix_vector
import
Rigid3Array
from
openfold.utils.geometry.rotation_matrix
import
Rot3Array
from
openfold.utils.geometry.vector
import
Vec3Array
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
(
from
tests.data_utils
import
(
...
@@ -46,6 +44,20 @@ if compare_utils.alphafold_is_installed():
...
@@ -46,6 +44,20 @@ if compare_utils.alphafold_is_installed():
class
TestStructureModule
(
unittest
.
TestCase
):
class
TestStructureModule
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_structure_module_shape
(
self
):
def
test_structure_module_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
n
=
consts
.
n_res
n
=
consts
.
n_res
...
@@ -81,6 +93,7 @@ class TestStructureModule(unittest.TestCase):
...
@@ -81,6 +93,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor
,
trans_scale_factor
,
ar_epsilon
,
ar_epsilon
,
inf
,
inf
,
is_multimer
=
consts
.
is_multimer
)
)
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
s
=
torch
.
rand
((
batch_size
,
n
,
c_s
))
...
@@ -89,7 +102,11 @@ class TestStructureModule(unittest.TestCase):
...
@@ -89,7 +102,11 @@ class TestStructureModule(unittest.TestCase):
out
=
sm
({
"single"
:
s
,
"pair"
:
z
},
f
)
out
=
sm
({
"single"
:
s
,
"pair"
:
z
},
f
)
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
if
consts
.
is_multimer
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
4
,
4
))
else
:
self
.
assertTrue
(
out
[
"frames"
].
shape
==
(
no_layers
,
batch_size
,
n
,
7
))
self
.
assertTrue
(
self
.
assertTrue
(
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
out
[
"angles"
].
shape
==
(
no_layers
,
batch_size
,
n
,
no_angles
,
2
)
)
)
...
@@ -121,11 +138,14 @@ class TestStructureModule(unittest.TestCase):
...
@@ -121,11 +138,14 @@ class TestStructureModule(unittest.TestCase):
c_global
=
config
.
model
.
global_config
c_global
=
config
.
model
.
global_config
def
run_sm
(
representations
,
batch
):
def
run_sm
(
representations
,
batch
):
sm
=
alphafold
.
model
.
fold
ing
.
StructureModule
(
c_sm
,
c_global
)
sm
=
self
.
am_
fold
.
StructureModule
(
c_sm
,
c_global
)
representations
=
{
representations
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
representations
.
items
()
}
}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
batch
=
{
k
:
jax
.
lax
.
stop_gradient
(
v
)
for
k
,
v
in
batch
.
items
()}
if
consts
.
is_multimer
:
return
sm
(
representations
,
batch
,
is_training
=
False
,
compute_loss
=
True
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
return
sm
(
representations
,
batch
,
is_training
=
False
)
f
=
hk
.
transform
(
run_sm
)
f
=
hk
.
transform
(
run_sm
)
...
@@ -177,10 +197,24 @@ class TestStructureModule(unittest.TestCase):
...
@@ -177,10 +197,24 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
0.05
)
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
c_m
=
13
c_m
=
13
c_z
=
17
c_z
=
17
...
@@ -197,13 +231,18 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -197,13 +231,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask
=
torch
.
ones
((
batch_size
,
n_res
))
mask
=
torch
.
ones
((
batch_size
,
n_res
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rot_mats
=
torch
.
rand
((
batch_size
,
n_res
,
3
,
3
))
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
trans
=
torch
.
rand
((
batch_size
,
n_res
,
3
))
r
=
Rigid
(
rots
,
trans
)
if
consts
.
is_multimer
:
rotation
=
Rot3Array
.
from_array
(
rot_mats
)
translation
=
Vec3Array
.
from_array
(
trans
)
r
=
Rigid3Array
(
rotation
,
translation
)
else
:
rots
=
Rotation
(
rot_mats
=
rot_mats
,
quats
=
None
)
r
=
Rigid
(
rots
,
trans
)
ipa
=
InvariantPointAttention
(
ipa
=
InvariantPointAttention
(
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
c_m
,
c_z
,
c_hidden
,
no_heads
,
no_qp
,
no_vp
,
is_multimer
=
consts
.
is_multimer
)
)
shape_before
=
s
.
shape
shape_before
=
s
.
shape
...
@@ -215,16 +254,26 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -215,16 +254,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def
test_ipa_compare
(
self
):
def
test_ipa_compare
(
self
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
def
run_ipa
(
act
,
static_feat_2d
,
mask
,
affine
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
ipa
=
alphafold
.
model
.
fold
ing
.
InvariantPointAttention
(
ipa
=
self
.
am_
fold
.
InvariantPointAttention
(
config
.
model
.
heads
.
structure_module
,
config
.
model
.
heads
.
structure_module
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
attn
=
ipa
(
inputs_1d
=
act
,
if
consts
.
is_multimer
:
inputs_2d
=
static_feat_2d
,
attn
=
ipa
(
mask
=
mask
,
inputs_1d
=
act
,
affine
=
affine
,
inputs_2d
=
static_feat_2d
,
)
mask
=
mask
,
rigid
=
affine
)
else
:
attn
=
ipa
(
inputs_1d
=
act
,
inputs_2d
=
static_feat_2d
,
mask
=
mask
,
affine
=
affine
)
return
attn
return
attn
f
=
hk
.
transform
(
run_ipa
)
f
=
hk
.
transform
(
run_ipa
)
...
@@ -238,13 +287,20 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -238,13 +287,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask
=
np
.
ones
((
n_res
,
1
))
sample_mask
=
np
.
ones
((
n_res
,
1
))
affines
=
random_affines_4x4
((
n_res
,))
affines
=
random_affines_4x4
((
n_res
,))
rigids
=
alphafold
.
model
.
r3
.
rigids_from_tensor4x4
(
affines
)
quats
=
alphafold
.
model
.
r3
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
if
consts
.
is_multimer
:
rigids
=
self
.
am_rigid
.
Rigid3Array
.
from_array4x4
(
affines
)
transformations
=
Rigid3Array
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
rigids
else
:
rigids
=
self
.
am_rigid
.
rigids_from_tensor4x4
(
affines
)
quats
=
self
.
am_rigid
.
rigids_to_quataffine
(
rigids
)
transformations
=
Rigid
.
from_tensor_4x4
(
torch
.
as_tensor
(
affines
).
float
().
cuda
()
)
sample_affine
=
quats
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
ipa_params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/structure_module/"
"alphafold/alphafold_iteration/structure_module/"
...
@@ -265,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
...
@@ -265,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
class
TestAngleResnet
(
unittest
.
TestCase
):
...
...
tests/test_template.py
View file @
bb3f51e5
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
re
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
...
@@ -19,7 +20,6 @@ from openfold.model.template import (
...
@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention
,
TemplatePointwiseAttention
,
TemplatePairStack
,
TemplatePairStack
,
)
)
from
openfold.utils.tensor_utils
import
tree_map
import
tests.compare_utils
as
compare_utils
import
tests.compare_utils
as
compare_utils
from
tests.config
import
consts
from
tests.config
import
consts
from
tests.data_utils
import
random_template_feats
from
tests.data_utils
import
random_template_feats
...
@@ -54,6 +54,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
...
@@ -54,6 +54,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
class
TestTemplatePairStack
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
def
test_shape
(
self
):
def
test_shape
(
self
):
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
c_t
=
consts
.
c_t
c_t
=
consts
.
c_t
...
@@ -65,6 +79,8 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -65,6 +79,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout
=
0.25
dropout
=
0.25
n_templ
=
consts
.
n_templ
n_templ
=
consts
.
n_templ
n_res
=
consts
.
n_res
n_res
=
consts
.
n_res
tri_mul_first
=
consts
.
is_multimer
fuse_projection_weights
=
True
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
)
else
False
blocks_per_ckpt
=
None
blocks_per_ckpt
=
None
chunk_size
=
4
chunk_size
=
4
inf
=
1e7
inf
=
1e7
...
@@ -78,6 +94,8 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -78,6 +94,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads
=
no_heads
,
no_heads
=
no_heads
,
pair_transition_n
=
pt_inner_dim
,
pair_transition_n
=
pt_inner_dim
,
dropout_rate
=
dropout
,
dropout_rate
=
dropout
,
tri_mul_first
=
tri_mul_first
,
fuse_projection_weights
=
fuse_projection_weights
,
blocks_per_ckpt
=
None
,
blocks_per_ckpt
=
None
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
...
@@ -96,12 +114,40 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -96,12 +114,40 @@ class TestTemplatePairStack(unittest.TestCase):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
def
run_template_pair_stack
(
pair_act
,
pair_mask
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
c_ee
=
config
.
model
.
embeddings_and_evoformer
c_ee
=
config
.
model
.
embeddings_and_evoformer
tps
=
alphafold
.
model
.
modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
if
consts
.
is_multimer
:
config
.
model
.
global_config
,
safe_key
=
alphafold
.
model
.
prng
.
SafeKey
(
hk
.
next_rng_key
())
name
=
"template_pair_stack"
,
template_iteration
=
self
.
am_modules
.
TemplateEmbeddingIteration
(
)
c_ee
.
template
.
template_pair_stack
,
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
config
.
model
.
global_config
,
name
=
'template_embedding_iteration'
)
def
template_iteration_fn
(
x
):
act
,
safe_key
=
x
safe_key
,
safe_subkey
=
safe_key
.
split
()
act
=
template_iteration
(
act
=
act
,
pair_mask
=
pair_mask
,
is_training
=
False
,
safe_key
=
safe_subkey
)
return
(
act
,
safe_key
)
if
config
.
model
.
global_config
.
use_remat
:
template_iteration_fn
=
hk
.
remat
(
template_iteration_fn
)
safe_key
,
safe_subkey
=
safe_key
.
split
()
template_stack
=
alphafold
.
model
.
layer_stack
.
layer_stack
(
c_ee
.
template
.
template_pair_stack
.
num_block
)(
template_iteration_fn
)
act
,
_
=
template_stack
((
pair_act
,
safe_subkey
))
else
:
tps
=
self
.
am_modules
.
TemplatePairStack
(
c_ee
.
template
.
template_pair_stack
,
config
.
model
.
global_config
,
name
=
"template_pair_stack"
,
)
act
=
tps
(
pair_act
,
pair_mask
,
is_training
=
False
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
ln
=
hk
.
LayerNorm
([
-
1
],
True
,
True
,
name
=
"output_layer_norm"
)
act
=
ln
(
act
)
act
=
ln
(
act
)
return
act
return
act
...
@@ -115,10 +161,16 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -115,10 +161,16 @@ class TestTemplatePairStack(unittest.TestCase):
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
)
).
astype
(
np
.
float32
)
).
astype
(
np
.
float32
)
params
=
compare_utils
.
fetch_alphafold_module_weights
(
if
consts
.
is_multimer
:
"alphafold/alphafold_iteration/evoformer/template_embedding/"
params
=
compare_utils
.
fetch_alphafold_module_weights
(
+
"single_template_embedding/template_pair_stack"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
)
+
"single_template_embedding/template_embedding_iteration"
)
else
:
params
=
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+
"single_template_embedding/template_pair_stack"
)
params
.
update
(
params
.
update
(
compare_utils
.
fetch_alphafold_module_weights
(
compare_utils
.
fetch_alphafold_module_weights
(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
"alphafold/alphafold_iteration/evoformer/template_embedding/"
...
@@ -132,26 +184,44 @@ class TestTemplatePairStack(unittest.TestCase):
...
@@ -132,26 +184,44 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
template_pair_stack
(
out_repro
=
model
.
template_
embedder
.
template_
pair_stack
(
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_act
).
unsqueeze
(
-
4
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
unsqueeze
(
-
3
).
cuda
(),
chunk_size
=
None
,
chunk_size
=
None
,
_mask_trans
=
False
,
_mask_trans
=
False
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
Template
(
unittest
.
TestCase
):
class
Template
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
if
compare_utils
.
alphafold_is_installed
():
if
consts
.
is_multimer
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom_multimer
cls
.
am_fold
=
alphafold
.
model
.
folding_multimer
cls
.
am_modules
=
alphafold
.
model
.
modules_multimer
cls
.
am_rigid
=
alphafold
.
model
.
geometry
else
:
cls
.
am_atom
=
alphafold
.
model
.
all_atom
cls
.
am_fold
=
alphafold
.
model
.
folding
cls
.
am_modules
=
alphafold
.
model
.
modules
cls
.
am_rigid
=
alphafold
.
model
.
r3
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_compare
(
self
):
def
test_compare
(
self
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
):
def
test_template_embedding
(
pair
,
batch
,
mask_2d
,
mc_
mask_2d
):
config
=
compare_utils
.
get_alphafold_config
()
config
=
compare_utils
.
get_alphafold_config
()
te
=
alphafold
.
model
.
modules
.
TemplateEmbedding
(
te
=
self
.
am_
modules
.
TemplateEmbedding
(
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
embeddings_and_evoformer
.
template
,
config
.
model
.
global_config
,
config
.
model
.
global_config
,
)
)
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
if
consts
.
is_multimer
:
act
=
te
(
pair
,
batch
,
mask_2d
,
multichain_mask_2d
=
mc_mask_2d
,
is_training
=
False
)
else
:
act
=
te
(
pair
,
batch
,
mask_2d
,
is_training
=
False
)
return
act
return
act
f
=
hk
.
transform
(
test_template_embedding
)
f
=
hk
.
transform
(
test_template_embedding
)
...
@@ -162,6 +232,14 @@ class Template(unittest.TestCase):
...
@@ -162,6 +232,14 @@ class Template(unittest.TestCase):
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
pair_act
=
np
.
random
.
rand
(
n_res
,
n_res
,
consts
.
c_z
).
astype
(
np
.
float32
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
=
random_template_feats
(
n_templ
,
n_res
)
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
batch
[
"template_all_atom_masks"
]
=
batch
[
"template_all_atom_mask"
]
multichain_mask_2d
=
None
if
consts
.
is_multimer
:
asym_id
=
batch
[
'asym_id'
][
0
]
multichain_mask_2d
=
(
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
pair_mask
=
np
.
random
.
randint
(
0
,
2
,
(
n_res
,
n_res
)).
astype
(
np
.
float32
)
# Fetch pretrained parameters (but only from one block)]
# Fetch pretrained parameters (but only from one block)]
params
=
compare_utils
.
fetch_alphafold_module_weights
(
params
=
compare_utils
.
fetch_alphafold_module_weights
(
...
@@ -169,7 +247,7 @@ class Template(unittest.TestCase):
...
@@ -169,7 +247,7 @@ class Template(unittest.TestCase):
)
)
out_gt
=
f
.
apply
(
out_gt
=
f
.
apply
(
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
params
,
jax
.
random
.
PRNGKey
(
42
),
pair_act
,
batch
,
pair_mask
,
multichain_mask_2d
).
block_until_ready
()
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
...
@@ -177,17 +255,36 @@ class Template(unittest.TestCase):
...
@@ -177,17 +255,36 @@ class Template(unittest.TestCase):
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
batch
[
"target_feat"
]
=
np
.
eye
(
22
)[
inds
]
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
out_repro
=
model
.
embed_templates
(
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()},
template_feats
=
{
k
:
torch
.
as_tensor
(
v
).
cuda
()
for
k
,
v
in
batch
.
items
()}
torch
.
as_tensor
(
pair_act
).
cuda
(),
if
consts
.
is_multimer
:
torch
.
as_tensor
(
pair_mask
).
cuda
(),
out_repro_all
=
model
.
template_embedder
(
templ_dim
=
0
,
template_feats
,
inplace_safe
=
False
torch
.
as_tensor
(
pair_act
).
cuda
(),
)
torch
.
as_tensor
(
pair_mask
).
cuda
(),
out_repro
=
out_repro
[
"template_pair_embedding"
]
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
_mask_trans
=
False
,
use_lma
=
False
,
inplace_safe
=
False
)
else
:
out_repro_all
=
model
.
template_embedder
(
template_feats
,
torch
.
as_tensor
(
pair_act
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
mask_trans
=
False
,
use_lma
=
False
,
inplace_safe
=
False
)
out_repro
=
out_repro_all
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/test_triangular_attention.py
View file @
bb3f51e5
...
@@ -79,16 +79,16 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -79,16 +79,16 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
+
name
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_start
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_start
if
starting
if
starting
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_att_end
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_att_end
)
)
# To save memory, the full model transposes inputs outside of the
# To save memory, the full model transposes inputs outside of the
...
@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
...
@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
chunk_size
=
None
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
bb3f51e5
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
import
torch
import
torch
import
re
import
numpy
as
np
import
numpy
as
np
import
unittest
import
unittest
from
openfold.model.triangular_multiplicative_update
import
*
from
openfold.model.triangular_multiplicative_update
import
*
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z
=
consts
.
c_z
c_z
=
consts
.
c_z
c
=
11
c
=
11
tm
=
TriangleMultiplicationOutgoing
(
if
re
.
fullmatch
(
"^model_[1-5]_multimer_v3$"
,
consts
.
model
):
c_z
,
tm
=
FusedTriangleMultiplicationOutgoing
(
c
,
c_z
,
)
c
,
)
else
:
tm
=
TriangleMultiplicationOutgoing
(
c_z
,
c
,
)
n_res
=
consts
.
c_z
n_res
=
consts
.
c_z
batch_size
=
consts
.
batch_size
batch_size
=
consts
.
batch_size
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config
.
model
.
global_config
,
config
.
model
.
global_config
,
name
=
name
,
name
=
name
,
)
)
act
=
tri_mul
(
act
=
pair_act
,
mask
=
pair_mask
)
act
=
tri_mul
(
pair_act
,
pair_mask
)
return
act
return
act
f
=
hk
.
transform
(
run_tri_mul
)
f
=
hk
.
transform
(
run_tri_mul
)
...
@@ -78,24 +85,25 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -78,24 +85,25 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+
name
+
name
)
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
numpy
.
Device
Array
)
params
=
tree_map
(
lambda
n
:
n
[
0
],
params
,
jax
.
Array
)
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
f
.
apply
(
params
,
None
,
pair_act
,
pair_mask
).
block_until_ready
()
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
out_gt
=
torch
.
as_tensor
(
np
.
array
(
out_gt
))
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
)
out_repro
=
module
(
out_repro
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
mask
=
torch
.
as_tensor
(
pair_mask
,
dtype
=
torch
.
float32
).
cuda
(),
inplace_safe
=
True
,
_inplace_chunk_size
=
4
,
inplace_safe
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
def
test_tri_mul_out_compare
(
self
):
...
@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
...
@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
np
.
random
.
randint
(
low
=
0
,
high
=
2
,
size
=
(
n_res
,
n_res
))
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
pair_mask
=
pair_mask
.
astype
(
np
.
float32
)
model
=
compare_utils
.
get_global_pretrained_openfold
()
model
=
compare_utils
.
get_global_pretrained_openfold
()
module
=
(
module
=
(
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_in
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_in
if
incoming
if
incoming
else
model
.
evoformer
.
blocks
[
0
].
core
.
tri_mul_out
else
model
.
evoformer
.
blocks
[
0
].
pair_stack
.
tri_mul_out
)
)
out_stock
=
module
(
out_stock
=
module
(
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
torch
.
as_tensor
(
pair_act
,
dtype
=
torch
.
float32
).
cuda
(),
...
...
train_openfold.py
View file @
bb3f51e5
import
argparse
import
argparse
import
logging
import
logging
import
os
import
os
import
random
import
sys
import
sys
import
time
import
numpy
as
np
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
import
torch
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
from
openfold.data.data_modules
import
OpenFoldDataModule
,
OpenFoldMultimerDataModule
OpenFoldDataModule
,
DummyDataLoader
,
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.argparse
_utils
import
remove_arguments
from
openfold.utils.callbacks
import
(
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
EarlyStoppingVerbose
,
)
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.multi_chain_permutation
import
multi_chain_permutation_align
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
...
@@ -39,6 +33,7 @@ from openfold.utils.validation_metrics import (
...
@@ -39,6 +33,7 @@ from openfold.utils.validation_metrics import (
)
)
from
openfold.utils.import_weights
import
(
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
import_jax_weights_
,
import_openfold_weights_
)
)
from
scripts.zero_to_fp32
import
(
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
,
get_fp32_state_dict_from_zero_checkpoint
,
...
@@ -53,7 +48,10 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -53,7 +48,10 @@ class OpenFoldWrapper(pl.LightningModule):
super
(
OpenFoldWrapper
,
self
).
__init__
()
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
model
=
AlphaFold
(
config
)
self
.
model
=
AlphaFold
(
config
)
self
.
is_multimer
=
self
.
config
.
globals
.
is_multimer
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
loss
=
AlphaFoldLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
)
...
@@ -98,12 +96,19 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -98,12 +96,19 @@ class OpenFoldWrapper(pl.LightningModule):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
# Remove the recycling dimension
# Remove the recycling dimension
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
if
self
.
is_multimer
:
batch
=
multi_chain_permutation_align
(
out
=
outputs
,
features
=
batch
,
ground_truth
=
ground_truth
)
# Compute loss
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
outputs
,
batch
,
_return_breakdown
=
True
...
@@ -126,13 +131,21 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -126,13 +131,21 @@ class OpenFoldWrapper(pl.LightningModule):
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
# Run the model
# Run the model
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
batch
[
"use_clamped_fape"
]
=
0.
if
self
.
is_multimer
:
batch
=
multi_chain_permutation_align
(
out
=
outputs
,
features
=
batch
,
ground_truth
=
ground_truth
)
# Compute loss and other metrics
_
,
loss_breakdown
=
self
.
loss
(
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
outputs
,
batch
,
_return_breakdown
=
True
)
)
...
@@ -221,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -221,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler
=
AlphaFoldLRScheduler
(
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
optimizer
,
last_epoch
=
self
.
last_lr_step
)
)
return
{
return
{
...
@@ -265,8 +279,8 @@ def main(args):
...
@@ -265,8 +279,8 @@ def main(args):
train
=
True
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
)
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
...
@@ -281,7 +295,7 @@ def main(args):
...
@@ -281,7 +295,7 @@ def main(args):
else
:
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_
state_dict
(
sd
)
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
...
@@ -291,12 +305,18 @@ def main(args):
...
@@ -291,12 +305,18 @@ def main(args):
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
#data_module = DummyDataLoader("new_batch.pickle")
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldDataModule
(
data_module
=
OpenFold
Multimer
DataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
**
vars
(
args
)
)
)
else
:
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
.
prepare_data
()
data_module
.
prepare_data
()
data_module
.
setup
()
data_module
.
setup
()
...
@@ -416,6 +436,10 @@ if __name__ == "__main__":
...
@@ -416,6 +436,10 @@ if __name__ == "__main__":
help
=
'''Cutoff for all templates. In training mode, templates are also
help
=
'''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
filtered by the release date of the target'''
)
)
parser
.
add_argument
(
"--train_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to the json file which records all the information of mmcif structures used during training"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--use_single_seq_mode"
,
type
=
str
,
default
=
False
,
"--use_single_seq_mode"
,
type
=
str
,
default
=
False
,
help
=
"Use single sequence embeddings instead of MSAs."
help
=
"Use single sequence embeddings instead of MSAs."
...
@@ -436,6 +460,10 @@ if __name__ == "__main__":
...
@@ -436,6 +460,10 @@ if __name__ == "__main__":
"--val_alignment_dir"
,
type
=
str
,
default
=
None
,
"--val_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed validation alignments"
help
=
"Directory containing precomputed validation alignments"
)
)
parser
.
add_argument
(
"--val_mmcif_data_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"path to the json file which records all the information of mmcif structures used during validation"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--kalign_binary_path"
,
type
=
str
,
default
=
'/usr/bin/kalign'
,
"--kalign_binary_path"
,
type
=
str
,
default
=
'/usr/bin/kalign'
,
help
=
"Path to the kalign binary"
help
=
"Path to the kalign binary"
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment