Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
82bda2d6
"...src/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "9f53922a9b4ef33e74367cc466384c98e4504ad7"
Commit
82bda2d6
authored
Aug 03, 2023
by
Christina Floristean
Browse files
Refactored multimer config update
parent
30764cf9
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
119 additions
and
225 deletions
+119
-225
openfold/config.py
openfold/config.py
+82
-190
openfold/model/embedders.py
openfold/model/embedders.py
+15
-12
openfold/model/model.py
openfold/model/model.py
+4
-2
openfold/model/template.py
openfold/model/template.py
+2
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+2
-2
openfold/utils/loss.py
openfold/utils/loss.py
+1
-4
tests/config.py
tests/config.py
+2
-1
tests/data_utils.py
tests/data_utils.py
+1
-1
tests/test_embedders.py
tests/test_embedders.py
+2
-2
tests/test_multimer_datamodule.py
tests/test_multimer_datamodule.py
+6
-9
tests/test_template.py
tests/test_template.py
+2
-0
No files found.
openfold/config.py
View file @
82bda2d6
...
@@ -154,50 +154,37 @@ def model_config(
...
@@ -154,50 +154,37 @@ def model_config(
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
update
(
multimer_config_update
.
copy_and_resolve_references
())
c
.
globals
.
bfloat16
=
False
del
c
.
model
.
template
.
template_pointwise_attention
c
.
globals
.
bfloat16_output
=
False
del
c
.
loss
.
fape
.
backbone
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
for
k
,
v
in
multimer_model_config_update
[
'model'
].
items
():
c
.
model
[
k
]
=
v
for
k
,
v
in
multimer_model_config_update
[
'loss'
].
items
():
c
.
loss
[
k
]
=
v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
#c.model.input_embedder.num_msa = 252
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
252
c
.
data
.
train
.
max_msa_clusters
=
252
c
.
data
.
eval
.
max_msa_clusters
=
252
c
.
data
.
predict
.
max_msa_clusters
=
252
c
.
data
.
predict
.
max_msa_clusters
=
252
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
eval
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
model
.
evoformer_stack
.
fuse_projection_weights
=
False
c
.
model
.
evoformer_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
elif
name
==
'model_4_multimer_v3'
:
elif
name
==
'model_4_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
eval
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
elif
name
==
'model_5_multimer_v3'
:
elif
name
==
'model_5_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
eval
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
else
:
c
.
data
.
train
.
max_msa_clusters
=
508
c
.
data
.
predict
.
max_msa_clusters
=
508
c
.
data
.
train
.
max_extra_msa
=
2048
c
.
data
.
predict
.
max_extra_msa
=
2048
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
])
else
:
else
:
raise
ValueError
(
"Invalid model name"
)
raise
ValueError
(
"Invalid model name"
)
...
@@ -451,7 +438,7 @@ config = mlc.ConfigDict(
...
@@ -451,7 +438,7 @@ config = mlc.ConfigDict(
"max_bin"
:
50.75
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
"no_bins"
:
39
,
},
},
"template_
a
ngle_embedder"
:
{
"template_
si
ngle_embedder"
:
{
# DISCREPANCY: c_in is supposed to be 51.
# DISCREPANCY: c_in is supposed to be 51.
"c_in"
:
57
,
"c_in"
:
57
,
"c_out"
:
c_m
,
"c_out"
:
c_m
,
...
@@ -682,226 +669,131 @@ config = mlc.ConfigDict(
...
@@ -682,226 +669,131 @@ config = mlc.ConfigDict(
}
}
)
)
multimer_model_config_update
=
{
multimer_config_update
=
mlc
.
ConfigDict
({
'model'
:
{
"globals"
:
{
"is_multimer"
:
True
,
"bfloat16"
:
False
,
# TODO: Change to True when implemented
"bfloat16_output"
:
False
},
"data"
:
{
"common"
:
{
"max_recycling_iters"
:
20
,
"unsupervised_features"
:
[
"aatype"
,
"residue_index"
,
"msa"
,
"num_alignments"
,
"seq_length"
,
"between_segment_residues"
,
"deletion_matrix"
,
"no_recycling_iters"
,
# Additional multimer features
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
]
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict"
:
{
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
},
"eval"
:
{
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
},
"train"
:
{
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
,
"crop_size"
:
640
},
},
"model"
:
{
"input_embedder"
:
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"tf_dim"
:
21
,
"msa_dim"
:
49
,
#"num_msa": 508,
#"num_msa": 508,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
"use_chain_relative"
:
True
},
},
"template"
:
{
"template"
:
{
"distogram"
:
{
"template_single_embedder"
:
{
"min_bin"
:
3.25
,
"c_in"
:
34
,
"max_bin"
:
50.75
,
"c_out"
:
c_m
"no_bins"
:
39
,
},
},
"template_pair_embedder"
:
{
"template_pair_embedder"
:
{
"c_
z
"
:
c_z
,
"c_
in
"
:
c_z
,
"c_out"
:
64
,
"c_out"
:
c_t
,
"c_dgram"
:
39
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
"c_aatype"
:
22
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
},
},
"template_pair_stack"
:
{
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
True
,
"tri_mul_first"
:
True
,
"fuse_projection_weights"
:
True
,
"fuse_projection_weights"
:
True
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
},
},
"c_t"
:
c_t
,
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
True
"use_unit_vector"
:
True
},
},
"extra_msa"
:
{
"extra_msa"
:
{
"extra_msa_embedder"
:
{
# "extra_msa_embedder": {
"c_in"
:
25
,
# "num_extra_msa": 2048
"c_out"
:
c_e
,
# },
#"num_extra_msa": 2048
},
"extra_msa_stack"
:
{
"extra_msa_stack"
:
{
"c_m"
:
c_e
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
8
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"fuse_projection_weights"
:
True
"clear_cache_between_blocks"
:
True
,
}
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
"enabled"
:
True
,
},
},
"evoformer_stack"
:
{
"evoformer_stack"
:
{
"c_m"
:
c_m
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
32
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"c_s"
:
c_s
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
48
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"fuse_projection_weights"
:
True
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
},
"structure_module"
:
{
"structure_module"
:
{
"c_s"
:
c_s
,
"trans_scale_factor"
:
20
"c_z"
:
c_z
,
"c_ipa"
:
16
,
"c_resnet"
:
128
,
"no_heads_ipa"
:
12
,
"no_qk_points"
:
4
,
"no_v_points"
:
8
,
"dropout_rate"
:
0.1
,
"no_blocks"
:
8
,
"no_transition_layers"
:
1
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
"trans_scale_factor"
:
20
,
"epsilon"
:
eps
,
# 1e-12,
"inf"
:
1e5
,
},
},
"heads"
:
{
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"ptm_weight"
:
0.2
,
"ptm_weight"
:
0.2
,
"iptm_weight"
:
0.8
,
"iptm_weight"
:
0.8
,
"enabled"
:
True
,
"enabled"
:
True
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
},
},
},
},
"recycle_early_stop_tolerance"
:
0.5
"recycle_early_stop_tolerance"
:
0.5
},
},
"loss"
:
{
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
# 1e-8,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.0
,
},
"fape"
:
{
"fape"
:
{
"intra_chain_backbone"
:
{
"intra_chain_backbone"
:
{
"clamp_distance"
:
10.0
,
"clamp_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
"weight"
:
0.5
},
},
"interface_backbone"
:
{
"interface_backbone"
:
{
"clamp_distance"
:
30.0
,
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
"weight"
:
0.5
},
}
"sidechain"
:
{
"clamp_distance"
:
10.0
,
"length_scale"
:
10.0
,
"weight"
:
0.5
,
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
},
"plddt_loss"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"no_bins"
:
50
,
"eps"
:
eps
,
# 1e-10,
"weight"
:
0.01
,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"num_classes"
:
23
,
"num_classes"
:
22
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
1.0
,
},
},
"violation"
:
{
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
True
,
"average_clashes"
:
True
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.03
# Not finetuning
"weight"
:
0.03
,
# Not finetuning
},
},
"tm"
:
{
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.1
,
"weight"
:
0.1
,
"enabled"
:
True
,
"enabled"
:
True
},
},
"chain_center_of_mass"
:
{
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.05
,
"weight"
:
0.05
,
"eps"
:
eps
,
"enabled"
:
True
"enabled"
:
True
,
}
},
"eps"
:
eps
,
}
}
}
}
)
openfold/model/embedders.py
View file @
82bda2d6
...
@@ -412,7 +412,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -412,7 +412,7 @@ class RecyclingEmbedder(nn.Module):
return
m_update
,
z_update
return
m_update
,
z_update
class
Template
A
ngleEmbedder
(
nn
.
Module
):
class
Template
Si
ngleEmbedder
(
nn
.
Module
):
"""
"""
Embeds the "template_angle_feat" feature.
Embeds the "template_angle_feat" feature.
...
@@ -432,7 +432,7 @@ class TemplateAngleEmbedder(nn.Module):
...
@@ -432,7 +432,7 @@ class TemplateAngleEmbedder(nn.Module):
c_out:
c_out:
Output channel dimension
Output channel dimension
"""
"""
super
(
Template
A
ngleEmbedder
,
self
).
__init__
()
super
(
Template
Si
ngleEmbedder
,
self
).
__init__
()
self
.
c_out
=
c_out
self
.
c_out
=
c_out
self
.
c_in
=
c_in
self
.
c_in
=
c_in
...
@@ -543,8 +543,8 @@ class TemplateEmbedder(nn.Module):
...
@@ -543,8 +543,8 @@ class TemplateEmbedder(nn.Module):
super
(
TemplateEmbedder
,
self
).
__init__
()
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
template_
a
ngle_embedder
=
Template
A
ngleEmbedder
(
self
.
template_
si
ngle_embedder
=
Template
Si
ngleEmbedder
(
**
config
[
"template_
a
ngle_embedder"
],
**
config
[
"template_
si
ngle_embedder"
],
)
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
**
config
[
"template_pair_embedder"
],
...
@@ -651,7 +651,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -651,7 +651,7 @@ class TemplateEmbedder(nn.Module):
)
)
# [*, S_t, N, C_m]
# [*, S_t, N, C_m]
a
=
self
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
self
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
[
"template_single_embedding"
]
=
a
...
@@ -660,7 +660,7 @@ class TemplateEmbedder(nn.Module):
...
@@ -660,7 +660,7 @@ class TemplateEmbedder(nn.Module):
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
c_
z
:
int
,
c_
in
:
int
,
c_out
:
int
,
c_out
:
int
,
c_dgram
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
c_aatype
:
int
,
...
@@ -670,8 +670,8 @@ class TemplatePairEmbedderMultimer(nn.Module):
...
@@ -670,8 +670,8 @@ class TemplatePairEmbedderMultimer(nn.Module):
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
,
init
=
'relu'
)
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_
z
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_
in
)
self
.
query_embedding_linear
=
Linear
(
c_
z
,
c_out
,
init
=
'relu'
)
self
.
query_embedding_linear
=
Linear
(
c_
in
,
c_out
,
init
=
'relu'
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
x_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
x_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
...
@@ -722,11 +722,11 @@ class TemplatePairEmbedderMultimer(nn.Module):
...
@@ -722,11 +722,11 @@ class TemplatePairEmbedderMultimer(nn.Module):
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
c_in
:
int
,
c_in
:
int
,
c_
m
:
int
,
c_
out
:
int
,
):
):
super
(
TemplateSingleEmbedderMultimer
,
self
).
__init__
()
super
(
TemplateSingleEmbedderMultimer
,
self
).
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_
m
)
self
.
template_single_embedder
=
Linear
(
c_in
,
c_
out
)
self
.
template_projector
=
Linear
(
c_
m
,
c_
m
)
self
.
template_projector
=
Linear
(
c_
out
,
c_
out
)
def
forward
(
self
,
def
forward
(
self
,
batch
,
batch
,
...
@@ -797,6 +797,7 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -797,6 +797,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim
,
templ_dim
,
chunk_size
,
chunk_size
,
multichain_mask_2d
,
multichain_mask_2d
,
_mask_trans
=
True
,
use_lma
=
False
,
use_lma
=
False
,
inplace_safe
=
False
inplace_safe
=
False
):
):
...
@@ -869,7 +870,9 @@ class TemplateEmbedderMultimer(nn.Module):
...
@@ -869,7 +870,9 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
[
"template_pair_embedding"
],
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
...
...
openfold/model/model.py
View file @
82bda2d6
...
@@ -139,7 +139,8 @@ class AlphaFold(nn.Module):
...
@@ -139,7 +139,8 @@ class AlphaFold(nn.Module):
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
multichain_mask_2d
=
multichain_mask_2d
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
)
)
feats
[
"template_torsion_angles_mask"
]
=
(
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
template_embeds
[
"template_mask"
]
...
@@ -161,7 +162,8 @@ class AlphaFold(nn.Module):
...
@@ -161,7 +162,8 @@ class AlphaFold(nn.Module):
templ_dim
,
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
)
)
return
template_embeds
return
template_embeds
...
...
openfold/model/template.py
View file @
82bda2d6
...
@@ -552,7 +552,7 @@ def embed_templates_offload(
...
@@ -552,7 +552,7 @@ def embed_templates_offload(
)
)
# [*, N, C_m]
# [*, N, C_m]
a
=
model
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
model
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
[
"template_single_embedding"
]
=
a
...
@@ -663,7 +663,7 @@ def embed_templates_average(
...
@@ -663,7 +663,7 @@ def embed_templates_average(
)
)
# [*, N, C_m]
# [*, N, C_m]
a
=
model
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
model
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
ret
[
"template_single_embedding"
]
=
a
...
...
openfold/utils/import_weights.py
View file @
82bda2d6
...
@@ -577,10 +577,10 @@ def generate_translation_dict(model, version, is_multimer=False):
...
@@ -577,10 +577,10 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
},
},
"template_single_embedding"
:
LinearParams
(
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_
a
ngle_embedder
.
linear_1
model
.
template_embedder
.
template_
si
ngle_embedder
.
linear_1
),
),
"template_projection"
:
LinearParams
(
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_
a
ngle_embedder
.
linear_2
model
.
template_embedder
.
template_
si
ngle_embedder
.
linear_2
),
),
}
}
else
:
else
:
...
...
openfold/utils/loss.py
View file @
82bda2d6
...
@@ -1668,11 +1668,8 @@ def chain_center_of_mass_loss(
...
@@ -1668,11 +1668,8 @@ def chain_center_of_mass_loss(
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
=
asym_id
.
unique
()
# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
long
()).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
long
()
-
1
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
...
...
tests/config.py
View file @
82bda2d6
...
@@ -19,7 +19,8 @@ consts = mlc.ConfigDict(
...
@@ -19,7 +19,8 @@ consts = mlc.ConfigDict(
"c_s"
:
384
,
"c_s"
:
384
,
"c_t"
:
64
,
"c_t"
:
64
,
"c_e"
:
64
,
"c_e"
:
64
,
"msa_logits"
:
22
# monomer: 23, multimer: 22
"msa_logits"
:
22
,
# monomer: 23, multimer: 22
"template_mmcif_dir"
:
None
# Set for test_multimer_datamodule
}
}
)
)
...
...
tests/data_utils.py
View file @
82bda2d6
...
@@ -40,7 +40,7 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
...
@@ -40,7 +40,7 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
return
np
.
array
(
asym_ids
).
astype
(
np
.
float32
)
+
1
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
...
...
tests/test_embedders.py
View file @
82bda2d6
...
@@ -21,7 +21,7 @@ from openfold.model.embedders import (
...
@@ -21,7 +21,7 @@ from openfold.model.embedders import (
InputEmbedder
,
InputEmbedder
,
InputEmbedderMultimer
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
RecyclingEmbedder
,
Template
A
ngleEmbedder
,
Template
Si
ngleEmbedder
,
TemplatePairEmbedder
TemplatePairEmbedder
)
)
...
@@ -96,7 +96,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
...
@@ -96,7 +96,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
n_templ
=
4
n_templ
=
4
n_res
=
256
n_res
=
256
tae
=
Template
A
ngleEmbedder
(
tae
=
Template
Si
ngleEmbedder
(
template_angle_dim
,
template_angle_dim
,
c_m
,
c_m
,
)
)
...
...
tests/test_multimer_datamodule.py
View file @
82bda2d6
...
@@ -12,24 +12,21 @@
...
@@ -12,24 +12,21 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
pathlib
import
Path
import
os
import
shutil
import
shutil
import
pickle
import
torch
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
functools
import
partial
import
unittest
import
unittest
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.config
import
model_config
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
,
OpenFoldDataModule
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
tests.config
import
consts
from
tests.config
import
consts
import
logging
import
logging
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
import
os
@
unittest
.
skipIf
(
not
consts
.
is_multimer
or
consts
.
template_mmcif_dir
is
None
,
"Template mmcif dir required."
)
class
TestMultimerDataModule
(
unittest
.
TestCase
):
class
TestMultimerDataModule
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
"""
"""
...
@@ -38,14 +35,14 @@ class TestMultimerDataModule(unittest.TestCase):
...
@@ -38,14 +35,14 @@ class TestMultimerDataModule(unittest.TestCase):
use model_1_multimer_v3 for now
use model_1_multimer_v3 for now
"""
"""
self
.
config
=
model_config
(
self
.
config
=
model_config
(
"model_1_multimer_v3"
,
consts
.
model
,
train
=
True
,
train
=
True
,
low_prec
=
True
)
low_prec
=
True
)
self
.
data_module
=
OpenFoldMultimerDataModule
(
self
.
data_module
=
OpenFoldMultimerDataModule
(
config
=
self
.
config
.
data
,
config
=
self
.
config
.
data
,
batch_seed
=
42
,
batch_seed
=
42
,
train_epoch_len
=
100
,
train_epoch_len
=
100
,
template_mmcif_dir
=
"/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/"
,
template_mmcif_dir
=
consts
.
template_mmcif_dir
,
template_release_dates_cache_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data/mmcif_cache.json"
),
template_release_dates_cache_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data/mmcif_cache.json"
),
max_template_date
=
"2500-01-01"
,
max_template_date
=
"2500-01-01"
,
train_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data/mmcifs"
),
train_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data/mmcifs"
),
...
...
tests/test_template.py
View file @
82bda2d6
...
@@ -263,6 +263,7 @@ class Template(unittest.TestCase):
...
@@ -263,6 +263,7 @@ class Template(unittest.TestCase):
templ_dim
=
0
,
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
_mask_trans
=
False
,
use_lma
=
False
,
use_lma
=
False
,
inplace_safe
=
False
inplace_safe
=
False
)
)
...
@@ -273,6 +274,7 @@ class Template(unittest.TestCase):
...
@@ -273,6 +274,7 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_mask
).
cuda
(),
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
chunk_size
=
consts
.
chunk_size
,
mask_trans
=
False
,
use_lma
=
False
,
use_lma
=
False
,
inplace_safe
=
False
inplace_safe
=
False
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment