import ml_collections as mlc import copy from typing import Any N_RES = "number of residues" N_MSA = "number of MSA sequences" N_EXTRA_MSA = "number of extra MSA sequences" N_TPL = "number of templates" d_pair = mlc.FieldReference(128, field_type=int) d_msa = mlc.FieldReference(256, field_type=int) d_template = mlc.FieldReference(64, field_type=int) d_extra_msa = mlc.FieldReference(64, field_type=int) d_single = mlc.FieldReference(384, field_type=int) max_recycling_iters = mlc.FieldReference(3, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int) eps = mlc.FieldReference(1e-8, field_type=float) inf = mlc.FieldReference(3e4, field_type=float) use_templates = mlc.FieldReference(True, field_type=bool) is_multimer = mlc.FieldReference(False, field_type=bool) def base_config(): return mlc.ConfigDict( { "data": { "common": { "features": { "aatype": [N_RES], "all_atom_mask": [N_RES, None], "all_atom_positions": [N_RES, None, None], "alt_chi_angles": [N_RES, None], "atom14_alt_gt_exists": [N_RES, None], "atom14_alt_gt_positions": [N_RES, None, None], "atom14_atom_exists": [N_RES, None], "atom14_atom_is_ambiguous": [N_RES, None], "atom14_gt_exists": [N_RES, None], "atom14_gt_positions": [N_RES, None, None], "atom37_atom_exists": [N_RES, None], "frame_mask": [N_RES], "true_frame_tensor": [N_RES, None, None], "bert_mask": [N_MSA, N_RES], "chi_angles_sin_cos": [N_RES, None, None], "chi_mask": [N_RES, None], "extra_msa_deletion_value": [N_EXTRA_MSA, N_RES], "extra_msa_has_deletion": [N_EXTRA_MSA, N_RES], "extra_msa": [N_EXTRA_MSA, N_RES], "extra_msa_mask": [N_EXTRA_MSA, N_RES], "extra_msa_row_mask": [N_EXTRA_MSA], "is_distillation": [], "msa_feat": [N_MSA, N_RES, None], "msa_mask": [N_MSA, N_RES], "msa_chains": [N_MSA, None], "msa_row_mask": [N_MSA], "num_recycling_iters": [], "pseudo_beta": [N_RES, None], "pseudo_beta_mask": [N_RES], "residue_index": [N_RES], "residx_atom14_to_atom37": [N_RES, None], "residx_atom37_to_atom14": [N_RES, None], "resolution": [], "rigidgroups_alt_gt_frames": [N_RES, None, None, None], "rigidgroups_group_exists": [N_RES, None], "rigidgroups_group_is_ambiguous": [N_RES, None], "rigidgroups_gt_exists": [N_RES, None], "rigidgroups_gt_frames": [N_RES, None, None, None], "seq_length": [], "seq_mask": [N_RES], "target_feat": [N_RES, None], "template_aatype": [N_TPL, N_RES], "template_all_atom_mask": [N_TPL, N_RES, None], "template_all_atom_positions": [N_TPL, N_RES, None, None], "template_alt_torsion_angles_sin_cos": [ N_TPL, N_RES, None, None, ], "template_frame_mask": [N_TPL, N_RES], "template_frame_tensor": [N_TPL, N_RES, None, None], "template_mask": [N_TPL], "template_pseudo_beta": [N_TPL, N_RES, None], "template_pseudo_beta_mask": [N_TPL, N_RES], "template_sum_probs": [N_TPL, None], "template_torsion_angles_mask": [N_TPL, N_RES, None], "template_torsion_angles_sin_cos": [N_TPL, N_RES, None, None], "true_msa": [N_MSA, N_RES], "use_clamped_fape": [], "assembly_num_chains": [1], "asym_id": [N_RES], "sym_id": [N_RES], "entity_id": [N_RES], "num_sym": [N_RES], "asym_len": [None], "cluster_bias_mask": [N_MSA], }, "masked_msa": { "profile_prob": 0.1, "same_prob": 0.1, "uniform_prob": 0.1, }, "block_delete_msa": { "msa_fraction_per_block": 0.3, "randomize_num_blocks": False, "num_blocks": 5, "min_num_msa": 16, }, "random_delete_msa": { "max_msa_entry": 1 << 25, # := 33554432 }, "v2_feature": False, "gumbel_sample": False, "max_extra_msa": 1024, "msa_cluster_features": True, "reduce_msa_clusters_by_max_templates": True, "resample_msa_in_recycling": True, "template_features": [ "template_all_atom_positions", "template_sum_probs", "template_aatype", "template_all_atom_mask", ], "unsupervised_features": [ "aatype", "residue_index", "msa", "msa_chains", "num_alignments", "seq_length", "between_segment_residues", "deletion_matrix", "num_recycling_iters", "crop_and_fix_size_seed", ], "recycling_features": [ "msa_chains", "msa_mask", "msa_row_mask", "bert_mask", "true_msa", "msa_feat", "extra_msa_deletion_value", "extra_msa_has_deletion", "extra_msa", "extra_msa_mask", "extra_msa_row_mask", "is_distillation", ], "multimer_features": [ "assembly_num_chains", "asym_id", "sym_id", "num_sym", "entity_id", "asym_len", "cluster_bias_mask", ], "use_templates": use_templates, "is_multimer": is_multimer, "use_template_torsion_angles": use_templates, "max_recycling_iters": max_recycling_iters, }, "supervised": { "use_clamped_fape_prob": 1.0, "supervised_features": [ "all_atom_mask", "all_atom_positions", "resolution", "use_clamped_fape", "is_distillation", ], }, "predict": { "fixed_size": True, "subsample_templates": False, "block_delete_msa": False, "random_delete_msa": True, "masked_msa_replace_fraction": 0.15, "max_msa_clusters": 128, "max_templates": 4, "num_ensembles": 2, "crop": False, "crop_size": None, "supervised": False, "biased_msa_by_chain": False, "share_mask": False, }, "eval": { "fixed_size": True, "subsample_templates": False, "block_delete_msa": False, "random_delete_msa": True, "masked_msa_replace_fraction": 0.15, "max_msa_clusters": 128, "max_templates": 4, "num_ensembles": 1, "crop": False, "crop_size": None, "spatial_crop_prob": 0.5, "ca_ca_threshold": 10.0, "supervised": True, "biased_msa_by_chain": False, "share_mask": False, }, "train": { "fixed_size": True, "subsample_templates": True, "block_delete_msa": True, "random_delete_msa": True, "masked_msa_replace_fraction": 0.15, "max_msa_clusters": 128, "max_templates": 4, "num_ensembles": 1, "crop": True, "crop_size": 256, "spatial_crop_prob": 0.5, "ca_ca_threshold": 10.0, "supervised": True, "use_clamped_fape_prob": 1.0, "max_distillation_msa_clusters": 1000, "biased_msa_by_chain": True, "share_mask": True, }, }, "globals": { "chunk_size": chunk_size, "block_size": None, "d_pair": d_pair, "d_msa": d_msa, "d_template": d_template, "d_extra_msa": d_extra_msa, "d_single": d_single, "eps": eps, "inf": inf, "max_recycling_iters": max_recycling_iters, "alphafold_original_mode": False, }, "model": { "is_multimer": is_multimer, "input_embedder": { "tf_dim": 22, "msa_dim": 49, "d_pair": d_pair, "d_msa": d_msa, "relpos_k": 32, "max_relative_chain": 2, }, "recycling_embedder": { "d_pair": d_pair, "d_msa": d_msa, "min_bin": 3.25, "max_bin": 20.75, "num_bins": 15, "inf": 1e8, }, "template": { "distogram": { "min_bin": 3.25, "max_bin": 50.75, "num_bins": 39, }, "template_angle_embedder": { "d_in": 57, "d_out": d_msa, }, "template_pair_embedder": { "d_in": 88, "v2_d_in": [39, 1, 22, 22, 1, 1, 1, 1], "d_pair": d_pair, "d_out": d_template, "v2_feature": False, }, "template_pair_stack": { "d_template": d_template, "d_hid_tri_att": 16, "d_hid_tri_mul": 64, "num_blocks": 2, "num_heads": 4, "pair_transition_n": 2, "dropout_rate": 0.25, "inf": 1e9, "tri_attn_first": True, }, "template_pointwise_attention": { "enabled": True, "d_template": d_template, "d_pair": d_pair, "d_hid": 16, "num_heads": 4, "inf": 1e5, }, "inf": 1e5, "eps": 1e-6, "enabled": use_templates, "embed_angles": use_templates, }, "extra_msa": { "extra_msa_embedder": { "d_in": 25, "d_out": d_extra_msa, }, "extra_msa_stack": { "d_msa": d_extra_msa, "d_pair": d_pair, "d_hid_msa_att": 8, "d_hid_opm": 32, "d_hid_mul": 128, "d_hid_pair_att": 32, "num_heads_msa": 8, "num_heads_pair": 4, "num_blocks": 4, "transition_n": 4, "msa_dropout": 0.15, "pair_dropout": 0.25, "inf": 1e9, "eps": 1e-10, "outer_product_mean_first": False, }, "enabled": True, }, "evoformer_stack": { "d_msa": d_msa, "d_pair": d_pair, "d_hid_msa_att": 32, "d_hid_opm": 32, "d_hid_mul": 128, "d_hid_pair_att": 32, "d_single": d_single, "num_heads_msa": 8, "num_heads_pair": 4, "num_blocks": 48, "transition_n": 4, "msa_dropout": 0.15, "pair_dropout": 0.25, "inf": 1e9, "eps": 1e-10, "outer_product_mean_first": False, }, "structure_module": { "d_single": d_single, "d_pair": d_pair, "d_ipa": 16, "d_angle": 128, "num_heads_ipa": 12, "num_qk_points": 4, "num_v_points": 8, "dropout_rate": 0.1, "num_blocks": 8, "no_transition_layers": 1, "num_resnet_blocks": 2, "num_angles": 7, "trans_scale_factor": 10, "epsilon": 1e-12, "inf": 1e5, "separate_kv": False, "ipa_bias": True, }, "heads": { "plddt": { "num_bins": 50, "d_in": d_single, "d_hid": 128, }, "distogram": { "d_pair": d_pair, "num_bins": aux_distogram_bins, "disable_enhance_head": False, }, "pae": { "d_pair": d_pair, "num_bins": aux_distogram_bins, "enabled": False, "iptm_weight": 0.8, "disable_enhance_head": False, }, "masked_msa": { "d_msa": d_msa, "d_out": 23, "disable_enhance_head": False, }, "experimentally_resolved": { "d_single": d_single, "d_out": 37, "enabled": False, "disable_enhance_head": False, }, }, }, "loss": { "distogram": { "min_bin": 2.3125, "max_bin": 21.6875, "num_bins": 64, "eps": 1e-6, "weight": 0.3, }, "experimentally_resolved": { "eps": 1e-8, "min_resolution": 0.1, "max_resolution": 3.0, "weight": 0.0, }, "fape": { "backbone": { "clamp_distance": 10.0, "clamp_distance_between_chains": 30.0, "loss_unit_distance": 10.0, "loss_unit_distance_between_chains": 20.0, "weight": 0.5, "eps": 1e-4, }, "sidechain": { "clamp_distance": 10.0, "length_scale": 10.0, "weight": 0.5, "eps": 1e-4, }, "weight": 1.0, }, "plddt": { "min_resolution": 0.1, "max_resolution": 3.0, "cutoff": 15.0, "num_bins": 50, "eps": 1e-10, "weight": 0.01, }, "masked_msa": { "eps": 1e-8, "weight": 2.0, }, "supervised_chi": { "chi_weight": 0.5, "angle_norm_weight": 0.01, "eps": 1e-6, "weight": 1.0, }, "violation": { "violation_tolerance_factor": 12.0, "clash_overlap_tolerance": 1.5, "bond_angle_loss_weight": 0.3, "eps": 1e-6, "weight": 0.0, }, "pae": { "max_bin": 31, "num_bins": 64, "min_resolution": 0.1, "max_resolution": 3.0, "eps": 1e-8, "weight": 0.0, }, "repr_norm": { "weight": 0.01, "tolerance": 1.0, }, "chain_centre_mass": { "weight": 0.0, "eps": 1e-8, }, }, } ) def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None): with c.unlocked(): for k, v in c.items(): if ignore is not None and k == ignore: continue if isinstance(v, mlc.ConfigDict): recursive_set(v, key, value) elif k == key: c[k] = value def model_config(name, train=False): c = copy.deepcopy(base_config()) def model_2_v2(c): recursive_set(c, "v2_feature", True) recursive_set(c, "gumbel_sample", True) c.model.heads.masked_msa.d_out = 22 c.model.structure_module.separate_kv = True c.model.structure_module.ipa_bias = False c.model.template.template_angle_embedder.d_in = 34 return c def multimer(c): recursive_set(c, "is_multimer", True) recursive_set(c, "max_extra_msa", 1152) recursive_set(c, "max_msa_clusters", 128) recursive_set(c, "v2_feature", True) recursive_set(c, "gumbel_sample", True) c.model.template.template_angle_embedder.d_in = 34 c.model.template.template_pair_stack.tri_attn_first = False c.model.template.template_pointwise_attention.enabled = False c.model.heads.pae.enabled = True # we forget to enable it in our training, so disable it here c.model.heads.pae.disable_enhance_head = True c.model.heads.masked_msa.d_out = 22 c.model.structure_module.separate_kv = True c.model.structure_module.ipa_bias = False c.model.structure_module.trans_scale_factor = 20 c.loss.pae.weight = 0.1 c.model.input_embedder.tf_dim = 21 c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 c.loss.chain_centre_mass.weight = 1.0 return c if name == "model_1": pass elif name == "model_1_ft": recursive_set(c, "max_extra_msa", 5120) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 elif name == "model_1_af2": recursive_set(c, "max_extra_msa", 5120) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 c.loss.repr_norm.weight = 0 c.model.heads.experimentally_resolved.enabled = True c.loss.experimentally_resolved.weight = 0.01 c.globals.alphafold_original_mode = True elif name == "model_2": pass elif name == "model_init": pass elif name == "model_init_af2": c.globals.alphafold_original_mode = True pass elif name == "model_2_ft": recursive_set(c, "max_extra_msa", 1024) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 elif name == "model_2_af2": recursive_set(c, "max_extra_msa", 1024) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 c.loss.repr_norm.weight = 0 c.model.heads.experimentally_resolved.enabled = True c.loss.experimentally_resolved.weight = 0.01 c.globals.alphafold_original_mode = True elif name == "model_2_v2": c = model_2_v2(c) elif name == "model_2_v2_ft": c = model_2_v2(c) recursive_set(c, "max_extra_msa", 1024) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 elif name == "model_3_af2" or name == "model_4_af2": recursive_set(c, "max_extra_msa", 5120) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 c.loss.repr_norm.weight = 0 c.model.heads.experimentally_resolved.enabled = True c.loss.experimentally_resolved.weight = 0.01 c.globals.alphafold_original_mode = True c.model.template.enabled = False c.model.template.embed_angles = False recursive_set(c, "use_templates", False) recursive_set(c, "use_template_torsion_angles", False) elif name == "model_5_af2": recursive_set(c, "max_extra_msa", 1024) recursive_set(c, "max_msa_clusters", 512) c.data.train.crop_size = 384 c.loss.violation.weight = 0.02 c.loss.repr_norm.weight = 0 c.model.heads.experimentally_resolved.enabled = True c.loss.experimentally_resolved.weight = 0.01 c.globals.alphafold_original_mode = True c.model.template.enabled = False c.model.template.embed_angles = False recursive_set(c, "use_templates", False) recursive_set(c, "use_template_torsion_angles", False) elif name == "multimer": c = multimer(c) elif name == "multimer_ft": c = multimer(c) recursive_set(c, "max_extra_msa", 1152) recursive_set(c, "max_msa_clusters", 256) c.data.train.crop_size = 384 c.loss.violation.weight = 0.5 elif name == "multimer_af2": recursive_set(c, "max_extra_msa", 1152) recursive_set(c, "max_msa_clusters", 256) recursive_set(c, "is_multimer", True) recursive_set(c, "v2_feature", True) recursive_set(c, "gumbel_sample", True) c.model.template.template_angle_embedder.d_in = 34 c.model.template.template_pair_stack.tri_attn_first = False c.model.template.template_pointwise_attention.enabled = False c.model.heads.pae.enabled = True c.model.heads.experimentally_resolved.enabled = True c.model.heads.masked_msa.d_out = 22 c.model.structure_module.separate_kv = True c.model.structure_module.ipa_bias = False c.model.structure_module.trans_scale_factor = 20 c.loss.pae.weight = 0.1 c.loss.violation.weight = 0.5 c.loss.experimentally_resolved.weight = 0.01 c.model.input_embedder.tf_dim = 21 c.globals.alphafold_original_mode = True c.data.train.crop_size = 384 c.loss.repr_norm.weight = 0 c.loss.chain_centre_mass.weight = 1.0 recursive_set(c, "outer_product_mean_first", True) else: raise ValueError(f"invalid --model-name: {name}.") if train: c.globals.chunk_size = None recursive_set(c, "inf", 3e4) recursive_set(c, "eps", 1e-5, "loss") return c