Commit da5d0e7d authored by Christina Floristean's avatar Christina Floristean
Browse files

Fixes for multimer config features and cropping

parent 3de188e9
...@@ -156,6 +156,10 @@ def model_config( ...@@ -156,6 +156,10 @@ def model_config(
elif "multimer" in name: elif "multimer" in name:
c.update(multimer_config_update.copy_and_resolve_references()) c.update(multimer_config_update.copy_and_resolve_references())
# Not used in multimer
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name): if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252 #c.model.input_embedder.num_msa = 252
...@@ -681,6 +685,52 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -681,6 +685,52 @@ multimer_config_update = mlc.ConfigDict({
}, },
"data": { "data": {
"common": { "common": {
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
# "all_chains_entity_ids": [], # TODO: Resolve missing features, remove processed msa feats
# "all_crops_all_chains_mask": [],
# "all_crops_all_chains_positions": [],
# "all_crops_all_chains_residue_ids": [],
"assembly_num_chains": [],
"asym_id": [NUM_RES],
"atom14_atom_exists": [NUM_RES, None],
"atom37_atom_exists": [NUM_RES, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"cluster_bias_mask": [NUM_MSA_SEQ],
"cluster_profile": [NUM_MSA_SEQ, NUM_RES, None],
"cluster_deletion_mean": [NUM_MSA_SEQ, NUM_RES],
"deletion_matrix": [NUM_MSA_SEQ, NUM_RES],
"deletion_mean": [NUM_RES],
"entity_id": [NUM_RES],
"entity_mask": [NUM_RES],
"extra_deletion_matrix": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
# "mem_peak": [],
"msa": [NUM_MSA_SEQ, NUM_RES],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_profile": [NUM_RES, None],
"num_alignments": [],
"num_templates": [],
# "queue_size": [],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"seq_length": [],
"seq_mask": [NUM_RES],
"sym_id": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES]
},
"max_recycling_iters": 20, "max_recycling_iters": 20,
"unsupervised_features": [ "unsupervised_features": [
"aatype", "aatype",
...@@ -741,7 +791,6 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -741,7 +791,6 @@ multimer_config_update = mlc.ConfigDict({
"tri_mul_first": True, "tri_mul_first": True,
"fuse_projection_weights": True "fuse_projection_weights": True
}, },
"template_pointwise_attention": None, # Not used in Multimer
"c_t": c_t, "c_t": c_t,
"c_z": c_z, "c_z": c_z,
"use_unit_vector": True "use_unit_vector": True
...@@ -785,8 +834,7 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -785,8 +834,7 @@ multimer_config_update = mlc.ConfigDict({
"clamp_distance": 30.0, "clamp_distance": 30.0,
"loss_unit_distance": 20.0, "loss_unit_distance": 20.0,
"weight": 0.5 "weight": 0.5
}, }
"backbone": None # Not used in Multimer
}, },
"masked_msa": { "masked_msa": {
"num_classes": 22 "num_classes": 22
......
...@@ -78,7 +78,8 @@ def np_example_to_features( ...@@ -78,7 +78,8 @@ def np_example_to_features(
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example["seq_length"][0]) seq_length = np_example["seq_length"]
num_res = int(seq_length[0]) if seq_length.ndim != 0 else int(seq_length)
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if "deletion_matrix_int" in np_example: if "deletion_matrix_int" in np_example:
......
...@@ -31,11 +31,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -31,11 +31,6 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_atom14_masks, data_transforms.make_atom14_masks,
] ]
if(common_cfg.use_templates):
transforms.extend([
data_transforms.make_pseudo_beta("template_"),
])
return transforms return transforms
......
...@@ -274,9 +274,8 @@ def _correct_post_merged_feats( ...@@ -274,9 +274,8 @@ def _correct_post_merged_feats(
) -> Mapping[str, np.ndarray]: ) -> Mapping[str, np.ndarray]:
"""Adds features that need to be computed/recomputed post merging.""" """Adds features that need to be computed/recomputed post merging."""
num_res = np_example['aatype'].shape[0]
np_example['seq_length'] = np.asarray( np_example['seq_length'] = np.asarray(
[num_res] * num_res, np_example['aatype'].shape[0],
dtype=np.int32 dtype=np.int32
) )
np_example['num_alignments'] = np.asarray( np_example['num_alignments'] = np.asarray(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment