"...src/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "9f53922a9b4ef33e74367cc466384c98e4504ad7"
Commit 82bda2d6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Refactored multimer config update

parent 30764cf9
...@@ -154,50 +154,37 @@ def model_config( ...@@ -154,50 +154,37 @@ def model_config(
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif "multimer" in name:
c.globals.is_multimer = True c.update(multimer_config_update.copy_and_resolve_references())
c.globals.bfloat16 = False del c.model.template.template_pointwise_attention
c.globals.bfloat16_output = False del c.loss.fape.backbone
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20
for k, v in multimer_model_config_update['model'].items():
c.model[k] = v
for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name): if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252 #c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152 #c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 252 c.data.train.max_msa_clusters = 252
c.data.eval.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252 c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152 c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152 c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3': elif name == 'model_4_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152 #c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152 c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152 c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3': elif name == 'model_5_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152 #c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152 c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152 c.data.predict.max_extra_msa = 1152
else:
c.data.train.max_msa_clusters = 508
c.data.predict.max_msa_clusters = 508
c.data.train.max_extra_msa = 2048
c.data.predict.max_extra_msa = 2048
c.data.common.unsupervised_features.extend([
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
])
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
...@@ -451,7 +438,7 @@ config = mlc.ConfigDict( ...@@ -451,7 +438,7 @@ config = mlc.ConfigDict(
"max_bin": 50.75, "max_bin": 50.75,
"no_bins": 39, "no_bins": 39,
}, },
"template_angle_embedder": { "template_single_embedder": {
# DISCREPANCY: c_in is supposed to be 51. # DISCREPANCY: c_in is supposed to be 51.
"c_in": 57, "c_in": 57,
"c_out": c_m, "c_out": c_m,
...@@ -682,226 +669,131 @@ config = mlc.ConfigDict( ...@@ -682,226 +669,131 @@ config = mlc.ConfigDict(
} }
) )
multimer_model_config_update = { multimer_config_update = mlc.ConfigDict({
'model': { "globals": {
"is_multimer": True,
"bfloat16": False, # TODO: Change to True when implemented
"bfloat16_output": False
},
"data": {
"common": {
"max_recycling_iters": 20,
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
# Additional multimer features
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"eval": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"crop_size": 640
},
},
"model": {
"input_embedder": { "input_embedder": {
"tf_dim": 21, "tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508, #"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2, "max_relative_chain": 2,
"max_relative_idx": 32, "max_relative_idx": 32,
"use_chain_relative": True, "use_chain_relative": True
}, },
"template": { "template": {
"distogram": { "template_single_embedder": {
"min_bin": 3.25, "c_in": 34,
"max_bin": 50.75, "c_out": c_m
"no_bins": 39,
}, },
"template_pair_embedder": { "template_pair_embedder": {
"c_z": c_z, "c_in": c_z,
"c_out": 64, "c_out": c_t,
"c_dgram": 39, "c_dgram": 39,
"c_aatype": 22, "c_aatype": 22
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
}, },
"template_pair_stack": { "template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True, "tri_mul_first": True,
"fuse_projection_weights": True, "fuse_projection_weights": True
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
}, },
"c_t": c_t, "c_t": c_t,
"c_z": c_z, "c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": True "use_unit_vector": True
}, },
"extra_msa": { "extra_msa": {
"extra_msa_embedder": { # "extra_msa_embedder": {
"c_in": 25, # "num_extra_msa": 2048
"c_out": c_e, # },
#"num_extra_msa": 2048
},
"extra_msa_stack": { "extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True, "opm_first": True,
"fuse_projection_weights": True, "fuse_projection_weights": True
"clear_cache_between_blocks": True, }
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
}, },
"evoformer_stack": { "evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True, "opm_first": True,
"fuse_projection_weights": True, "fuse_projection_weights": True
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
}, },
"structure_module": { "structure_module": {
"c_s": c_s, "trans_scale_factor": 20
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
}, },
"heads": { "heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": { "tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"ptm_weight": 0.2, "ptm_weight": 0.2,
"iptm_weight": 0.8, "iptm_weight": 0.8,
"enabled": True, "enabled": True
}, },
"masked_msa": { "masked_msa": {
"c_m": c_m, "c_out": 22
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
}, },
}, },
"recycle_early_stop_tolerance": 0.5 "recycle_early_stop_tolerance": 0.5
}, },
"loss": { "loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": { "fape": {
"intra_chain_backbone": { "intra_chain_backbone": {
"clamp_distance": 10.0, "clamp_distance": 10.0,
"loss_unit_distance": 10.0, "loss_unit_distance": 10.0,
"weight": 0.5, "weight": 0.5
}, },
"interface_backbone": { "interface_backbone": {
"clamp_distance": 30.0, "clamp_distance": 30.0,
"loss_unit_distance": 20.0, "loss_unit_distance": 20.0,
"weight": 0.5, "weight": 0.5
}, }
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"plddt_loss": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
}, },
"masked_msa": { "masked_msa": {
"num_classes": 23, "num_classes": 22
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
}, },
"violation": { "violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"average_clashes": True, "average_clashes": True,
"eps": eps, # 1e-6, "weight": 0.03 # Not finetuning
"weight": 0.03, # Not finetuning
}, },
"tm": { "tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.1, "weight": 0.1,
"enabled": True, "enabled": True
}, },
"chain_center_of_mass": { "chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.05, "weight": 0.05,
"eps": eps, "enabled": True
"enabled": True, }
},
"eps": eps,
} }
} })
...@@ -412,7 +412,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -412,7 +412,7 @@ class RecyclingEmbedder(nn.Module):
return m_update, z_update return m_update, z_update
class TemplateAngleEmbedder(nn.Module): class TemplateSingleEmbedder(nn.Module):
""" """
Embeds the "template_angle_feat" feature. Embeds the "template_angle_feat" feature.
...@@ -432,7 +432,7 @@ class TemplateAngleEmbedder(nn.Module): ...@@ -432,7 +432,7 @@ class TemplateAngleEmbedder(nn.Module):
c_out: c_out:
Output channel dimension Output channel dimension
""" """
super(TemplateAngleEmbedder, self).__init__() super(TemplateSingleEmbedder, self).__init__()
self.c_out = c_out self.c_out = c_out
self.c_in = c_in self.c_in = c_in
...@@ -543,8 +543,8 @@ class TemplateEmbedder(nn.Module): ...@@ -543,8 +543,8 @@ class TemplateEmbedder(nn.Module):
super(TemplateEmbedder, self).__init__() super(TemplateEmbedder, self).__init__()
self.config = config self.config = config
self.template_angle_embedder = TemplateAngleEmbedder( self.template_single_embedder = TemplateSingleEmbedder(
**config["template_angle_embedder"], **config["template_single_embedder"],
) )
self.template_pair_embedder = TemplatePairEmbedder( self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"], **config["template_pair_embedder"],
...@@ -651,7 +651,7 @@ class TemplateEmbedder(nn.Module): ...@@ -651,7 +651,7 @@ class TemplateEmbedder(nn.Module):
) )
# [*, S_t, N, C_m] # [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat) a = self.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a ret["template_single_embedding"] = a
...@@ -660,7 +660,7 @@ class TemplateEmbedder(nn.Module): ...@@ -660,7 +660,7 @@ class TemplateEmbedder(nn.Module):
class TemplatePairEmbedderMultimer(nn.Module): class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self, def __init__(self,
c_z: int, c_in: int,
c_out: int, c_out: int,
c_dgram: int, c_dgram: int,
c_aatype: int, c_aatype: int,
...@@ -670,8 +670,8 @@ class TemplatePairEmbedderMultimer(nn.Module): ...@@ -670,8 +670,8 @@ class TemplatePairEmbedderMultimer(nn.Module):
self.dgram_linear = Linear(c_dgram, c_out, init='relu') self.dgram_linear = Linear(c_dgram, c_out, init='relu')
self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu') self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu')
self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu') self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu')
self.query_embedding_layer_norm = LayerNorm(c_z) self.query_embedding_layer_norm = LayerNorm(c_in)
self.query_embedding_linear = Linear(c_z, c_out, init='relu') self.query_embedding_linear = Linear(c_in, c_out, init='relu')
self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu') self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu')
self.x_linear = Linear(1, c_out, init='relu') self.x_linear = Linear(1, c_out, init='relu')
...@@ -722,11 +722,11 @@ class TemplatePairEmbedderMultimer(nn.Module): ...@@ -722,11 +722,11 @@ class TemplatePairEmbedderMultimer(nn.Module):
class TemplateSingleEmbedderMultimer(nn.Module): class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self, def __init__(self,
c_in: int, c_in: int,
c_m: int, c_out: int,
): ):
super(TemplateSingleEmbedderMultimer, self).__init__() super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m) self.template_single_embedder = Linear(c_in, c_out)
self.template_projector = Linear(c_m, c_m) self.template_projector = Linear(c_out, c_out)
def forward(self, def forward(self,
batch, batch,
...@@ -797,6 +797,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -797,6 +797,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim, templ_dim,
chunk_size, chunk_size,
multichain_mask_2d, multichain_mask_2d,
_mask_trans=True,
use_lma=False, use_lma=False,
inplace_safe=False inplace_safe=False
): ):
...@@ -869,7 +870,9 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -869,7 +870,9 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds["template_pair_embedding"], template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype), padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size, chunk_size=chunk_size,
_mask_trans=False, use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ t = torch.sum(t, dim=-4) / n_templ
......
...@@ -139,7 +139,8 @@ class AlphaFold(nn.Module): ...@@ -139,7 +139,8 @@ class AlphaFold(nn.Module):
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d, multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe inplace_safe=inplace_safe,
_mask_trans = self.config._mask_trans
) )
feats["template_torsion_angles_mask"] = ( feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"] template_embeds["template_mask"]
...@@ -161,7 +162,8 @@ class AlphaFold(nn.Module): ...@@ -161,7 +162,8 @@ class AlphaFold(nn.Module):
templ_dim, templ_dim,
chunk_size=self.globals.chunk_size, chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans
) )
return template_embeds return template_embeds
......
...@@ -552,7 +552,7 @@ def embed_templates_offload( ...@@ -552,7 +552,7 @@ def embed_templates_offload(
) )
# [*, N, C_m] # [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat) a = model.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a ret["template_single_embedding"] = a
...@@ -663,7 +663,7 @@ def embed_templates_average( ...@@ -663,7 +663,7 @@ def embed_templates_average(
) )
# [*, N, C_m] # [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat) a = model.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a ret["template_single_embedding"] = a
......
...@@ -577,10 +577,10 @@ def generate_translation_dict(model, version, is_multimer=False): ...@@ -577,10 +577,10 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention": AttentionParams(model.template_embedder.template_pointwise_att.mha), "attention": AttentionParams(model.template_embedder.template_pointwise_att.mha),
}, },
"template_single_embedding": LinearParams( "template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1 model.template_embedder.template_single_embedder.linear_1
), ),
"template_projection": LinearParams( "template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2 model.template_embedder.template_single_embedder.linear_2
), ),
} }
else: else:
......
...@@ -1668,11 +1668,8 @@ def chain_center_of_mass_loss( ...@@ -1668,11 +1668,8 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :] all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains = asym_id.unique()
# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1 one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = torch.nn.functional.one_hot(asym_id.long() - 1,
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1) chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float() chain_exists = torch.any(chain_pos_mask, dim=-1).float()
......
...@@ -19,7 +19,8 @@ consts = mlc.ConfigDict( ...@@ -19,7 +19,8 @@ consts = mlc.ConfigDict(
"c_s": 384, "c_s": 384,
"c_t": 64, "c_t": 64,
"c_e": 64, "c_e": 64,
"msa_logits": 22 # monomer: 23, multimer: 22 "msa_logits": 22, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
} }
) )
......
...@@ -40,7 +40,7 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4): ...@@ -40,7 +40,7 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
asym_ids.extend(piece * [idx]) asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [final_idx]) asym_ids.extend((n_res - sum(pieces)) * [final_idx])
return np.array(asym_ids).astype(np.int64) return np.array(asym_ids).astype(np.float32) + 1
def random_template_feats(n_templ, n, batch_size=None): def random_template_feats(n_templ, n, batch_size=None):
......
...@@ -21,7 +21,7 @@ from openfold.model.embedders import ( ...@@ -21,7 +21,7 @@ from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer, InputEmbedderMultimer,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateSingleEmbedder,
TemplatePairEmbedder TemplatePairEmbedder
) )
...@@ -96,7 +96,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
n_templ = 4 n_templ = 4
n_res = 256 n_res = 256
tae = TemplateAngleEmbedder( tae = TemplateSingleEmbedder(
template_angle_dim, template_angle_dim,
c_m, c_m,
) )
......
...@@ -12,24 +12,21 @@ ...@@ -12,24 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pathlib import Path import os
import shutil import shutil
import pickle
import torch import torch
import torch.nn as nn
import numpy as np
from functools import partial
import unittest import unittest
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule from openfold.data.data_modules import OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss from openfold.utils.loss import AlphaFoldMultimerLoss
from tests.config import consts from tests.config import consts
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
import os
@unittest.skipIf(not consts.is_multimer or consts.template_mmcif_dir is None, "Template mmcif dir required.")
class TestMultimerDataModule(unittest.TestCase): class TestMultimerDataModule(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -38,14 +35,14 @@ class TestMultimerDataModule(unittest.TestCase): ...@@ -38,14 +35,14 @@ class TestMultimerDataModule(unittest.TestCase):
use model_1_multimer_v3 for now use model_1_multimer_v3 for now
""" """
self.config = model_config( self.config = model_config(
"model_1_multimer_v3", consts.model,
train=True, train=True,
low_prec=True) low_prec=True)
self.data_module = OpenFoldMultimerDataModule( self.data_module = OpenFoldMultimerDataModule(
config=self.config.data, config=self.config.data,
batch_seed=42, batch_seed=42,
train_epoch_len=100, train_epoch_len=100,
template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/", template_mmcif_dir= consts.template_mmcif_dir,
template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"), template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
max_template_date="2500-01-01", max_template_date="2500-01-01",
train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"), train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
......
...@@ -263,6 +263,7 @@ class Template(unittest.TestCase): ...@@ -263,6 +263,7 @@ class Template(unittest.TestCase):
templ_dim=0, templ_dim=0,
chunk_size=consts.chunk_size, chunk_size=consts.chunk_size,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(), multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
_mask_trans=False,
use_lma=False, use_lma=False,
inplace_safe=False inplace_safe=False
) )
...@@ -273,6 +274,7 @@ class Template(unittest.TestCase): ...@@ -273,6 +274,7 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
templ_dim=0, templ_dim=0,
chunk_size=consts.chunk_size, chunk_size=consts.chunk_size,
mask_trans=False,
use_lma=False, use_lma=False,
inplace_safe=False inplace_safe=False
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment