Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
82bda2d6
Commit
82bda2d6
authored
Aug 03, 2023
by
Christina Floristean
Browse files
Refactored multimer config update
parent
30764cf9
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
119 additions
and
225 deletions
+119
-225
openfold/config.py
openfold/config.py
+82
-190
openfold/model/embedders.py
openfold/model/embedders.py
+15
-12
openfold/model/model.py
openfold/model/model.py
+4
-2
openfold/model/template.py
openfold/model/template.py
+2
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+2
-2
openfold/utils/loss.py
openfold/utils/loss.py
+1
-4
tests/config.py
tests/config.py
+2
-1
tests/data_utils.py
tests/data_utils.py
+1
-1
tests/test_embedders.py
tests/test_embedders.py
+2
-2
tests/test_multimer_datamodule.py
tests/test_multimer_datamodule.py
+6
-9
tests/test_template.py
tests/test_template.py
+2
-0
No files found.
openfold/config.py
View file @
82bda2d6
...
...
@@ -154,50 +154,37 @@ def model_config(
c
.
model
.
heads
.
tm
.
enabled
=
True
c
.
loss
.
tm
.
weight
=
0.1
elif
"multimer"
in
name
:
c
.
globals
.
is_multimer
=
True
c
.
globals
.
bfloat16
=
False
c
.
globals
.
bfloat16_output
=
False
c
.
loss
.
masked_msa
.
num_classes
=
22
c
.
data
.
common
.
max_recycling_iters
=
20
for
k
,
v
in
multimer_model_config_update
[
'model'
].
items
():
c
.
model
[
k
]
=
v
for
k
,
v
in
multimer_model_config_update
[
'loss'
].
items
():
c
.
loss
[
k
]
=
v
c
.
update
(
multimer_config_update
.
copy_and_resolve_references
())
del
c
.
model
.
template
.
template_pointwise_attention
del
c
.
loss
.
fape
.
backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if
re
.
fullmatch
(
"^model_[1-5]_multimer(_v2)?$"
,
name
):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
crop_size
=
384
c
.
data
.
train
.
max_msa_clusters
=
252
c
.
data
.
eval
.
max_msa_clusters
=
252
c
.
data
.
predict
.
max_msa_clusters
=
252
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
eval
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
c
.
model
.
evoformer_stack
.
fuse_projection_weights
=
False
c
.
model
.
extra_msa
.
extra_msa_stack
.
fuse_projection_weights
=
False
c
.
model
.
template
.
template_pair_stack
.
fuse_projection_weights
=
False
elif
name
==
'model_4_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
eval
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
elif
name
==
'model_5_multimer_v3'
:
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c
.
data
.
train
.
max_extra_msa
=
1152
c
.
data
.
eval
.
max_extra_msa
=
1152
c
.
data
.
predict
.
max_extra_msa
=
1152
else
:
c
.
data
.
train
.
max_msa_clusters
=
508
c
.
data
.
predict
.
max_msa_clusters
=
508
c
.
data
.
train
.
max_extra_msa
=
2048
c
.
data
.
predict
.
max_extra_msa
=
2048
c
.
data
.
common
.
unsupervised_features
.
extend
([
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
])
else
:
raise
ValueError
(
"Invalid model name"
)
...
...
@@ -451,7 +438,7 @@ config = mlc.ConfigDict(
"max_bin"
:
50.75
,
"no_bins"
:
39
,
},
"template_
a
ngle_embedder"
:
{
"template_
si
ngle_embedder"
:
{
# DISCREPANCY: c_in is supposed to be 51.
"c_in"
:
57
,
"c_out"
:
c_m
,
...
...
@@ -682,226 +669,131 @@ config = mlc.ConfigDict(
}
)
multimer_model_config_update
=
{
'model'
:
{
multimer_config_update
=
mlc
.
ConfigDict
({
"globals"
:
{
"is_multimer"
:
True
,
"bfloat16"
:
False
,
# TODO: Change to True when implemented
"bfloat16_output"
:
False
},
"data"
:
{
"common"
:
{
"max_recycling_iters"
:
20
,
"unsupervised_features"
:
[
"aatype"
,
"residue_index"
,
"msa"
,
"num_alignments"
,
"seq_length"
,
"between_segment_residues"
,
"deletion_matrix"
,
"no_recycling_iters"
,
# Additional multimer features
"msa_mask"
,
"seq_mask"
,
"asym_id"
,
"entity_id"
,
"sym_id"
,
]
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict"
:
{
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
},
"eval"
:
{
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
},
"train"
:
{
"max_msa_clusters"
:
508
,
"max_extra_msa"
:
2048
,
"crop_size"
:
640
},
},
"model"
:
{
"input_embedder"
:
{
"tf_dim"
:
21
,
"msa_dim"
:
49
,
#"num_msa": 508,
"c_z"
:
c_z
,
"c_m"
:
c_m
,
"relpos_k"
:
32
,
"max_relative_chain"
:
2
,
"max_relative_idx"
:
32
,
"use_chain_relative"
:
True
,
"use_chain_relative"
:
True
},
"template"
:
{
"distogram"
:
{
"min_bin"
:
3.25
,
"max_bin"
:
50.75
,
"no_bins"
:
39
,
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_out"
:
c_m
},
"template_pair_embedder"
:
{
"c_
z
"
:
c_z
,
"c_out"
:
64
,
"c_
in
"
:
c_z
,
"c_out"
:
c_t
,
"c_dgram"
:
39
,
"c_aatype"
:
22
,
},
"template_single_embedder"
:
{
"c_in"
:
34
,
"c_m"
:
c_m
,
"c_aatype"
:
22
},
"template_pair_stack"
:
{
"c_t"
:
c_t
,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att"
:
16
,
"c_hidden_tri_mul"
:
64
,
"no_blocks"
:
2
,
"no_heads"
:
4
,
"pair_transition_n"
:
2
,
"dropout_rate"
:
0.25
,
"tri_mul_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"inf"
:
1e9
,
"fuse_projection_weights"
:
True
},
"c_t"
:
c_t
,
"c_z"
:
c_z
,
"inf"
:
1e5
,
# 1e9,
"eps"
:
eps
,
# 1e-6,
"enabled"
:
templates_enabled
,
"embed_angles"
:
embed_template_torsion_angles
,
"use_unit_vector"
:
True
},
"extra_msa"
:
{
"extra_msa_embedder"
:
{
"c_in"
:
25
,
"c_out"
:
c_e
,
#"num_extra_msa": 2048
},
# "extra_msa_embedder": {
# "num_extra_msa": 2048
# },
"extra_msa_stack"
:
{
"c_m"
:
c_e
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
8
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
"enabled"
:
True
,
"fuse_projection_weights"
:
True
}
},
"evoformer_stack"
:
{
"c_m"
:
c_m
,
"c_z"
:
c_z
,
"c_hidden_msa_att"
:
32
,
"c_hidden_opm"
:
32
,
"c_hidden_mul"
:
128
,
"c_hidden_pair_att"
:
32
,
"c_s"
:
c_s
,
"no_heads_msa"
:
8
,
"no_heads_pair"
:
4
,
"no_blocks"
:
48
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"opm_first"
:
True
,
"fuse_projection_weights"
:
True
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
False
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"fuse_projection_weights"
:
True
},
"structure_module"
:
{
"c_s"
:
c_s
,
"c_z"
:
c_z
,
"c_ipa"
:
16
,
"c_resnet"
:
128
,
"no_heads_ipa"
:
12
,
"no_qk_points"
:
4
,
"no_v_points"
:
8
,
"dropout_rate"
:
0.1
,
"no_blocks"
:
8
,
"no_transition_layers"
:
1
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
"trans_scale_factor"
:
20
,
"epsilon"
:
eps
,
# 1e-12,
"inf"
:
1e5
,
"trans_scale_factor"
:
20
},
"heads"
:
{
"lddt"
:
{
"no_bins"
:
50
,
"c_in"
:
c_s
,
"c_hidden"
:
128
,
},
"distogram"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
},
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"ptm_weight"
:
0.2
,
"iptm_weight"
:
0.8
,
"enabled"
:
True
,
"enabled"
:
True
},
"masked_msa"
:
{
"c_m"
:
c_m
,
"c_out"
:
22
,
},
"experimentally_resolved"
:
{
"c_s"
:
c_s
,
"c_out"
:
37
,
"c_out"
:
22
},
},
"recycle_early_stop_tolerance"
:
0.5
},
"loss"
:
{
"distogram"
:
{
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.3
,
},
"experimentally_resolved"
:
{
"eps"
:
eps
,
# 1e-8,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"weight"
:
0.0
,
},
"fape"
:
{
"intra_chain_backbone"
:
{
"clamp_distance"
:
10.0
,
"loss_unit_distance"
:
10.0
,
"weight"
:
0.5
,
"weight"
:
0.5
},
"interface_backbone"
:
{
"clamp_distance"
:
30.0
,
"loss_unit_distance"
:
20.0
,
"weight"
:
0.5
,
},
"sidechain"
:
{
"clamp_distance"
:
10.0
,
"length_scale"
:
10.0
,
"weight"
:
0.5
,
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
},
"plddt_loss"
:
{
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"no_bins"
:
50
,
"eps"
:
eps
,
# 1e-10,
"weight"
:
0.01
,
"weight"
:
0.5
}
},
"masked_msa"
:
{
"num_classes"
:
23
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
2.0
,
},
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
1.0
,
"num_classes"
:
22
},
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"average_clashes"
:
True
,
"eps"
:
eps
,
# 1e-6,
"weight"
:
0.03
,
# Not finetuning
"weight"
:
0.03
# Not finetuning
},
"tm"
:
{
"max_bin"
:
31
,
"no_bins"
:
64
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.1
,
"enabled"
:
True
,
"enabled"
:
True
},
"chain_center_of_mass"
:
{
"clamp_distance"
:
-
4.0
,
"weight"
:
0.05
,
"eps"
:
eps
,
"enabled"
:
True
,
},
"eps"
:
eps
,
"enabled"
:
True
}
}
}
}
)
openfold/model/embedders.py
View file @
82bda2d6
...
...
@@ -412,7 +412,7 @@ class RecyclingEmbedder(nn.Module):
return
m_update
,
z_update
class
Template
A
ngleEmbedder
(
nn
.
Module
):
class
Template
Si
ngleEmbedder
(
nn
.
Module
):
"""
Embeds the "template_angle_feat" feature.
...
...
@@ -432,7 +432,7 @@ class TemplateAngleEmbedder(nn.Module):
c_out:
Output channel dimension
"""
super
(
Template
A
ngleEmbedder
,
self
).
__init__
()
super
(
Template
Si
ngleEmbedder
,
self
).
__init__
()
self
.
c_out
=
c_out
self
.
c_in
=
c_in
...
...
@@ -543,8 +543,8 @@ class TemplateEmbedder(nn.Module):
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
template_
a
ngle_embedder
=
Template
A
ngleEmbedder
(
**
config
[
"template_
a
ngle_embedder"
],
self
.
template_
si
ngle_embedder
=
Template
Si
ngleEmbedder
(
**
config
[
"template_
si
ngle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
...
...
@@ -651,7 +651,7 @@ class TemplateEmbedder(nn.Module):
)
# [*, S_t, N, C_m]
a
=
self
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
self
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
...
...
@@ -660,7 +660,7 @@ class TemplateEmbedder(nn.Module):
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_
z
:
int
,
c_
in
:
int
,
c_out
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
...
...
@@ -670,8 +670,8 @@ class TemplatePairEmbedderMultimer(nn.Module):
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
,
init
=
'relu'
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_
z
)
self
.
query_embedding_linear
=
Linear
(
c_
z
,
c_out
,
init
=
'relu'
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_
in
)
self
.
query_embedding_linear
=
Linear
(
c_
in
,
c_out
,
init
=
'relu'
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
self
.
x_linear
=
Linear
(
1
,
c_out
,
init
=
'relu'
)
...
...
@@ -722,11 +722,11 @@ class TemplatePairEmbedderMultimer(nn.Module):
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_in
:
int
,
c_
m
:
int
,
c_
out
:
int
,
):
super
(
TemplateSingleEmbedderMultimer
,
self
).
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_
m
)
self
.
template_projector
=
Linear
(
c_
m
,
c_
m
)
self
.
template_single_embedder
=
Linear
(
c_in
,
c_
out
)
self
.
template_projector
=
Linear
(
c_
out
,
c_
out
)
def
forward
(
self
,
batch
,
...
...
@@ -797,6 +797,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim
,
chunk_size
,
multichain_mask_2d
,
_mask_trans
=
True
,
use_lma
=
False
,
inplace_safe
=
False
):
...
...
@@ -869,7 +870,9 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds
[
"template_pair_embedding"
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
# [*, N, N, C_z]
t
=
torch
.
sum
(
t
,
dim
=-
4
)
/
n_templ
...
...
openfold/model/model.py
View file @
82bda2d6
...
...
@@ -139,7 +139,8 @@ class AlphaFold(nn.Module):
chunk_size
=
self
.
globals
.
chunk_size
,
multichain_mask_2d
=
multichain_mask_2d
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
)
feats
[
"template_torsion_angles_mask"
]
=
(
template_embeds
[
"template_mask"
]
...
...
@@ -161,7 +162,8 @@ class AlphaFold(nn.Module):
templ_dim
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
)
return
template_embeds
...
...
openfold/model/template.py
View file @
82bda2d6
...
...
@@ -552,7 +552,7 @@ def embed_templates_offload(
)
# [*, N, C_m]
a
=
model
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
model
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
...
...
@@ -663,7 +663,7 @@ def embed_templates_average(
)
# [*, N, C_m]
a
=
model
.
template_
a
ngle_embedder
(
template_angle_feat
)
a
=
model
.
template_
si
ngle_embedder
(
template_angle_feat
)
ret
[
"template_single_embedding"
]
=
a
...
...
openfold/utils/import_weights.py
View file @
82bda2d6
...
...
@@ -577,10 +577,10 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention"
:
AttentionParams
(
model
.
template_embedder
.
template_pointwise_att
.
mha
),
},
"template_single_embedding"
:
LinearParams
(
model
.
template_embedder
.
template_
a
ngle_embedder
.
linear_1
model
.
template_embedder
.
template_
si
ngle_embedder
.
linear_1
),
"template_projection"
:
LinearParams
(
model
.
template_embedder
.
template_
a
ngle_embedder
.
linear_2
model
.
template_embedder
.
template_
si
ngle_embedder
.
linear_2
),
}
else
:
...
...
openfold/utils/loss.py
View file @
82bda2d6
...
...
@@ -1668,11 +1668,8 @@ def chain_center_of_mass_loss(
all_atom_pred_pos
=
all_atom_pred_pos
[...,
ca_pos
,
:]
all_atom_positions
=
all_atom_positions
[...,
ca_pos
,
:]
all_atom_mask
=
all_atom_mask
[...,
ca_pos
:
(
ca_pos
+
1
)]
# keep dim
chains
=
asym_id
.
unique
()
# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
long
()
-
1
,
num_classes
=
chains
.
shape
[
0
]).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
torch
.
nn
.
functional
.
one_hot
(
asym_id
.
long
()).
to
(
dtype
=
all_atom_mask
.
dtype
)
one_hot
=
one_hot
*
all_atom_mask
chain_pos_mask
=
one_hot
.
transpose
(
-
2
,
-
1
)
chain_exists
=
torch
.
any
(
chain_pos_mask
,
dim
=-
1
).
float
()
...
...
tests/config.py
View file @
82bda2d6
...
...
@@ -19,7 +19,8 @@ consts = mlc.ConfigDict(
"c_s"
:
384
,
"c_t"
:
64
,
"c_e"
:
64
,
"msa_logits"
:
22
# monomer: 23, multimer: 22
"msa_logits"
:
22
,
# monomer: 23, multimer: 22
"template_mmcif_dir"
:
None
# Set for test_multimer_datamodule
}
)
...
...
tests/data_utils.py
View file @
82bda2d6
...
...
@@ -40,7 +40,7 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
asym_ids
.
extend
(
piece
*
[
idx
])
asym_ids
.
extend
((
n_res
-
sum
(
pieces
))
*
[
final_idx
])
return
np
.
array
(
asym_ids
).
astype
(
np
.
int64
)
return
np
.
array
(
asym_ids
).
astype
(
np
.
float32
)
+
1
def
random_template_feats
(
n_templ
,
n
,
batch_size
=
None
):
...
...
tests/test_embedders.py
View file @
82bda2d6
...
...
@@ -21,7 +21,7 @@ from openfold.model.embedders import (
InputEmbedder
,
InputEmbedderMultimer
,
RecyclingEmbedder
,
Template
A
ngleEmbedder
,
Template
Si
ngleEmbedder
,
TemplatePairEmbedder
)
...
...
@@ -96,7 +96,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
n_templ
=
4
n_res
=
256
tae
=
Template
A
ngleEmbedder
(
tae
=
Template
Si
ngleEmbedder
(
template_angle_dim
,
c_m
,
)
...
...
tests/test_multimer_datamodule.py
View file @
82bda2d6
...
...
@@ -12,24 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
pathlib
import
Path
import
os
import
shutil
import
pickle
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
functools
import
partial
import
unittest
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
,
OpenFoldDataModule
from
openfold.data.data_modules
import
OpenFoldMultimerDataModule
from
openfold.model.model
import
AlphaFold
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
from
tests.config
import
consts
import
logging
logger
=
logging
.
getLogger
(
__name__
)
import
os
@
unittest
.
skipIf
(
not
consts
.
is_multimer
or
consts
.
template_mmcif_dir
is
None
,
"Template mmcif dir required."
)
class
TestMultimerDataModule
(
unittest
.
TestCase
):
def
setUp
(
self
):
"""
...
...
@@ -38,14 +35,14 @@ class TestMultimerDataModule(unittest.TestCase):
use model_1_multimer_v3 for now
"""
self
.
config
=
model_config
(
"model_1_multimer_v3"
,
consts
.
model
,
train
=
True
,
low_prec
=
True
)
self
.
data_module
=
OpenFoldMultimerDataModule
(
config
=
self
.
config
.
data
,
batch_seed
=
42
,
train_epoch_len
=
100
,
template_mmcif_dir
=
"/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/"
,
template_mmcif_dir
=
consts
.
template_mmcif_dir
,
template_release_dates_cache_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data/mmcif_cache.json"
),
max_template_date
=
"2500-01-01"
,
train_data_dir
=
os
.
path
.
join
(
os
.
getcwd
(),
"tests/test_data/mmcifs"
),
...
...
tests/test_template.py
View file @
82bda2d6
...
...
@@ -263,6 +263,7 @@ class Template(unittest.TestCase):
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
multichain_mask_2d
=
torch
.
as_tensor
(
multichain_mask_2d
).
cuda
(),
_mask_trans
=
False
,
use_lma
=
False
,
inplace_safe
=
False
)
...
...
@@ -273,6 +274,7 @@ class Template(unittest.TestCase):
torch
.
as_tensor
(
pair_mask
).
cuda
(),
templ_dim
=
0
,
chunk_size
=
consts
.
chunk_size
,
mask_trans
=
False
,
use_lma
=
False
,
inplace_safe
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment