Commit 82bda2d6 authored by Christina Floristean's avatar Christina Floristean
Browse files

Refactored multimer config update

parent 30764cf9
......@@ -154,50 +154,37 @@ def model_config(
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif "multimer" in name:
c.globals.is_multimer = True
c.globals.bfloat16 = False
c.globals.bfloat16_output = False
c.loss.masked_msa.num_classes = 22
c.data.common.max_recycling_iters = 20
for k, v in multimer_model_config_update['model'].items():
c.model[k] = v
for k, v in multimer_model_config_update['loss'].items():
c.loss[k] = v
c.update(multimer_config_update.copy_and_resolve_references())
del c.model.template.template_pointwise_attention
del c.loss.fape.backbone
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
#c.model.input_embedder.num_msa = 252
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.crop_size = 384
c.data.train.max_msa_clusters = 252
c.data.eval.max_msa_clusters = 252
c.data.predict.max_msa_clusters = 252
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
c.model.evoformer_stack.fuse_projection_weights = False
c.model.extra_msa.extra_msa_stack.fuse_projection_weights = False
c.model.template.template_pair_stack.fuse_projection_weights = False
elif name == 'model_4_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
elif name == 'model_5_multimer_v3':
#c.model.extra_msa.extra_msa_embedder.num_extra_msa = 1152
c.data.train.max_extra_msa = 1152
c.data.eval.max_extra_msa = 1152
c.data.predict.max_extra_msa = 1152
else:
c.data.train.max_msa_clusters = 508
c.data.predict.max_msa_clusters = 508
c.data.train.max_extra_msa = 2048
c.data.predict.max_extra_msa = 2048
c.data.common.unsupervised_features.extend([
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
])
else:
raise ValueError("Invalid model name")
......@@ -451,7 +438,7 @@ config = mlc.ConfigDict(
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
"template_single_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
......@@ -682,226 +669,131 @@ config = mlc.ConfigDict(
}
)
multimer_model_config_update = {
'model': {
multimer_config_update = mlc.ConfigDict({
"globals": {
"is_multimer": True,
"bfloat16": False, # TODO: Change to True when implemented
"bfloat16_output": False
},
"data": {
"common": {
"max_recycling_iters": 20,
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
"no_recycling_iters",
# Additional multimer features
"msa_mask",
"seq_mask",
"asym_id",
"entity_id",
"sym_id",
]
},
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model:
# c.model.input_embedder.num_msa = 508
# c.model.extra_msa.extra_msa_embedder.num_extra_msa = 2048
"predict": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"eval": {
"max_msa_clusters": 508,
"max_extra_msa": 2048
},
"train": {
"max_msa_clusters": 508,
"max_extra_msa": 2048,
"crop_size": 640
},
},
"model": {
"input_embedder": {
"tf_dim": 21,
"msa_dim": 49,
#"num_msa": 508,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
"max_relative_chain": 2,
"max_relative_idx": 32,
"use_chain_relative": True,
"use_chain_relative": True
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
"template_single_embedder": {
"c_in": 34,
"c_out": c_m
},
"template_pair_embedder": {
"c_z": c_z,
"c_out": 64,
"c_in": c_z,
"c_out": c_t,
"c_dgram": 39,
"c_aatype": 22,
},
"template_single_embedder": {
"c_in": 34,
"c_m": c_m,
"c_aatype": 22
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"tri_mul_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"inf": 1e9,
"fuse_projection_weights": True
},
"c_t": c_t,
"c_z": c_z,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
"use_unit_vector": True
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
#"num_extra_msa": 2048
},
# "extra_msa_embedder": {
# "num_extra_msa": 2048
# },
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"clear_cache_between_blocks": True,
"inf": 1e9,
"eps": eps, # 1e-10,
"ckpt": blocks_per_ckpt is not None,
},
"enabled": True,
"fuse_projection_weights": True
}
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"opm_first": True,
"fuse_projection_weights": True,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"inf": 1e9,
"eps": eps, # 1e-10,
"fuse_projection_weights": True
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 20,
"epsilon": eps, # 1e-12,
"inf": 1e5,
"trans_scale_factor": 20
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"ptm_weight": 0.2,
"iptm_weight": 0.8,
"enabled": True,
"enabled": True
},
"masked_msa": {
"c_m": c_m,
"c_out": 22,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
"c_out": 22
},
},
"recycle_early_stop_tolerance": 0.5
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"intra_chain_backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
"weight": 0.5
},
"interface_backbone": {
"clamp_distance": 30.0,
"loss_unit_distance": 20.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"plddt_loss": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
"weight": 0.5
}
},
"masked_msa": {
"num_classes": 23,
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
"num_classes": 22
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"average_clashes": True,
"eps": eps, # 1e-6,
"weight": 0.03, # Not finetuning
"weight": 0.03 # Not finetuning
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.1,
"enabled": True,
"enabled": True
},
"chain_center_of_mass": {
"clamp_distance": -4.0,
"weight": 0.05,
"eps": eps,
"enabled": True,
},
"eps": eps,
"enabled": True
}
}
}
})
......@@ -412,7 +412,7 @@ class RecyclingEmbedder(nn.Module):
return m_update, z_update
class TemplateAngleEmbedder(nn.Module):
class TemplateSingleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
......@@ -432,7 +432,7 @@ class TemplateAngleEmbedder(nn.Module):
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
super(TemplateSingleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
......@@ -543,8 +543,8 @@ class TemplateEmbedder(nn.Module):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
self.template_single_embedder = TemplateSingleEmbedder(
**config["template_single_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
......@@ -651,7 +651,7 @@ class TemplateEmbedder(nn.Module):
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
a = self.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a
......@@ -660,7 +660,7 @@ class TemplateEmbedder(nn.Module):
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_in: int,
c_out: int,
c_dgram: int,
c_aatype: int,
......@@ -670,8 +670,8 @@ class TemplatePairEmbedderMultimer(nn.Module):
self.dgram_linear = Linear(c_dgram, c_out, init='relu')
self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu')
self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu')
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out, init='relu')
self.query_embedding_layer_norm = LayerNorm(c_in)
self.query_embedding_linear = Linear(c_in, c_out, init='relu')
self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu')
self.x_linear = Linear(1, c_out, init='relu')
......@@ -722,11 +722,11 @@ class TemplatePairEmbedderMultimer(nn.Module):
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
c_out: int,
):
super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
self.template_single_embedder = Linear(c_in, c_out)
self.template_projector = Linear(c_out, c_out)
def forward(self,
batch,
......@@ -797,6 +797,7 @@ class TemplateEmbedderMultimer(nn.Module):
templ_dim,
chunk_size,
multichain_mask_2d,
_mask_trans=True,
use_lma=False,
inplace_safe=False
):
......@@ -869,7 +870,9 @@ class TemplateEmbedderMultimer(nn.Module):
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
......
......@@ -139,7 +139,8 @@ class AlphaFold(nn.Module):
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
inplace_safe=inplace_safe,
_mask_trans = self.config._mask_trans
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
......@@ -161,7 +162,8 @@ class AlphaFold(nn.Module):
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans
)
return template_embeds
......
......@@ -552,7 +552,7 @@ def embed_templates_offload(
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
a = model.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a
......@@ -663,7 +663,7 @@ def embed_templates_average(
)
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
a = model.template_single_embedder(template_angle_feat)
ret["template_single_embedding"] = a
......
......@@ -577,10 +577,10 @@ def generate_translation_dict(model, version, is_multimer=False):
"attention": AttentionParams(model.template_embedder.template_pointwise_att.mha),
},
"template_single_embedding": LinearParams(
model.template_embedder.template_angle_embedder.linear_1
model.template_embedder.template_single_embedder.linear_1
),
"template_projection": LinearParams(
model.template_embedder.template_angle_embedder.linear_2
model.template_embedder.template_single_embedder.linear_2
),
}
else:
......
......@@ -1668,11 +1668,8 @@ def chain_center_of_mass_loss(
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
chains = asym_id.unique()
# Reduce asym_id by one because class values must be smaller than num_classes and asym_ids start at 1
one_hot = torch.nn.functional.one_hot(asym_id.long() - 1,
num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
one_hot = torch.nn.functional.one_hot(asym_id.long()).to(dtype=all_atom_mask.dtype)
one_hot = one_hot * all_atom_mask
chain_pos_mask = one_hot.transpose(-2, -1)
chain_exists = torch.any(chain_pos_mask, dim=-1).float()
......
......@@ -19,7 +19,8 @@ consts = mlc.ConfigDict(
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 22 # monomer: 23, multimer: 22
"msa_logits": 22, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)
......
......@@ -40,7 +40,7 @@ def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [final_idx])
return np.array(asym_ids).astype(np.int64)
return np.array(asym_ids).astype(np.float32) + 1
def random_template_feats(n_templ, n, batch_size=None):
......
......@@ -21,7 +21,7 @@ from openfold.model.embedders import (
InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplateSingleEmbedder,
TemplatePairEmbedder
)
......@@ -96,7 +96,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
n_templ = 4
n_res = 256
tae = TemplateAngleEmbedder(
tae = TemplateSingleEmbedder(
template_angle_dim,
c_m,
)
......
......@@ -12,24 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import os
import shutil
import pickle
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import unittest
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldMultimerDataModule,OpenFoldDataModule
from openfold.data.data_modules import OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldMultimerLoss
from tests.config import consts
import logging
logger = logging.getLogger(__name__)
import os
@unittest.skipIf(not consts.is_multimer or consts.template_mmcif_dir is None, "Template mmcif dir required.")
class TestMultimerDataModule(unittest.TestCase):
def setUp(self):
"""
......@@ -38,14 +35,14 @@ class TestMultimerDataModule(unittest.TestCase):
use model_1_multimer_v3 for now
"""
self.config = model_config(
"model_1_multimer_v3",
consts.model,
train=True,
low_prec=True)
self.data_module = OpenFoldMultimerDataModule(
config=self.config.data,
batch_seed=42,
train_epoch_len=100,
template_mmcif_dir = "/g/alphafold/AlphaFold_DBs/2.3.0/pdb_mmcif/mmcif_files/",
template_mmcif_dir= consts.template_mmcif_dir,
template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
max_template_date="2500-01-01",
train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
......
......@@ -263,6 +263,7 @@ class Template(unittest.TestCase):
templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
_mask_trans=False,
use_lma=False,
inplace_safe=False
)
......@@ -273,6 +274,7 @@ class Template(unittest.TestCase):
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
mask_trans=False,
use_lma=False,
inplace_safe=False
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment