Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
...@@ -4,55 +4,55 @@ import ml_collections as mlc ...@@ -4,55 +4,55 @@ import ml_collections as mlc
def set_inf(c, inf): def set_inf(c, inf):
for k, v in c.items(): for k, v in c.items():
if(isinstance(v, mlc.ConfigDict)): if isinstance(v, mlc.ConfigDict):
set_inf(v, inf) set_inf(v, inf)
elif(k == 'inf'): elif k == "inf":
c[k] = inf c[k] = inf
def model_config(name, train=False, low_prec=False): def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config) c = copy.deepcopy(config)
if(name == 'model_1'): if name == "model_1":
pass pass
elif(name == 'model_2'): elif name == "model_2":
pass pass
elif(name == 'model_3'): elif name == "model_3":
c.model.template.enabled = False c.model.template.enabled = False
elif(name == 'model_4'): elif name == "model_4":
c.model.template.enabled = False c.model.template.enabled = False
elif(name == 'model_5'): elif name == "model_5":
c.model.template.enabled = False c.model.template.enabled = False
elif(name == 'model_1_ptm'): elif name == "model_1_ptm":
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif(name == 'model_2_ptm'): elif name == "model_2_ptm":
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif(name == 'model_3_ptm'): elif name == "model_3_ptm":
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif(name == 'model_4_ptm'): elif name == "model_4_ptm":
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif(name == 'model_5_ptm'): elif name == "model_5_ptm":
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
else: else:
raise ValueError('Invalid model name') raise ValueError("Invalid model name")
if(train): if train:
c.globals.blocks_per_ckpt = 1 c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None c.globals.chunk_size = None
if(low_prec): if low_prec:
c.globals.eps = 1e-4 c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be # If we want exact numerical parity with the original, inf can't be
# a global constant # a global constant
set_inf(c, 1e4) set_inf(c, 1e4)
return c return c
...@@ -69,370 +69,384 @@ num_recycle = mlc.FieldReference(3, field_type=int) ...@@ -69,370 +69,384 @@ num_recycle = mlc.FieldReference(3, field_type=int)
templates_enabled = mlc.FieldReference(True, field_type=bool) templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool) embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
NUM_RES = 'num residues placeholder' NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = 'msa placeholder' NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = 'extra msa placeholder' NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = 'num templates placeholder' NUM_TEMPLATES = "num templates placeholder"
config = mlc.ConfigDict({ config = mlc.ConfigDict(
'data': { {
'common': { "data": {
'batch_modes': [('clamped', 0.9), ('unclamped', 0.1)], "common": {
'feat': { "batch_modes": [("clamped", 0.9), ("unclamped", 0.1)],
'aatype': [NUM_RES], "feat": {
'all_atom_mask': [NUM_RES, None], "aatype": [NUM_RES],
'all_atom_positions': [NUM_RES, None, None], "all_atom_mask": [NUM_RES, None],
'alt_chi_angles': [NUM_RES, None], "all_atom_positions": [NUM_RES, None, None],
'atom14_alt_gt_exists': [NUM_RES, None], "alt_chi_angles": [NUM_RES, None],
'atom14_alt_gt_positions': [NUM_RES, None, None], "atom14_alt_gt_exists": [NUM_RES, None],
'atom14_atom_exists': [NUM_RES, None], "atom14_alt_gt_positions": [NUM_RES, None, None],
'atom14_atom_is_ambiguous': [NUM_RES, None], "atom14_atom_exists": [NUM_RES, None],
'atom14_gt_exists': [NUM_RES, None], "atom14_atom_is_ambiguous": [NUM_RES, None],
'atom14_gt_positions': [NUM_RES, None, None], "atom14_gt_exists": [NUM_RES, None],
'atom37_atom_exists': [NUM_RES, None], "atom14_gt_positions": [NUM_RES, None, None],
'backbone_affine_mask': [NUM_RES], "atom37_atom_exists": [NUM_RES, None],
'backbone_affine_tensor': [NUM_RES, None, None], "backbone_affine_mask": [NUM_RES],
'bert_mask': [NUM_MSA_SEQ, NUM_RES], "backbone_affine_tensor": [NUM_RES, None, None],
'chi_angles': [NUM_RES, None], "bert_mask": [NUM_MSA_SEQ, NUM_RES],
'chi_mask': [NUM_RES, None], "chi_angles": [NUM_RES, None],
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES], "chi_mask": [NUM_RES, None],
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES], "extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES], "extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES], "extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_row_mask': [NUM_EXTRA_SEQ], "extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
'is_distillation': [], "extra_msa_row_mask": [NUM_EXTRA_SEQ],
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None], "is_distillation": [],
'msa_mask': [NUM_MSA_SEQ, NUM_RES], "msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
'msa_row_mask': [NUM_MSA_SEQ], "msa_mask": [NUM_MSA_SEQ, NUM_RES],
'pseudo_beta': [NUM_RES, None], "msa_row_mask": [NUM_MSA_SEQ],
'pseudo_beta_mask': [NUM_RES], "pseudo_beta": [NUM_RES, None],
'residue_index': [NUM_RES], "pseudo_beta_mask": [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None], "residue_index": [NUM_RES],
'residx_atom37_to_atom14': [NUM_RES, None], "residx_atom14_to_atom37": [NUM_RES, None],
'resolution': [], "residx_atom37_to_atom14": [NUM_RES, None],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None, None], "resolution": [],
'rigidgroups_group_exists': [NUM_RES, None], "rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None], "rigidgroups_group_exists": [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None], "rigidgroups_group_is_ambiguous": [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None, None], "rigidgroups_gt_exists": [NUM_RES, None],
'seq_length': [], "rigidgroups_gt_frames": [NUM_RES, None, None, None],
'seq_mask': [NUM_RES], "seq_length": [],
'target_feat': [NUM_RES, None], "seq_mask": [NUM_RES],
'template_aatype': [NUM_TEMPLATES, NUM_RES], "target_feat": [NUM_RES, None],
'template_all_atom_mask': [NUM_TEMPLATES, NUM_RES, None], "template_aatype": [NUM_TEMPLATES, NUM_RES],
'template_all_atom_positions': "template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
[NUM_TEMPLATES, NUM_RES, None, None], "template_all_atom_positions": [
'template_alt_torsion_angles_sin_cos': NUM_TEMPLATES, NUM_RES, None, None,
[NUM_TEMPLATES, NUM_RES, None, None], ],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES], "template_alt_torsion_angles_sin_cos": [
'template_backbone_affine_tensor': [ NUM_TEMPLATES, NUM_RES, None, None,
NUM_TEMPLATES, NUM_RES, None, None], ],
'template_mask': [NUM_TEMPLATES], "template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None], "template_backbone_affine_tensor": [
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES], NUM_TEMPLATES, NUM_RES, None, None,
'template_sum_probs': [NUM_TEMPLATES, None], ],
'template_torsion_angles_mask': [NUM_TEMPLATES, NUM_RES, None], "template_mask": [NUM_TEMPLATES],
'template_torsion_angles_sin_cos': "template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
[NUM_TEMPLATES, NUM_RES, None, None], "template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
'true_msa': [NUM_MSA_SEQ, NUM_RES], "template_sum_probs": [NUM_TEMPLATES, None],
'use_clamped_fape': [], "template_torsion_angles_mask": [
NUM_TEMPLATES, NUM_RES, None,
],
"template_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"msa_cluster_features": True,
"num_recycle": num_recycle,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
],
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
"resolution",
"use_clamped_fape",
],
}, },
"predict": {
'masked_msa': { "fixed_size": True,
'profile_prob': 0.1, "subsample_templates": False, # We want top templates.
'same_prob': 0.1, "masked_msa_replace_fraction": 0.15,
'uniform_prob': 0.1 "max_msa_clusters": 512,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
"crop_size": None,
"supervised": False,
}, },
'max_extra_msa': 1024, "eval": {
'msa_cluster_features': True, "fixed_size": True,
'num_recycle': num_recycle, "subsample_templates": False, # We want top templates.
'reduce_msa_clusters_by_max_templates': False, "masked_msa_replace_fraction": 0.15,
'resample_msa_in_recycling': True, "max_msa_clusters": 512,
'template_features': [ "max_templates": 4,
'template_all_atom_positions', 'template_sum_probs', "num_ensemble": 1,
'template_aatype', 'template_all_atom_mask', "crop": False,
], "crop_size": None,
'unsupervised_features': [ "supervised": True,
'aatype', 'residue_index', 'msa', 'num_alignments', },
'seq_length', 'between_segment_residues', 'deletion_matrix' "train": {
], "fixed_size": True,
'use_templates': templates_enabled, "subsample_templates": True,
'use_template_torsion_angles': embed_template_torsion_angles, "masked_msa_replace_fraction": 0.15,
'supervised_features': [ "max_msa_clusters": 512,
'all_atom_mask', 'all_atom_positions', 'resolution', "max_templates": 4,
'use_clamped_fape', "num_ensemble": 1,
], "crop": True,
}, "crop_size": 256,
'predict': { "supervised": True,
'fixed_size': True, },
'subsample_templates': False, # We want top templates. "data_module": {
'masked_msa_replace_fraction': 0.15, "use_small_bfd": False,
'max_msa_clusters': 512, "data_loaders": {
'max_templates': 4, "batch_size": 1,
'num_ensemble': 1, "num_workers": 1,
'crop': False, },
'crop_size': None,
'supervised': False,
},
'eval': {
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': False,
'crop_size': None,
'supervised': True,
},
'train': {
'fixed_size': True,
'subsample_templates': True,
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': True,
'crop_size': 256,
'supervised': True,
},
'data_module': {
'use_small_bfd': False,
'data_loaders': {
'batch_size': 1,
'num_workers': 1,
}, },
}
},
# Recurring FieldReferences that can be changed globally here
'globals': {
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'c_z': c_z,
'c_m': c_m,
'c_t': c_t,
'c_e': c_e,
'c_s': c_s,
'eps': eps,
},
'model': {
'num_recycle': num_recycle,
'_mask_trans': False,
'input_embedder': {
'tf_dim': 22,
'msa_dim': 49,
'c_z': c_z,
'c_m': c_m,
'relpos_k': 32,
}, },
'recycling_embedder': { # Recurring FieldReferences that can be changed globally here
'c_z': c_z, "globals": {
'c_m': c_m, "blocks_per_ckpt": blocks_per_ckpt,
'min_bin': 3.25, "chunk_size": chunk_size,
'max_bin': 20.75, "c_z": c_z,
'no_bins': 15, "c_m": c_m,
'inf': 1e8, "c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
}, },
'template': { "model": {
'distogram': { "num_recycle": num_recycle,
'min_bin': 3.25, "_mask_trans": False,
'max_bin': 50.75, "input_embedder": {
'no_bins': 39, "tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
}, },
'template_angle_embedder': { "recycling_embedder": {
# DISCREPANCY: c_in is supposed to be 51. "c_z": c_z,
'c_in': 57, "c_m": c_m,
'c_out': c_m, "min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
}, },
'template_pair_embedder': { "template": {
'c_in': 88, "distogram": {
'c_out': c_t, "min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
},
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
},
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
}, },
'template_pair_stack': { "extra_msa": {
'c_t': c_t, "extra_msa_embedder": {
# DISCREPANCY: c_hidden_tri_att here is given in the supplement "c_in": 25,
# as 64. In the code, it's 16. "c_out": c_e,
'c_hidden_tri_att': 16, },
'c_hidden_tri_mul': 64, "extra_msa_stack": {
'no_blocks': 2, "c_m": c_e,
'no_heads': 4, "c_z": c_z,
'pair_transition_n': 2, "c_hidden_msa_att": 8,
'dropout_rate': 0.25, "c_hidden_opm": 32,
'blocks_per_ckpt': blocks_per_ckpt, "c_hidden_mul": 128,
'chunk_size': chunk_size, "c_hidden_pair_att": 32,
'inf': 1e5,#1e9, "no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-10,
},
"enabled": True,
}, },
'template_pointwise_attention': { "evoformer_stack": {
'c_t': c_t, "c_m": c_m,
'c_z': c_z, "c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64. "c_hidden_msa_att": 32,
# It's actually 16. "c_hidden_opm": 32,
'c_hidden': 16, "c_hidden_mul": 128,
'no_heads': 4, "c_hidden_pair_att": 32,
'chunk_size': chunk_size, "c_s": c_s,
'inf': 1e5,#1e9, "no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-10,
}, },
'inf': 1e5,#1e9, "structure_module": {
'eps': eps,#1e-6, "c_s": c_s,
'enabled': templates_enabled, "c_z": c_z,
'embed_angles': embed_template_torsion_angles, "c_ipa": 16,
}, "c_resnet": 128,
'extra_msa': { "no_heads_ipa": 12,
'extra_msa_embedder': { "no_qk_points": 4,
'c_in': 25, "no_v_points": 8,
'c_out': c_e, "dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps, # 1e-12,
"inf": 1e5,
}, },
'extra_msa_stack': { "heads": {
'c_m': c_e, "lddt": {
'c_z': c_z, "no_bins": 50,
'c_hidden_msa_att': 8, "c_in": c_s,
'c_hidden_opm': 32, "c_hidden": 128,
'c_hidden_mul': 128, },
'c_hidden_pair_att': 32, "distogram": {
'no_heads_msa': 8, "c_z": c_z,
'no_heads_pair': 4, "no_bins": aux_distogram_bins,
'no_blocks': 4, },
'transition_n': 4, "tm": {
'msa_dropout': 0.15, "c_z": c_z,
'pair_dropout': 0.25, "no_bins": aux_distogram_bins,
'blocks_per_ckpt': blocks_per_ckpt, "enabled": False,
'chunk_size': chunk_size, },
'inf': 1e5,#1e9, "masked_msa": {
'eps': eps,#1e-10, "c_m": c_m,
"c_out": 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
}, },
'enabled': True,
},
'evoformer_stack': {
'c_m': c_m,
'c_z': c_z,
'c_hidden_msa_att': 32,
'c_hidden_opm': 32,
'c_hidden_mul': 128,
'c_hidden_pair_att': 32,
'c_s': c_s,
'no_heads_msa': 8,
'no_heads_pair': 4,
'no_blocks': 48,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
'eps': eps,#1e-10,
}, },
'structure_module': { "relax": {
'c_s': c_s, "max_iterations": 0, # no max
'c_z': c_z, "tolerance": 2.39,
'c_ipa': 16, "stiffness": 10.0,
'c_resnet': 128, "max_outer_iterations": 20,
'no_heads_ipa': 12, "exclude_residues": [],
'no_qk_points': 4,
'no_v_points': 8,
'dropout_rate': 0.1,
'no_blocks': 8,
'no_transition_layers': 1,
'no_resnet_blocks': 2,
'no_angles': 7,
'trans_scale_factor': 10,
'epsilon': eps,#1e-12,
'inf': 1e5,
}, },
'heads': { "loss": {
'lddt': { "distogram": {
'no_bins': 50, "min_bin": 2.3125,
'c_in': c_s, "max_bin": 21.6875,
'c_hidden': 128, "no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
}, },
'distogram': { "experimentally_resolved": {
'c_z': c_z, "eps": eps, # 1e-8,
'no_bins': aux_distogram_bins, "min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
}, },
'tm': { "fape": {
'c_z': c_z, "backbone": {
'no_bins': aux_distogram_bins, "clamp_distance": 10.0,
'enabled': False, "loss_unit_distance": 10.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
}, },
'masked_msa': { "lddt": {
'c_m': c_m, "min_resolution": 0.1,
'c_out': 23, "max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
}, },
'experimentally_resolved': { "masked_msa": {
'c_s': c_s, "eps": eps, # 1e-8,
'c_out': 37, "weight": 2.0,
}, },
}, "supervised_chi": {
}, "chi_weight": 0.5,
'relax': { "angle_norm_weight": 0.01,
'max_iterations': 0, # no max "eps": eps, # 1e-6,
'tolerance': 2.39, "weight": 1.0,
'stiffness': 10.0,
'max_outer_iterations': 20,
'exclude_residues': [],
},
'loss': {
'distogram': {
'min_bin': 2.3125,
'max_bin': 21.6875,
'no_bins': 64,
'eps': eps,#1e-6,
'weight': 0.3,
},
'experimentally_resolved': {
'eps': eps,#1e-8,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.,
},
'fape': {
'backbone': {
'clamp_distance': 10.,
'loss_unit_distance': 10.,
'weight': 0.5,
}, },
'sidechain': { "violation": {
'clamp_distance': 10., "violation_tolerance_factor": 12.0,
'length_scale': 10., "clash_overlap_tolerance": 1.5,
'weight': 0.5, "eps": eps, # 1e-6,
"weight": 0.0,
}, },
'eps': 1e-4, "tm": {
'weight': 1.0, "max_bin": 31,
}, "no_bins": 64,
'lddt': { "min_resolution": 0.1,
'min_resolution': 0.1, "max_resolution": 3.0,
'max_resolution': 3.0, "eps": eps, # 1e-8,
'cutoff': 15., "weight": 0.0,
'no_bins': 50, },
'eps': eps,#1e-10, "eps": eps,
'weight': 0.01,
},
'masked_msa': {
'eps': eps,#1e-8,
'weight': 2.0,
},
'supervised_chi': {
'chi_weight': 0.5,
'angle_norm_weight': 0.01,
'eps': eps,#1e-6,
'weight': 1.0,
},
'violation': {
'violation_tolerance_factor': 12.0,
'clash_overlap_tolerance': 1.5,
'eps': eps,#1e-6,
'weight': 0.,
},
'tm': {
'max_bin': 31,
'no_bins': 64,
'min_resolution': 0.1,
'max_resolution': 3.0,
'eps': eps,#1e-8,
'weight': 0.,
}, },
'eps': eps, "ema": {"decay": 0.999},
}, }
'ema': { )
'decay': 0.999
},
})
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -27,76 +27,79 @@ from openfold.np import residue_constants ...@@ -27,76 +27,79 @@ from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features( def make_sequence_features(
sequence: str, sequence: str, description: str, num_res: int
description: str,
num_res: int
) -> FeatureDict: ) -> FeatureDict:
"""Construct a feature dict of sequence features.""" """Construct a feature dict of sequence features."""
features = {} features = {}
features['aatype'] = residue_constants.sequence_to_onehot( features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence, sequence=sequence,
mapping=residue_constants.restype_order_with_x, mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True map_unknown_to_x=True,
) )
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array( features["domain_name"] = np.array(
[description.encode('utf-8')], dtype=np.object_ [description.encode("utf-8")], dtype=np.object_
) )
features['residue_index'] = np.array(range(num_res), dtype=np.int32) features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array( features["sequence"] = np.array(
[sequence.encode('utf-8')], dtype=np.object_ [sequence.encode("utf-8")], dtype=np.object_
) )
return features return features
def make_mmcif_features( def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
chain_id: str
) -> FeatureDict: ) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id] input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id]) description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence) num_res = len(input_sequence)
mmcif_feats = {} mmcif_feats = {}
mmcif_feats.update(make_sequence_features( mmcif_feats.update(
sequence=input_sequence, make_sequence_features(
description=description, sequence=input_sequence,
num_res=num_res, description=description,
)) num_res=num_res,
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords( all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id mmcif_object=mmcif_object, chain_id=chain_id
) )
mmcif_feats["all_atom_positions"] = all_atom_positions mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array( mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32 [mmcif_object.header["resolution"]], dtype=np.float32
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
) )
return mmcif_feats return mmcif_feats
def make_msa_features( def make_msa_features(
msas: Sequence[Sequence[str]], msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features.""" """Constructs a feature dict of MSA features."""
if not msas: if not msas:
raise ValueError('At least one MSA must be provided.') raise ValueError("At least one MSA must be provided.")
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
if not msa: if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.') raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa): for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences: if sequence in seen_sequences:
continue continue
...@@ -109,30 +112,32 @@ def make_msa_features( ...@@ -109,30 +112,32 @@ def make_msa_features(
num_res = len(msas[0][0]) num_res = len(msas[0][0])
num_alignments = len(int_msa) num_alignments = len(int_msa)
features = {} features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32) features["msa"] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
return features return features
class AlignmentRunner: class AlignmentRunner:
""" Runs alignment tools and saves the results """ """Runs alignment tools and saves the results"""
def __init__(self,
jackhmmer_binary_path: str, def __init__(
hhblits_binary_path: str, self,
hhsearch_binary_path: str, jackhmmer_binary_path: str,
uniref90_database_path: str, hhblits_binary_path: str,
mgnify_database_path: str, hhsearch_binary_path: str,
bfd_database_path: Optional[str], uniref90_database_path: str,
uniclust30_database_path: Optional[str], mgnify_database_path: str,
small_bfd_database_path: Optional[str], bfd_database_path: Optional[str],
pdb70_database_path: str, uniclust30_database_path: Optional[str],
use_small_bfd: bool, small_bfd_database_path: Optional[str],
no_cpus: int, pdb70_database_path: str,
uniref_max_hits: int = 10000, use_small_bfd: bool,
mgnify_max_hits: int = 5000, no_cpus: int,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
): ):
self._use_small_bfd = use_small_bfd self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
...@@ -161,115 +166,120 @@ class AlignmentRunner: ...@@ -161,115 +166,120 @@ class AlignmentRunner:
) )
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path, binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
databases=[pdb70_database_path]
) )
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
def run(self, def run(
self,
fasta_path: str, fasta_path: str,
output_dir: str, output_dir: str,
): ):
"""Runs alignment tools on a sequence""" """Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0] jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
) )
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m') uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, 'w') as f: with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m) f.write(uniref90_msa_as_a3m)
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0] jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
) )
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m') mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, 'w') as f: with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m) f.write(mgnify_msa_as_a3m)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr') pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, 'w') as f: with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result) f.write(hhsearch_result)
if self._use_small_bfd: if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0] jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto') fasta_path
with open(bfd_out_path, 'w') as f: )[0]
f.write(jackhmmer_small_bfd_result['sto']) bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
else: else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path) hhblits_bfd_uniclust_result = (
if(output_dir is not None): self.hhblits_bfd_uniclust_runner.query(fasta_path)
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m') )
with open(bfd_out_path, 'w') as f: if output_dir is not None:
f.write(hhblits_bfd_uniclust_result['a3m']) bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline: class DataPipeline:
"""Assembles input features.""" """Assembles input features."""
def __init__(self,
template_featurizer: templates.TemplateHitFeaturizer, def __init__(
use_small_bfd: bool, self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
): ):
self.template_featurizer = template_featurizer self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self, def _parse_alignment_output(
self,
alignment_dir: str, alignment_dir: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m') uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, 'r') as f: with open(uniref90_out_path, "r") as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m( uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m') mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, 'r') as f: with open(mgnify_out_path, "r") as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m( mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr') pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, 'r') as f: with open(pdb70_out_path, "r") as f:
hhsearch_hits = parsers.parse_hhr( hhsearch_hits = parsers.parse_hhr(f.read())
f.read()
)
if(self.use_small_bfd): if self.use_small_bfd:
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto') bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, 'r') as f: with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read() f.read()
) )
else: else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, 'r') as f: with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
return { return {
'uniref90_msa': uniref90_msa, "uniref90_msa": uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix, "uniref90_deletion_matrix": uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa, "mgnify_msa": mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix, "mgnify_deletion_matrix": mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits, "hhsearch_hits": hhsearch_hits,
'bfd_msa': bfd_msa, "bfd_msa": bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix, "bfd_deletion_matrix": bfd_deletion_matrix,
} }
def process_fasta(self, def process_fasta(
self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
fasta_str = f.read() fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str) input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1: if len(input_seqs) != 1:
raise ValueError( raise ValueError(
f'More than one input sequence found in {fasta_path}.') f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0] input_sequence = input_seqs[0]
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
...@@ -280,47 +290,46 @@ class DataPipeline: ...@@ -280,47 +290,46 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hits=alignments['hhsearch_hits'] hits=alignments["hhsearch_hits"],
) )
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
description=input_description, description=input_description,
num_res=num_res num_res=num_res,
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=( msas=(
alignments['uniref90_msa'], alignments["uniref90_msa"],
alignments['bfd_msa'], alignments["bfd_msa"],
alignments['mgnify_msa'] alignments["mgnify_msa"],
), ),
deletion_matrices=( deletion_matrices=(
alignments['uniref90_deletion_matrix'], alignments["uniref90_deletion_matrix"],
alignments['bfd_deletion_matrix'], alignments["bfd_deletion_matrix"],
alignments['mgnify_deletion_matrix'] alignments["mgnify_deletion_matrix"],
) ),
) )
return {**sequence_features, **msa_features, **templates_result.data} return {**sequence_features, **msa_features, **templates_result.data}
def process_mmcif(self, def process_mmcif(
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown. in the object. Otherwise, a ValueError is thrown.
""" """
if(chain_id is None): if chain_id is None:
chains = mmcif.structure.get_chains() chains = mmcif.structure.get_chains()
chain = next(chains, None) chain = next(chains, None)
if(chain is None): if chain is None:
raise ValueError( raise ValueError("No chains in mmCIF file")
'No chains in mmCIF file'
)
chain_id = chain.id chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
...@@ -332,20 +341,20 @@ class DataPipeline: ...@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]), query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits'] hits=alignments["hhsearch_hits"],
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=( msas=(
alignments['uniref90_msa'], alignments["uniref90_msa"],
alignments['bfd_msa'], alignments["bfd_msa"],
alignments['mgnify_msa'] alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
), ),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
) )
return {**mmcif_feats, **templates_result.data, **msa_features} return {**mmcif_feats, **templates_result.data, **msa_features}
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -23,13 +23,23 @@ import torch ...@@ -23,13 +23,23 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.tools import residue_constants as rc from openfold.tools import residue_constants as rc
from openfold.utils.affine_utils import T from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import tree_map, tensor_tree_map, batched_gather from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
batched_gather,
)
MSA_FEATURE_NAMES = [ MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa' "msa",
"deletion_matrix",
"msa_mask",
"msa_row_mask",
"bert_mask",
"true_msa",
] ]
def cast_to_64bit_ints(protein): def cast_to_64bit_ints(protein):
# We keep all ints as int64 # We keep all ints as int64
for k, v in protein.items(): for k, v in protein.items():
...@@ -37,160 +47,196 @@ def cast_to_64bit_ints(protein): ...@@ -37,160 +47,196 @@ def cast_to_64bit_ints(protein):
protein[k] = v.type(torch.int64) protein[k] = v.type(torch.int64)
return protein return protein
def make_one_hot(x, num_classes): def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes) x_one_hot = torch.zeros(*x.shape, num_classes)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1) x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot return x_one_hot
def make_seq_mask(protein): def make_seq_mask(protein):
protein['seq_mask'] = torch.ones(protein['aatype'].shape, dtype=torch.float32) protein["seq_mask"] = torch.ones(
protein["aatype"].shape, dtype=torch.float32
)
return protein return protein
def make_template_mask(protein): def make_template_mask(protein):
protein['template_mask'] = torch.ones( protein["template_mask"] = torch.ones(
protein['template_aatype'].shape[0], dtype=torch.float32 protein["template_aatype"].shape[0], dtype=torch.float32
) )
return protein return protein
def curry1(f): def curry1(f):
"""Supply all arguments but the first.""" """Supply all arguments but the first."""
def fc(*args, **kwargs):
return lambda x: f(x, *args, **kwargs)
def fc(*args, **kwargs): return fc
return lambda x: f(x, *args, **kwargs)
return fc
@curry1 @curry1
def add_distillation_flag(protein, distillation): def add_distillation_flag(protein, distillation):
protein['is_distillation'] = torch.tensor( protein["is_distillation"] = torch.tensor(
float(distillation), dtype=torch.float32 float(distillation), dtype=torch.float32
) )
return protein return protein
def make_all_atom_aatype(protein): def make_all_atom_aatype(protein):
protein['all_atom_aatype'] = protein['aatype'] protein["all_atom_aatype"] = protein["aatype"]
return protein return protein
def fix_templates_aatype(protein): def fix_templates_aatype(protein):
# Map one-hot to indices # Map one-hot to indices
num_templates = protein['template_aatype'].shape[0] num_templates = protein["template_aatype"].shape[0]
protein['template_aatype'] = torch.argmax(protein['template_aatype'], dim=-1) protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype. # Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor( new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
new_order_list, dtype=torch.int64 num_templates, -1
).expand(num_templates, -1) )
protein['template_aatype'] = torch.gather( protein["template_aatype"] = torch.gather(
new_order, 1, index=protein['template_aatype'] new_order, 1, index=protein["template_aatype"]
) )
return protein return protein
def correct_msa_restypes(protein): def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc.""" """Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor( new_order = torch.tensor(
[new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
).transpose(0,1) ).transpose(0, 1)
protein['msa'] = torch.gather(new_order, 0, protein['msa']) protein["msa"] = torch.gather(new_order, 0, protein["msa"])
perm_matrix = np.zeros((22, 22), dtype=np.float32) perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1. perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
for k in protein: for k in protein:
if 'profile' in k: if "profile" in k:
num_dim = protein[k].shape.as_list()[-1] num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [20,21,22], ( assert num_dim in [
'num_dim for %s out of expected range: %s' % (k, num_dim)) 20,
21,
22,
], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim]) protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein return protein
def squeeze_features(protein): def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features.""" """Remove singleton and repeated dimensions in protein features."""
protein['aatype'] = torch.argmax(protein['aatype'], dim=-1) protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
for k in [ for k in [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence', "domain_name",
'superfamily', 'deletion_matrix', 'resolution', "msa",
'between_segment_residues', 'residue_index', 'template_all_atom_mask']: "num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein: if k in protein:
final_dim = protein[k].shape[-1] final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1: if isinstance(final_dim, int) and final_dim == 1:
protein[k] = torch.squeeze(protein[k], dim=-1) protein[k] = torch.squeeze(protein[k], dim=-1)
for k in ['seq_length', 'num_alignments']: for k in ["seq_length", "num_alignments"]:
if k in protein: if k in protein:
protein[k] = protein[k][0] protein[k] = protein[k][0]
return protein return protein
@curry1 @curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion): def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a portion of the MSA with 'X'.""" """Replace a portion of the MSA with 'X'."""
msa_mask = (torch.rand(protein['msa'].shape) < replace_proportion) msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
x_idx = 20 x_idx = 20
gap_idx = 21 gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein['msa'] != gap_idx) msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
protein['msa'] = torch.where(msa_mask, torch.ones_like(protein['msa'])*x_idx, protein["msa"] = torch.where(
protein['msa']) msa_mask, torch.ones_like(protein["msa"]) * x_idx, protein["msa"]
aatype_mask = (
torch.rand(protein['aatype'].shape) < replace_proportion
) )
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
protein['aatype'] = torch.where( protein["aatype"] = torch.where(
aatype_mask, torch.ones_like(protein['aatype']) * x_idx, aatype_mask,
protein['aatype'] torch.ones_like(protein["aatype"]) * x_idx,
protein["aatype"],
) )
return protein return protein
@curry1 @curry1
def sample_msa(protein, max_seq, keep_extra): def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`. """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
""" num_seq = protein["msa"].shape[0]
num_seq = protein['msa'].shape[0] shuffled = torch.randperm(num_seq - 1) + 1
shuffled = torch.randperm(num_seq-1)+1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0) index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq) num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split(index_order, [num_sel, num_seq-num_sel]) sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel]
)
for k in MSA_FEATURE_NAMES: for k in MSA_FEATURE_NAMES:
if k in protein: if k in protein:
if keep_extra: if keep_extra:
protein['extra_'+k] = torch.index_select(protein[k], 0, not_sel_seq) protein["extra_" + k] = torch.index_select(
protein[k], 0, not_sel_seq
)
protein[k] = torch.index_select(protein[k], 0, sel_seq) protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein return protein
@curry1 @curry1
def crop_extra_msa(protein, max_extra_msa): def crop_extra_msa(protein, max_extra_msa):
num_seq = protein['extra_msa'].shape[0] num_seq = protein["extra_msa"].shape[0]
num_sel = min(max_extra_msa, num_seq) num_sel = min(max_extra_msa, num_seq)
select_indices = torch.randperm(num_seq)[:num_sel] select_indices = torch.randperm(num_seq)[:num_sel]
for k in MSA_FEATURE_NAMES: for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein: if "extra_" + k in protein:
protein['extra_'+k] = torch.index_select(protein['extra_'+k], 0, select_indices) protein["extra_" + k] = torch.index_select(
protein["extra_" + k], 0, select_indices
)
return protein return protein
def delete_extra_msa(protein): def delete_extra_msa(protein):
for k in MSA_FEATURE_NAMES: for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein: if "extra_" + k in protein:
del protein['extra_' + k] del protein["extra_" + k]
return protein return protein
# Not used in inference # Not used in inference
@curry1 @curry1
def block_delete_msa(protein, config): def block_delete_msa(protein, config):
num_seq = protein['msa'].shape[0] num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor( block_num_seq = torch.floor(
torch.tensor( torch.tensor(num_seq, dtype=torch.float32)
num_seq, dtype=torch.float32 * config.msa_fraction_per_block
) * config.msa_fraction_per_block
).to(torch.int32) ).to(torch.int32)
if config.randomize_num_blocks: if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(0, config.num_blocks+1).sample() nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
else: else:
nb = config.num_blocks nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb) del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq) del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq-1) del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0] del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence # Make sure we keep the original sequence
...@@ -206,19 +252,19 @@ def block_delete_msa(protein, config): ...@@ -206,19 +252,19 @@ def block_delete_msa(protein, config):
return protein return protein
@curry1 @curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat([ weights = torch.cat(
torch.ones(21), [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
gap_agreement_weight * torch.ones(1), 0,
torch.zeros(1) )
], 0)
# Make agreement score as weighted Hamming distance # Make agreement score as weighted Hamming distance
msa_one_hot = make_one_hot(protein['msa'], 23) msa_one_hot = make_one_hot(protein["msa"], 23)
sample_one_hot = (protein['msa_mask'][:,:,None] * msa_one_hot) sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
extra_msa_one_hot = make_one_hot(protein['extra_msa'], 23) extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
extra_one_hot = (protein['extra_msa_mask'][:,:,None] * extra_msa_one_hot) extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
num_seq, num_res, _ = sample_one_hot.shape num_seq, num_res, _ = sample_one_hot.shape
extra_num_seq, _, _ = extra_one_hot.shape extra_num_seq, _, _ = extra_one_hot.shape
...@@ -226,17 +272,20 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.): ...@@ -226,17 +272,20 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights) # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup. # in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul( agreement = torch.matmul(
torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]), torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
torch.reshape( torch.reshape(
sample_one_hot * weights, [num_seq, num_res * 23] sample_one_hot * weights, [num_seq, num_res * 23]
).transpose(0, 1), ).transpose(0, 1),
) )
# Assign each sequence in the extra sequences to the closest MSA sample # Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).to(torch.int64) protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
torch.int64
)
return protein return protein
def unsorted_segment_sum(data, segment_ids, num_segments): def unsorted_segment_sum(data, segment_ids, num_segments):
""" """
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum. Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
...@@ -264,131 +313,153 @@ def unsorted_segment_sum(data, segment_ids, num_segments): ...@@ -264,131 +313,153 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
tensor = tensor.type(data.dtype) tensor = tensor.type(data.dtype)
return tensor return tensor
@curry1 @curry1
def summarize_clusters(protein): def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster.""" """Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein['msa'].shape[0] num_seq = protein["msa"].shape[0]
def csum(x): def csum(x):
return unsorted_segment_sum( return unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq x, protein["extra_cluster_assignment"], num_seq
) )
mask = protein['extra_msa_mask'] mask = protein["extra_msa_mask"]
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * make_one_hot(protein['extra_msa'], 23)) msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
msa_sum += make_one_hot(protein['msa'], 23) # Original sequence msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None] protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
del msa_sum del msa_sum
del_sum = csum(mask * protein['extra_deletion_matrix']) del_sum = csum(mask * protein["extra_deletion_matrix"])
del_sum += protein['deletion_matrix'] # Original sequence del_sum += protein["deletion_matrix"] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts protein["cluster_deletion_mean"] = del_sum / mask_counts
del del_sum del del_sum
return protein return protein
def make_msa_mask(protein): def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded.""" """Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = torch.ones(protein['msa'].shape, dtype=torch.float32) protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
protein['msa_row_mask'] = torch.ones(protein['msa'].shape[0], dtype=torch.float32) protein["msa_row_mask"] = torch.ones(
protein["msa"].shape[0], dtype=torch.float32
)
return protein return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask): def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features.""" """Create pseudo beta features."""
is_gly = torch.eq(aatype, rc.restype_order['G']) is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order['CA'] ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order['CB'] cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where( pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]), torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :], all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :]) all_atom_positions[..., cb_idx, :],
)
if all_atom_mask is not None: if all_atom_mask is not None:
pseudo_beta_mask = torch.where( pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]) is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
return pseudo_beta, pseudo_beta_mask return pseudo_beta, pseudo_beta_mask
else: else:
return pseudo_beta return pseudo_beta
@curry1 @curry1
def make_pseudo_beta(protein, prefix=''): def make_pseudo_beta(protein, prefix=""):
"""Create pseudo-beta (alpha for glycine) position and mask.""" """Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_'] assert prefix in ["", "template_"]
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = ( (
pseudo_beta_fn( protein[prefix + "pseudo_beta"],
protein['template_aatype' if prefix else 'aatype'], protein[prefix + "pseudo_beta_mask"],
protein[prefix + 'all_atom_positions'], ) = pseudo_beta_fn(
protein['template_all_atom_mask' if prefix else 'all_atom_mask'])) protein["template_aatype" if prefix else "aatype"],
protein[prefix + "all_atom_positions"],
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
)
return protein return protein
@curry1 @curry1
def add_constant_field(protein, key, value): def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value) protein[key] = torch.tensor(value)
return protein return protein
def shaped_categorical(probs, epsilon=1e-10): def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape ds = probs.shape
num_classes = ds[-1] num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical( distribution = torch.distributions.categorical.Categorical(
torch.reshape(probs+epsilon,[-1, num_classes]) torch.reshape(probs + epsilon, [-1, num_classes])
) )
counts = distribution.sample() counts = distribution.sample()
return torch.reshape(counts, ds[:-1]) return torch.reshape(counts, ds[:-1])
def make_hhblits_profile(protein): def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present.""" """Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein: if "hhblits_profile" in protein:
return protein return protein
# Compute the profile for every residue (over all MSA sequences). # Compute the profile for every residue (over all MSA sequences).
msa_one_hot = make_one_hot(protein['msa'], 22) msa_one_hot = make_one_hot(protein["msa"], 22)
protein['hhblits_profile'] = torch.mean(msa_one_hot, dim=0) protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
return protein return protein
@curry1 @curry1
def make_masked_msa(protein, config, replace_fraction): def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA.""" """Create data for BERT on raw MSA."""
# Add a random amino acid uniformly. # Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0., 0.], dtype=torch.float32) random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
categorical_probs = ( categorical_probs = (
config.uniform_prob * random_aa + config.uniform_prob * random_aa
config.profile_prob * protein['hhblits_profile'] + + config.profile_prob * protein["hhblits_profile"]
config.same_prob * make_one_hot(protein['msa'], 22)) + config.same_prob * make_one_hot(protein["msa"], 22)
)
# Put all remaining probability on [MASK] which is a new column # Put all remaining probability on [MASK] which is a new column
pad_shapes = list(reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])) pad_shapes = list(
reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
)
pad_shapes[1] = 1 pad_shapes[1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob mask_prob = (
assert mask_prob >= 0. 1.0 - config.profile_prob - config.same_prob - config.uniform_prob
)
assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad( categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob categorical_probs, pad_shapes, value=mask_prob
) )
sh = protein['msa'].shape sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction mask_position = torch.rand(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs) bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein['msa']) bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
# Mix real and masked MSA # Mix real and masked MSA
protein['bert_mask'] = mask_position.to(torch.float32) protein["bert_mask"] = mask_position.to(torch.float32)
protein['true_msa'] = protein['msa'] protein["true_msa"] = protein["msa"]
protein['msa'] = bert_msa protein["msa"] = bert_msa
return protein return protein
@curry1 @curry1
def make_fixed_size( def make_fixed_size(
protein, protein,
shape_schema, shape_schema,
msa_cluster_size, msa_cluster_size,
extra_msa_size, extra_msa_size,
num_res=0, num_res=0,
num_templates=0 num_templates=0,
): ):
"""Guess at the MSA and sequence dimension to make fixed size.""" """Guess at the MSA and sequence dimension to make fixed size."""
...@@ -401,14 +472,12 @@ def make_fixed_size( ...@@ -401,14 +472,12 @@ def make_fixed_size(
for k, v in protein.items(): for k, v in protein.items():
# Don't transfer this to the accelerator. # Don't transfer this to the accelerator.
if k == 'extra_cluster_assignment': if k == "extra_cluster_assignment":
continue continue
shape = list(v.shape) shape = list(v.shape)
schema = shape_schema[k] schema = shape_schema[k]
msg = "Rank mismatch between shape and shape schema for" msg = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), ( assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
f'{msg} {k}: {shape} vs {schema}'
)
pad_size = [ pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema) pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
] ]
...@@ -422,24 +491,27 @@ def make_fixed_size( ...@@ -422,24 +491,27 @@ def make_fixed_size(
return protein return protein
@curry1 @curry1
def make_msa_feat(protein): def make_msa_feat(protein):
"""Create and concatenate MSA features.""" """Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for # Whether there is a domain break. Always zero for chains, but keeping for
# compatibility with domain datasets. # compatibility with domain datasets.
has_break = torch.clip( has_break = torch.clip(
protein['between_segment_residues'].to(torch.float32), 0, 1 protein["between_segment_residues"].to(torch.float32), 0, 1
) )
aatype_1hot = make_one_hot(protein['aatype'], 21) aatype_1hot = make_one_hot(protein["aatype"], 21)
target_feat = [ target_feat = [
torch.unsqueeze(has_break, dim=-1), torch.unsqueeze(has_break, dim=-1),
aatype_1hot, # Everyone gets the original sequence. aatype_1hot, # Everyone gets the original sequence.
] ]
msa_1hot = make_one_hot(protein['msa'], 23) msa_1hot = make_one_hot(protein["msa"], 23)
has_deletion = torch.clip(protein['deletion_matrix'], 0., 1.) has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
deletion_value = torch.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi) deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
2.0 / np.pi
)
msa_feat = [ msa_feat = [
msa_1hot, msa_1hot,
...@@ -447,24 +519,27 @@ def make_msa_feat(protein): ...@@ -447,24 +519,27 @@ def make_msa_feat(protein):
torch.unsqueeze(deletion_value, dim=-1), torch.unsqueeze(deletion_value, dim=-1),
] ]
if 'cluster_profile' in protein: if "cluster_profile" in protein:
deletion_mean_value = ( deletion_mean_value = torch.atan(
torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi) protein["cluster_deletion_mean"] / 3.0
) * (2.0 / np.pi)
msa_feat.extend(
[
protein["cluster_profile"],
torch.unsqueeze(deletion_mean_value, dim=-1),
]
) )
msa_feat.extend([protein['cluster_profile'],
torch.unsqueeze(deletion_mean_value, dim=-1),
])
if 'extra_deletion_matrix' in protein: if "extra_deletion_matrix" in protein:
protein['extra_has_deletion'] = torch.clip( protein["extra_has_deletion"] = torch.clip(
protein['extra_deletion_matrix'], 0., 1. protein["extra_deletion_matrix"], 0.0, 1.0
) )
protein['extra_deletion_value'] = torch.atan( protein["extra_deletion_value"] = torch.atan(
protein['extra_deletion_matrix'] / 3. protein["extra_deletion_matrix"] / 3.0
) * (2. / np.pi) ) * (2.0 / np.pi)
protein['msa_feat'] = torch.cat(msa_feat, dim=-1) protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
protein['target_feat'] = torch.cat(target_feat, dim=-1) protein["target_feat"] = torch.cat(target_feat, dim=-1)
return protein return protein
...@@ -476,7 +551,7 @@ def select_feat(protein, feature_list): ...@@ -476,7 +551,7 @@ def select_feat(protein, feature_list):
@curry1 @curry1
def crop_templates(protein, max_templates): def crop_templates(protein, max_templates):
for k, v in protein.items(): for k, v in protein.items():
if k.startswith('template_'): if k.startswith("template_"):
protein[k] = v[:max_templates] protein[k] = v[:max_templates]
return protein return protein
...@@ -488,57 +563,58 @@ def make_atom14_masks(protein): ...@@ -488,57 +563,58 @@ def make_atom14_masks(protein):
restype_atom14_mask = [] restype_atom14_mask = []
for rt in rc.restypes: for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
rc.restype_1to3[rt] restype_atom14_to_atom37.append(
] [(rc.atom_order[name] if name else 0) for name in atom_names]
restype_atom14_to_atom37.append([ )
(rc.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)} atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([ restype_atom37_to_atom14.append(
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0) [
for name in rc.atom_types (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
]) for name in rc.atom_types
]
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names]) )
restype_atom14_mask.append(
[(1.0 if name else 0.0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK' # Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14) restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37) restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14) restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor( restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37, restype_atom14_to_atom37,
dtype=torch.int32, dtype=torch.int32,
device=protein['aatype'].device, device=protein["aatype"].device,
) )
restype_atom37_to_atom14 = torch.tensor( restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14, restype_atom37_to_atom14,
dtype=torch.int32, dtype=torch.int32,
device=protein['aatype'].device, device=protein["aatype"].device,
) )
restype_atom14_mask = torch.tensor( restype_atom14_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=torch.float32, dtype=torch.float32,
device=protein['aatype'].device, device=protein["aatype"].device,
) )
# create the mapping for (residx, atom14) --> atom37, i.e. an array # create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein # with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein['aatype']] residx_atom14_to_atom37 = restype_atom14_to_atom37[protein["aatype"]]
residx_atom14_mask = restype_atom14_mask[protein['aatype']] residx_atom14_mask = restype_atom14_mask[protein["aatype"]]
protein['atom14_atom_exists'] = residx_atom14_mask protein["atom14_atom_exists"] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37.long() protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back # create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein['aatype']] residx_atom37_to_atom14 = restype_atom37_to_atom14[protein["aatype"]]
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14.long() protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask # create the corresponding mask
restype_atom37_mask = torch.zeros( restype_atom37_mask = torch.zeros(
[21, 37], dtype=torch.float32, device=protein['aatype'].device [21, 37], dtype=torch.float32, device=protein["aatype"].device
) )
for restype, restype_letter in enumerate(rc.restypes): for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter] restype_name = rc.restype_1to3[restype_letter]
...@@ -547,8 +623,8 @@ def make_atom14_masks(protein): ...@@ -547,8 +623,8 @@ def make_atom14_masks(protein):
atom_type = rc.atom_order[atom_name] atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1 restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein['aatype']] residx_atom37_mask = restype_atom37_mask[protein["aatype"]]
protein['atom37_atom_exists'] = residx_atom37_mask protein["atom37_atom_exists"] = residx_atom37_mask
return protein return protein
...@@ -564,13 +640,13 @@ def make_atom14_positions(protein): ...@@ -564,13 +640,13 @@ def make_atom14_positions(protein):
"""Constructs denser atom positions (14 dimensions instead of 37).""" """Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask = protein["atom14_atom_exists"] residx_atom14_mask = protein["atom14_atom_exists"]
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"] residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
# Create a mask for known ground truth positions. # Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * batched_gather( residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
protein["all_atom_mask"], protein["all_atom_mask"],
residx_atom14_to_atom37, residx_atom14_to_atom37,
dim=-1, dim=-1,
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]) no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
) )
# Gather the ground truth positions. # Gather the ground truth positions.
...@@ -579,86 +655,86 @@ def make_atom14_positions(protein): ...@@ -579,86 +655,86 @@ def make_atom14_positions(protein):
protein["all_atom_positions"], protein["all_atom_positions"],
residx_atom14_to_atom37, residx_atom14_to_atom37,
dim=-2, dim=-2,
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]) no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
) )
) )
protein["atom14_atom_exists"] = residx_atom14_mask protein["atom14_atom_exists"] = residx_atom14_mask
protein["atom14_gt_exists"] = residx_atom14_gt_mask protein["atom14_gt_exists"] = residx_atom14_gt_mask
protein["atom14_gt_positions"] = residx_atom14_gt_positions protein["atom14_gt_positions"] = residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped # alternative ground truth coordinates where the naming is swapped
restype_3 = [ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
rc.restype_1to3[res] for res in rc.restypes
]
restype_3 += ["UNK"] restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms. # Matrices for renaming ambiguous atoms.
all_matrices = { all_matrices = {
res: torch.eye( res: torch.eye(
14, 14,
dtype=protein["all_atom_mask"].dtype, dtype=protein["all_atom_mask"].dtype,
device=protein["all_atom_mask"].device device=protein["all_atom_mask"].device,
) for res in restype_3 )
for res in restype_3
} }
for resname, swap in rc.residue_atom_renaming_swaps.items(): for resname, swap in rc.residue_atom_renaming_swaps.items():
correspondences = torch.arange(14, device=protein["all_atom_mask"].device) correspondences = torch.arange(
for source_atom_swap, target_atom_swap in swap.items(): 14, device=protein["all_atom_mask"].device
source_index = rc.restype_name_to_atom14_names[ )
resname].index(source_atom_swap) for source_atom_swap, target_atom_swap in swap.items():
target_index = rc.restype_name_to_atom14_names[ source_index = rc.restype_name_to_atom14_names[resname].index(
resname].index(target_atom_swap) source_atom_swap
correspondences[source_index] = target_index )
correspondences[target_index] = source_index target_index = rc.restype_name_to_atom14_names[resname].index(
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14)) target_atom_swap
for index, correspondence in enumerate(correspondences): )
renaming_matrix[index, correspondence] = 1. correspondences[source_index] = target_index
all_matrices[resname] = renaming_matrix correspondences[target_index] = source_index
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack( renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3] [all_matrices[restype] for restype in restype_3]
) )
# Pick the transformation matrices for the given residue sequence # Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14). # shape (num_res, 14, 14).
renaming_transform = renaming_matrices[protein["aatype"]] renaming_transform = renaming_matrices[protein["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3). # Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = torch.einsum( alternative_gt_positions = torch.einsum(
"...rac,...rab->...rbc", "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
residx_atom14_gt_positions,
renaming_transform
) )
protein["atom14_alt_gt_positions"] = alternative_gt_positions protein["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the # Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a # ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position). # ground truth position).
alternative_gt_mask = torch.einsum( alternative_gt_mask = torch.einsum(
"...ra,...rab->...rb", "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
residx_atom14_gt_mask, )
renaming_transform
)
protein["atom14_alt_gt_exists"] = alternative_gt_mask protein["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14). # Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14)) restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
for resname, swap in rc.residue_atom_renaming_swaps.items(): for resname, swap in rc.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items(): for atom_name1, atom_name2 in swap.items():
restype = rc.restype_order[ restype = rc.restype_order[rc.restype_3to1[resname]]
rc.restype_3to1[resname]] atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
atom_idx1 = rc.restype_name_to_atom14_names[resname].index( atom_name1
atom_name1) )
atom_idx2 = rc.restype_name_to_atom14_names[resname].index( atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
atom_name2) atom_name2
restype_atom14_is_ambiguous[restype, atom_idx1] = 1 )
restype_atom14_is_ambiguous[restype, atom_idx2] = 1 restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence. # From this create an ambiguous_mask for the given sequence.
protein["atom14_atom_is_ambiguous"] = ( protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
restype_atom14_is_ambiguous[protein["aatype"]] protein["aatype"]
) ]
return protein return protein
...@@ -669,30 +745,30 @@ def atom37_to_frames(protein): ...@@ -669,30 +745,30 @@ def atom37_to_frames(protein):
batch_dims = len(aatype.shape[:-1]) batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object) restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N'] restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O'] restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
for restype, restype_letter in enumerate(rc.restypes): for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter] resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4): for chi_idx in range(4):
if(rc.chi_angles_mask[restype][chi_idx]): if rc.chi_angles_mask[restype][chi_idx]:
names = rc.chi_angles_atoms[resname][chi_idx] names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[ restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, : restype, chi_idx + 4, :
] = names[1:] ] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros( restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8), (*aatype.shape[:-1], 21, 8),
) )
restype_rigidgroup_mask[..., 0] = 1 restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1 restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = ( restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
all_atom_mask.new_tensor(rc.chi_angles_mask) rc.chi_angles_mask
) )
lookuptable = rc.atom_order.copy() lookuptable = rc.atom_order.copy()
lookuptable[''] = 0 lookuptable[""] = 0
lookup = np.vectorize(lambda x: lookuptable[x]) lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup( restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names, restype_rigidgroup_base_atom_names,
...@@ -702,8 +778,7 @@ def atom37_to_frames(protein): ...@@ -702,8 +778,7 @@ def atom37_to_frames(protein):
) )
restype_rigidgroup_base_atom37_idx = ( restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view( restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims), *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
*restype_rigidgroup_base_atom37_idx.shape
) )
) )
...@@ -713,7 +788,7 @@ def atom37_to_frames(protein): ...@@ -713,7 +788,7 @@ def atom37_to_frames(protein):
dim=-3, dim=-3,
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
base_atom_pos = batched_gather( base_atom_pos = batched_gather(
all_atom_positions, all_atom_positions,
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
...@@ -729,9 +804,9 @@ def atom37_to_frames(protein): ...@@ -729,9 +804,9 @@ def atom37_to_frames(protein):
) )
group_exists = batched_gather( group_exists = batched_gather(
restype_rigidgroup_mask, restype_rigidgroup_mask,
aatype, aatype,
dim=-2, dim=-2,
no_batch_dims=batch_dims, no_batch_dims=batch_dims,
) )
...@@ -739,19 +814,17 @@ def atom37_to_frames(protein): ...@@ -739,19 +814,17 @@ def atom37_to_frames(protein):
all_atom_mask, all_atom_mask,
residx_rigidgroup_base_atom37_idx, residx_rigidgroup_base_atom37_idx,
dim=-1, dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1]) no_batch_dims=len(all_atom_mask.shape[:-1]),
) )
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye( rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1)) rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1 rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1 rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None)) gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros( restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8 *((1,) * batch_dims), 21, 8
) )
...@@ -764,12 +837,10 @@ def atom37_to_frames(protein): ...@@ -764,12 +837,10 @@ def atom37_to_frames(protein):
) )
for resname, _ in rc.residue_atom_renaming_swaps.items(): for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[ restype = rc.restype_order[rc.restype_3to1[resname]]
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1) chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1 restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1 restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1 restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather( residx_rigidgroup_is_ambiguous = batched_gather(
...@@ -791,18 +862,18 @@ def atom37_to_frames(protein): ...@@ -791,18 +862,18 @@ def atom37_to_frames(protein):
gt_frames_tensor = gt_frames.to_4x4() gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4() alt_gt_frames_tensor = alt_gt_frames.to_4x4()
protein['rigidgroups_gt_frames'] = gt_frames_tensor protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein['rigidgroups_gt_exists'] = gt_exists protein["rigidgroups_gt_exists"] = gt_exists
protein['rigidgroups_group_exists'] = group_exists protein["rigidgroups_group_exists"] = group_exists
protein['rigidgroups_group_is_ambiguous'] = residx_rigidgroup_is_ambiguous protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
protein['rigidgroups_alt_gt_frames'] = alt_gt_frames_tensor protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
return protein return protein
def get_chi_atom_indices(): def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types. """Returns atom indices needed to compute chi angles for all residue types.
Returns: Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type in the order specified in rc.restypes + unknown residue type
...@@ -811,57 +882,58 @@ def get_chi_atom_indices(): ...@@ -811,57 +882,58 @@ def get_chi_atom_indices():
""" """
chi_atom_indices = [] chi_atom_indices = []
for residue_name in rc.restypes: for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name] residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name] residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = [] atom_indices = []
for chi_angle in residue_chi_angles: for chi_angle in residue_chi_angles:
atom_indices.append( atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
[rc.atom_order[atom] for atom in chi_angle]) for _ in range(4 - len(atom_indices)):
for _ in range(4 - len(atom_indices)): atom_indices.append(
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA. [0, 0, 0, 0]
chi_atom_indices.append(atom_indices) ) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue. chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices return chi_atom_indices
@curry1 @curry1
def atom37_to_torsion_angles( def atom37_to_torsion_angles(
protein, protein,
prefix='', prefix="",
): ):
""" """
Convert coordinates to torsion angles. Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible. and should be run with double precision whenever possible.
Args: Args:
Dict containing: Dict containing:
* (prefix)aatype: * (prefix)aatype:
[*, N_res] residue indices [*, N_res] residue indices
* (prefix)all_atom_positions: * (prefix)all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37 [*, N_res, 37, 3] atom positions (in atom37
format) format)
* (prefix)all_atom_mask: * (prefix)all_atom_mask:
[*, N_res, 37] atom position mask [*, N_res, 37] atom position mask
Returns: Returns:
The same dictionary updated with the following features: The same dictionary updated with the following features:
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2]) "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles Torsion angles
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2]) "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry) Alternate torsion angles (accounting for 180-degree symmetry)
"(prefix)torsion_angles_mask" ([*, N_res, 7]) "(prefix)torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask Torsion angles mask
""" """
aatype = protein[prefix + "aatype"] aatype = protein[prefix + "aatype"]
all_atom_positions = protein[prefix + "all_atom_positions"] all_atom_positions = protein[prefix + "all_atom_positions"]
all_atom_mask = protein[prefix + "all_atom_mask"] all_atom_mask = protein[prefix + "all_atom_mask"]
aatype = torch.clamp(aatype, max=20) aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros( pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3] [*all_atom_positions.shape[:-3], 1, 37, 3]
) )
...@@ -873,35 +945,27 @@ def atom37_to_torsion_angles( ...@@ -873,35 +945,27 @@ def atom37_to_torsion_angles(
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2) prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat( pre_omega_atom_pos = torch.cat(
[ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
prev_all_atom_positions[..., 1:3, :], dim=-2,
all_atom_positions[..., :2, :]
], dim=-2
) )
phi_atom_pos = torch.cat( phi_atom_pos = torch.cat(
[ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
prev_all_atom_positions[..., 2:3, :], dim=-2,
all_atom_positions[..., :3, :]
], dim=-2
) )
psi_atom_pos = torch.cat( psi_atom_pos = torch.cat(
[ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
all_atom_positions[..., :3, :], dim=-2,
all_atom_positions[..., 4:5, :]
], dim=-2
) )
pre_omega_mask = ( pre_omega_mask = torch.prod(
torch.prod(prev_all_atom_mask[..., 1:3], dim=-1) * prev_all_atom_mask[..., 1:3], dim=-1
torch.prod(all_atom_mask[..., :2], dim=-1) ) * torch.prod(all_atom_mask[..., :2], dim=-1)
) phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
phi_mask = ( all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
) )
psi_mask = ( psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) * torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
all_atom_mask[..., 4] * all_atom_mask[..., 4]
) )
chi_atom_indices = torch.as_tensor( chi_atom_indices = torch.as_tensor(
...@@ -914,16 +978,16 @@ def atom37_to_torsion_angles( ...@@ -914,16 +978,16 @@ def atom37_to_torsion_angles(
) )
chi_angles_mask = list(rc.chi_angles_mask) chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.]) chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask) chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :] chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather( chi_angle_atoms_mask = batched_gather(
all_atom_mask, all_atom_mask,
atom_indices, atom_indices,
dim=-1, dim=-1,
no_batch_dims=len(atom_indices.shape[:-2]) no_batch_dims=len(atom_indices.shape[:-2]),
) )
chi_angle_atoms_mask = torch.prod( chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
...@@ -936,7 +1000,8 @@ def atom37_to_torsion_angles( ...@@ -936,7 +1000,8 @@ def atom37_to_torsion_angles(
phi_atom_pos[..., None, :, :], phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :], psi_atom_pos[..., None, :, :],
chis_atom_pos, chis_atom_pos,
], dim=-3 ],
dim=-3,
) )
torsion_angles_mask = torch.cat( torsion_angles_mask = torch.cat(
...@@ -945,7 +1010,8 @@ def atom37_to_torsion_angles( ...@@ -945,7 +1010,8 @@ def atom37_to_torsion_angles(
phi_mask[..., None], phi_mask[..., None],
psi_mask[..., None], psi_mask[..., None],
chis_mask, chis_mask,
], dim=-1 ],
dim=-1,
) )
torsion_frames = T.from_3_points( torsion_frames = T.from_3_points(
...@@ -965,16 +1031,17 @@ def atom37_to_torsion_angles( ...@@ -965,16 +1031,17 @@ def atom37_to_torsion_angles(
denom = torch.sqrt( denom = torch.sqrt(
torch.sum( torch.sum(
torch.square(torsion_angles_sin_cos), torch.square(torsion_angles_sin_cos),
dim=-1, dim=-1,
dtype=torsion_angles_sin_cos.dtype, dtype=torsion_angles_sin_cos.dtype,
keepdims=True keepdims=True,
) + 1e-8 )
+ 1e-8
) )
torsion_angles_sin_cos = torsion_angles_sin_cos / denom torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor( torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.], [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)] )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor( chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
...@@ -984,8 +1051,9 @@ def atom37_to_torsion_angles( ...@@ -984,8 +1051,9 @@ def atom37_to_torsion_angles(
mirror_torsion_angles = torch.cat( mirror_torsion_angles = torch.cat(
[ [
all_atom_mask.new_ones(*aatype.shape, 3), all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous 1.0 - 2.0 * chi_is_ambiguous,
], dim=-1 ],
dim=-1,
) )
alt_torsion_angles_sin_cos = ( alt_torsion_angles_sin_cos = (
...@@ -995,18 +1063,16 @@ def atom37_to_torsion_angles( ...@@ -995,18 +1063,16 @@ def atom37_to_torsion_angles(
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
return protein return protein
def get_backbone_frames(protein): def get_backbone_frames(protein):
# TODO: Verify that this is correct # TODO: Verify that this is correct
protein["backbone_affine_tensor"] = ( protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][
protein["rigidgroups_gt_frames"][..., 0, :, :] ..., 0, :, :
) ]
protein["backbone_affine_mask"] = ( protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0]
protein["rigidgroups_gt_exists"][..., 0]
)
return protein return protein
...@@ -1023,38 +1089,43 @@ def get_chi_angles(protein): ...@@ -1023,38 +1089,43 @@ def get_chi_angles(protein):
@curry1 @curry1
def random_crop_to_size( def random_crop_to_size(
protein, protein,
crop_size, crop_size,
max_templates, max_templates,
shape_schema, shape_schema,
subsample_templates=False, subsample_templates=False,
seed=None, seed=None,
batch_mode='clamped' batch_mode="clamped",
): ):
"""Crop randomly to `crop_size`, or keep as is if shorter than that.""" """Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length'] seq_length = protein["seq_length"]
if 'template_mask' in protein: if "template_mask" in protein:
num_templates = protein['template_mask'].shape[-1] num_templates = protein["template_mask"].shape[-1]
else: else:
num_templates = protein['aatype'].new_zeros((1,)) num_templates = protein["aatype"].new_zeros((1,))
num_res_crop_size = min(seq_length, crop_size) num_res_crop_size = min(seq_length, crop_size)
# We want each ensemble to be cropped the same way # We want each ensemble to be cropped the same way
g = torch.Generator(device=protein['seq_length'].device) g = torch.Generator(device=protein["seq_length"].device)
if(seed is not None): if seed is not None:
g.manual_seed(seed) g.manual_seed(seed)
def _randint(lower, upper): def _randint(lower, upper):
return int(torch.randint( return int(
lower, upper, (1,), torch.randint(
device=protein['seq_length'].device, generator=g lower,
)[0]) upper,
(1,),
device=protein["seq_length"].device,
generator=g,
)[0]
)
if subsample_templates: if subsample_templates:
templates_crop_start = _randint(0, num_templates + 1) templates_crop_start = _randint(0, num_templates + 1)
templates_select_indices = torch.randperm( templates_select_indices = torch.randperm(
num_templates, device=protein['seq_length'].device, generator=g num_templates, device=protein["seq_length"].device, generator=g
) )
num_templates_crop_size = min( num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates num_templates - templates_crop_start, max_templates
...@@ -1062,11 +1133,11 @@ def random_crop_to_size( ...@@ -1062,11 +1133,11 @@ def random_crop_to_size(
else: else:
templates_crop_start = 0 templates_crop_start = 0
num_templates_crop_size = num_templates num_templates_crop_size = num_templates
n = seq_length - num_res_crop_size n = seq_length - num_res_crop_size
if(batch_mode == 'clamped'): if batch_mode == "clamped":
right_anchor = n + 1 right_anchor = n + 1
elif(batch_mode == 'unclamped'): elif batch_mode == "unclamped":
x = _randint(0, n) x = _randint(0, n)
right_anchor = n - x + 1 right_anchor = n - x + 1
else: else:
...@@ -1075,29 +1146,26 @@ def random_crop_to_size( ...@@ -1075,29 +1146,26 @@ def random_crop_to_size(
num_res_crop_start = _randint(0, right_anchor) num_res_crop_start = _randint(0, right_anchor)
for k, v in protein.items(): for k, v in protein.items():
if (k not in shape_schema or if k not in shape_schema or (
('template' not in k and NUM_RES not in shape_schema[k]) "template" not in k and NUM_RES not in shape_schema[k]
): ):
continue continue
# randomly permute the templates before cropping them. # randomly permute the templates before cropping them.
if k.startswith('template') and subsample_templates: if k.startswith("template") and subsample_templates:
v = v[templates_select_indices] v = v[templates_select_indices]
slices = [] slices = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
v.shape)): is_num_res = dim_size == NUM_RES
is_num_res = (dim_size == NUM_RES) if i == 0 and k.startswith("template"):
if i == 0 and k.startswith('template'):
crop_size = num_templates_crop_size crop_size = num_templates_crop_size
crop_start = templates_crop_start crop_start = templates_crop_start
else: else:
crop_start = num_res_crop_start if is_num_res else 0 crop_start = num_res_crop_start if is_num_res else 0
crop_size = num_res_crop_size if is_num_res else dim crop_size = num_res_crop_size if is_num_res else dim
slices.append(slice(crop_start, crop_start + crop_size)) slices.append(slice(crop_start, crop_start + crop_size))
protein[k] = v[slices] protein[k] = v[slices]
protein['seq_length'] = ( protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
protein['seq_length'].new_tensor(num_res_crop_size)
)
return protein return protein
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -26,10 +26,11 @@ from openfold.data import input_pipeline ...@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor] TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict( def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray], np_example: Mapping[str, np.ndarray],
features: Sequence[str], features: Sequence[str],
) -> TensorDict: ) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays. """Creates dict of tensors from a dict of NumPy arrays.
Args: Args:
...@@ -47,14 +48,14 @@ def np_to_tensor_dict( ...@@ -47,14 +48,14 @@ def np_to_tensor_dict(
def make_data_config( def make_data_config(
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str, mode: str,
num_res: int, num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]: ) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config) cfg = copy.deepcopy(config)
mode_cfg = cfg[mode] mode_cfg = cfg[mode]
with cfg.unlocked(): with cfg.unlocked():
if(mode_cfg.crop_size is None): if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
...@@ -62,7 +63,7 @@ def make_data_config( ...@@ -62,7 +63,7 @@ def make_data_config(
if cfg.common.use_templates: if cfg.common.use_templates:
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
if(cfg[mode].supervised): if cfg[mode].supervised:
feature_names += cfg.common.supervised_features feature_names += cfg.common.supervised_features
return cfg, feature_names return cfg, feature_names
...@@ -75,47 +76,47 @@ def np_example_to_features( ...@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode: str, batch_mode: str,
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example['seq_length'][0]) num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config( cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
config, mode=mode, num_res=num_res
)
if 'deletion_matrix_int' in np_example: if "deletion_matrix_int" in np_example:
np_example['deletion_matrix'] = ( np_example["deletion_matrix"] = np_example.pop(
np_example.pop('deletion_matrix_int').astype(np.float32) "deletion_matrix_int"
) ).astype(np.float32)
if batch_mode == 'clamped': if batch_mode == "clamped":
np_example['use_clamped_fape'] = ( np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
np.array(1.).astype(np.float32) elif batch_mode == "unclamped":
) np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode, tensor_dict,
cfg.common,
cfg[mode],
batch_mode=batch_mode,
) )
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
class FeaturePipeline: class FeaturePipeline:
def __init__(self, def __init__(
config: ml_collections.ConfigDict, self,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
):
self.config = config self.config = config
self.params = params self.params = params
def process_features(self, def process_features(
self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = 'train', mode: str = "train",
batch_mode: str = 'clamped', batch_mode: str = "clamped",
) -> FeatureDict: ) -> FeatureDict:
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
......
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_hhblits_profile, data_transforms.make_hhblits_profile,
] ]
if common_cfg.use_templates: if common_cfg.use_templates:
transforms.extend([ transforms.extend(
data_transforms.fix_templates_aatype, [
data_transforms.make_template_mask, data_transforms.fix_templates_aatype,
data_transforms.make_pseudo_beta('template_') data_transforms.make_template_mask,
]) data_transforms.make_pseudo_beta("template_"),
if(common_cfg.use_template_torsion_angles): ]
transforms.extend([ )
data_transforms.atom37_to_torsion_angles('template_'), if common_cfg.use_template_torsion_angles:
]) transforms.extend(
[
transforms.extend([ data_transforms.atom37_to_torsion_angles("template_"),
data_transforms.make_atom14_masks, ]
]) )
if(mode_cfg.supervised): transforms.extend(
transforms.extend([ [
data_transforms.make_atom14_positions, data_transforms.make_atom14_masks,
data_transforms.atom37_to_frames, ]
data_transforms.atom37_to_torsion_angles(''), )
data_transforms.make_pseudo_beta(''),
data_transforms.get_backbone_frames, if mode_cfg.supervised:
data_transforms.get_chi_angles, transforms.extend(
]) [
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms return transforms
...@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms.sample_msa(max_msa_clusters, keep_extra=True) data_transforms.sample_msa(max_msa_clusters, keep_extra=True)
) )
if 'masked_msa' in common_cfg: if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that # Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about # the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa( data_transforms.make_masked_msa(
common_cfg.masked_msa, common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
mode_cfg.masked_msa_replace_fraction
) )
) )
...@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if mode_cfg.fixed_size: if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats))) transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size( transforms.append(
mode_cfg.crop_size, data_transforms.random_crop_to_size(
mode_cfg.max_templates, mode_cfg.crop_size,
crop_feats, mode_cfg.max_templates,
mode_cfg.subsample_templates, crop_feats,
batch_mode=batch_mode, mode_cfg.subsample_templates,
seed=torch.Generator().seed() batch_mode=batch_mode,
)) seed=torch.Generator().seed(),
transforms.append(data_transforms.make_fixed_size( )
crop_feats, )
pad_msa_clusters, transforms.append(
common_cfg.max_extra_msa, data_transforms.make_fixed_size(
mode_cfg.crop_size, crop_feats,
mode_cfg.max_templates pad_msa_clusters,
)) common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
)
else: else:
transforms.append( transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates) data_transforms.crop_templates(mode_cfg.max_templates)
...@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def process_tensors_from_config( def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped' tensors, common_cfg, mode_cfg, batch_mode="clamped"
): ):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
...@@ -136,12 +147,10 @@ def process_tensors_from_config( ...@@ -136,12 +147,10 @@ def process_tensors_from_config(
d = data.copy() d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode) fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns) fn = compose(fns)
d['ensemble_index'] = i d["ensemble_index"] = i
return fn(d) return fn(d)
tensors = compose( tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0) tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
...@@ -150,8 +159,9 @@ def process_tensors_from_config( ...@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble *= common_cfg.num_recycle + 1 num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1: if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x), tensors = map_fn(
torch.arange(num_ensemble)) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else: else:
tensors = tree.map_structure(lambda x: x[None], tensors_0) tensors = tree.map_structure(lambda x: x[None], tensors_0)
......
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -39,368 +39,404 @@ MmCIFDict = Mapping[str, Sequence[str]] ...@@ -39,368 +39,404 @@ MmCIFDict = Mapping[str, Sequence[str]]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Monomer: class Monomer:
id: str id: str
num: int num: int
# Note - mmCIF format provides no guarantees on the type of author-assigned # Note - mmCIF format provides no guarantees on the type of author-assigned
# sequence numbers. They need not be integers. # sequence numbers. They need not be integers.
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class AtomSite: class AtomSite:
residue_name: str residue_name: str
author_chain_id: str author_chain_id: str
mmcif_chain_id: str mmcif_chain_id: str
author_seq_num: str author_seq_num: str
mmcif_seq_num: int mmcif_seq_num: int
insertion_code: str insertion_code: str
hetatm_atom: str hetatm_atom: str
model_num: int model_num: int
# Used to map SEQRES index to a residue in the structure. # Used to map SEQRES index to a residue in the structure.
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ResiduePosition: class ResiduePosition:
chain_id: str chain_id: str
residue_number: int residue_number: int
insertion_code: str insertion_code: str
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ResidueAtPosition: class ResidueAtPosition:
position: Optional[ResiduePosition] position: Optional[ResiduePosition]
name: str name: str
is_missing: bool is_missing: bool
hetflag: str hetflag: str
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class MmcifObject: class MmcifObject:
"""Representation of a parsed mmCIF file. """Representation of a parsed mmCIF file.
Contains: Contains:
file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all
files being processed. files being processed.
header: Biopython header. header: Biopython header.
structure: Biopython structure. structure: Biopython structure.
chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g.
{'A': 'ABCDEFG'} {'A': 'ABCDEFG'}
seqres_to_structure: Dict; for each chain_id contains a mapping between seqres_to_structure: Dict; for each chain_id contains a mapping between
SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition,
1: ResidueAtPosition, 1: ResidueAtPosition,
...}} ...}}
raw_string: The raw string used to construct the MmcifObject. raw_string: The raw string used to construct the MmcifObject.
""" """
file_id: str
header: PdbHeader file_id: str
structure: PdbStructure header: PdbHeader
chain_to_seqres: Mapping[ChainId, SeqRes] structure: PdbStructure
seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] chain_to_seqres: Mapping[ChainId, SeqRes]
raw_string: Any seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]]
raw_string: Any
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ParsingResult: class ParsingResult:
"""Returned by the parse function. """Returned by the parse function.
Contains: Contains:
mmcif_object: A MmcifObject, may be None if no chain could be successfully mmcif_object: A MmcifObject, may be None if no chain could be successfully
parsed. parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated. errors: A dict mapping (file_id, chain_id) to any exception generated.
""" """
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any] mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
class ParseError(Exception): class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed.""" """An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(prefix: str, def mmcif_loop_to_list(
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: prefix: str, parsed_info: MmCIFDict
"""Extracts loop associated with a prefix from mmCIF data as a list. ) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html Reference for loop_ in mmCIF:
http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html
Args:
prefix: Prefix shared by each of the data items in the loop. Args:
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, prefix: Prefix shared by each of the data items in the loop.
_entity_poly_seq.mon_id. Should include the trailing period. e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython _entity_poly_seq.mon_id. Should include the trailing period.
parser. parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. Returns:
""" Returns a list of dicts; each dict represents 1 entry from an mmCIF loop.
cols = [] """
data = [] cols = []
for key, value in parsed_info.items(): data = []
if key.startswith(prefix):
cols.append(key)
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
prefix: Prefix shared by each of the data items in the loop.
e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
_entity_poly_seq.mon_id. Should include the trailing period.
index: Which item of loop data should serve as the key.
parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
Returns:
Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
indexed by the index column.
"""
entries = mmcif_loop_to_list(prefix, parsed_info)
return {entry[index]: entry for entry in entries}
def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
file_id: A string identifier for this file. Should be unique within the
collection of files being processed.
mmcif_string: Contents of an mmCIF file.
catch_all_errors: If True, all exceptions are caught and error messages are
returned as part of the ParsingResult. If False exceptions will be allowed
to propagate.
Returns:
A ParsingResult.
"""
errors = {}
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
parsed_info = parser._mmcif_dict # pylint:disable=protected-access
# Ensure all values are lists, even if singletons.
for key, value in parsed_info.items(): for key, value in parsed_info.items():
if not isinstance(value, list): if key.startswith(prefix):
parsed_info[key] = [value] cols.append(key)
data.append(value)
header = _get_header(parsed_info)
assert all([len(xs) == len(data[0]) for xs in data]), (
# Determine the protein chains, and their start numbers according to the "mmCIF error: Not all loops are the same length: %s" % cols
# internal mmCIF numbering scheme (likely but not guaranteed to be 1). )
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains: return [dict(zip(cols, xs)) for xs in zip(*data)]
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {chain_id: min([monomer.num for monomer in seq]) def mmcif_loop_to_dict(
for chain_id, seq in valid_chains.items()} prefix: str,
index: str,
# Loop over the atoms for which we have coordinates. Populate two mappings: parsed_info: MmCIFDict,
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used ) -> Mapping[str, Mapping[str, str]]:
# the authors / Biopython). """Extracts loop associated with a prefix from mmCIF data as a dictionary.
# -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
mmcif_to_author_chain_id = {} Args:
seq_to_structure_mappings = {} prefix: Prefix shared by each of the data items in the loop.
for atom in _get_atom_site_list(parsed_info): e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num,
if atom.model_num != '1': _entity_poly_seq.mon_id. Should include the trailing period.
# We only process the first model at the moment. index: Which item of loop data should serve as the key.
continue parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython
parser.
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
Returns:
if atom.mmcif_chain_id in valid_chains: Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop,
hetflag = ' ' indexed by the index column.
if atom.hetatm_atom == 'HETATM': """
# Water atoms are assigned a special hetflag of W in Biopython. We entries = mmcif_loop_to_list(prefix, parsed_info)
# need to do the same, so that this hetflag can be used to fetch return {entry[index]: entry for entry in entries}
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W' def parse(
else: *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
hetflag = 'H_' + atom.residue_name ) -> ParsingResult:
insertion_code = atom.insertion_code """Entry point, parses an mmcif_string.
if not _is_set(atom.insertion_code):
insertion_code = ' ' Args:
position = ResiduePosition(chain_id=atom.author_chain_id, file_id: A string identifier for this file. Should be unique within the
residue_number=int(atom.author_seq_num), collection of files being processed.
insertion_code=insertion_code) mmcif_string: Contents of an mmCIF file.
seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] catch_all_errors: If True, all exceptions are caught and error messages are
current = seq_to_structure_mappings.get(atom.author_chain_id, {}) returned as part of the ParsingResult. If False exceptions will be allowed
current[seq_idx] = ResidueAtPosition(position=position, to propagate.
name=atom.residue_name,
is_missing=False, Returns:
hetflag=hetflag) A ParsingResult.
seq_to_structure_mappings[atom.author_chain_id] = current """
errors = {}
# Add missing residue information to seq_to_structure_mappings. try:
for chain_id, seq_info in valid_chains.items(): parser = PDB.MMCIFParser(QUIET=True)
author_chain = mmcif_to_author_chain_id[chain_id] handle = io.StringIO(mmcif_string)
current_mapping = seq_to_structure_mappings[author_chain] full_structure = parser.get_structure("", handle)
for idx, monomer in enumerate(seq_info): first_model_structure = _get_first_model(full_structure)
if idx not in current_mapping: # Extract the _mmcif_dict from the parser, which contains useful fields not
current_mapping[idx] = ResidueAtPosition(position=None, # reflected in the Biopython structure.
name=monomer.id, parsed_info = parser._mmcif_dict # pylint:disable=protected-access
is_missing=True,
hetflag=' ') # Ensure all values are lists, even if singletons.
for key, value in parsed_info.items():
author_chain_to_sequence = {} if not isinstance(value, list):
for chain_id, seq_info in valid_chains.items(): parsed_info[key] = [value]
author_chain = mmcif_to_author_chain_id[chain_id]
seq = [] header = _get_header(parsed_info)
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') # Determine the protein chains, and their start numbers according to the
seq.append(code if len(code) == 1 else 'X') # internal mmCIF numbering scheme (likely but not guaranteed to be 1).
seq = ''.join(seq) valid_chains = _get_protein_chains(parsed_info=parsed_info)
author_chain_to_sequence[author_chain] = seq if not valid_chains:
return ParsingResult(
mmcif_object = MmcifObject( None, {(file_id, ""): "No protein chains found in this file."}
file_id=file_id, )
header=header, seq_start_num = {
structure=first_model_structure, chain_id: min([monomer.num for monomer in seq])
chain_to_seqres=author_chain_to_sequence, for chain_id, seq in valid_chains.items()
seqres_to_structure=seq_to_structure_mappings, }
raw_string=parsed_info)
# Loop over the atoms for which we have coordinates. Populate two mappings:
return ParsingResult(mmcif_object=mmcif_object, errors=errors) # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
except Exception as e: # pylint:disable=broad-except # the authors / Biopython).
errors[(file_id, '')] = e # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition).
if not catch_all_errors: mmcif_to_author_chain_id = {}
raise seq_to_structure_mappings = {}
return ParsingResult(mmcif_object=None, errors=errors) for atom in _get_atom_site_list(parsed_info):
if atom.model_num != "1":
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = " "
if atom.hetatm_atom == "HETATM":
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ("HOH", "WAT"):
hetflag = "W"
else:
hetflag = "H_" + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = " "
position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code,
)
seq_idx = (
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
)
current = seq_to_structure_mappings.get(
atom.author_chain_id, {}
)
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id,
is_missing=True,
hetflag=" ",
)
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X")
seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
file_id=file_id,
header=header,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info,
)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, "")] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
def _get_first_model(structure: PdbStructure) -> PdbStructure: def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure.""" """Returns the first model in a Biopython structure."""
return next(structure.get_models()) return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 _MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str: def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date.""" """Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
return min(revision_dates) return min(revision_dates)
def _get_header(parsed_info: MmCIFDict) -> PdbHeader: def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution.""" """Returns a basic header containing method, release date and resolution."""
header = {} header = {}
experiments = mmcif_loop_to_list('_exptl.', parsed_info) experiments = mmcif_loop_to_list("_exptl.", parsed_info)
header['structure_method'] = ','.join([ header["structure_method"] = ",".join(
experiment['_exptl.method'].lower() for experiment in experiments]) [experiment["_exptl.method"].lower() for experiment in experiments]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date. # Note: The release_date here corresponds to the oldest revision. We prefer to
if '_pdbx_audit_revision_history.revision_date' in parsed_info: # use this for dataset filtering over the deposition_date.
header['release_date'] = get_release_date(parsed_info) if "_pdbx_audit_revision_history.revision_date" in parsed_info:
else: header["release_date"] = get_release_date(parsed_info)
logging.warning('Could not determine release_date: %s', else:
parsed_info['_entry.id']) logging.warning(
"Could not determine release_date: %s", parsed_info["_entry.id"]
header['resolution'] = 0.00 )
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution',
'_reflns.d_resolution_high'): header["resolution"] = 0.00
if res_key in parsed_info: for res_key in (
try: "_refine.ls_d_res_high",
raw_resolution = parsed_info[res_key][0] "_em_3d_reconstruction.resolution",
header['resolution'] = float(raw_resolution) "_reflns.d_resolution_high",
except ValueError: ):
logging.warning('Invalid resolution format: %s', parsed_info[res_key]) if res_key in parsed_info:
try:
return header raw_resolution = parsed_info[res_key][0]
header["resolution"] = float(raw_resolution)
except ValueError:
logging.warning(
"Invalid resolution format: %s", parsed_info[res_key]
)
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure.""" """Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension return [
parsed_info['_atom_site.label_comp_id'], AtomSite(*site)
parsed_info['_atom_site.auth_asym_id'], for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_asym_id'], parsed_info["_atom_site.label_comp_id"],
parsed_info['_atom_site.auth_seq_id'], parsed_info["_atom_site.auth_asym_id"],
parsed_info['_atom_site.label_seq_id'], parsed_info["_atom_site.label_asym_id"],
parsed_info['_atom_site.pdbx_PDB_ins_code'], parsed_info["_atom_site.auth_seq_id"],
parsed_info['_atom_site.group_PDB'], parsed_info["_atom_site.label_seq_id"],
parsed_info['_atom_site.pdbx_PDB_model_num'], parsed_info["_atom_site.pdbx_PDB_ins_code"],
)] parsed_info["_atom_site.group_PDB"],
parsed_info["_atom_site.pdbx_PDB_model_num"],
)
]
def _get_protein_chains( def _get_protein_chains(
*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]: *, parsed_info: Mapping[str, Any]
"""Extracts polymer information for protein chains only. ) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
parsed_info: _mmcif_dict produced by the Biopython parser. Args:
parsed_info: _mmcif_dict produced by the Biopython parser.
Returns:
A dict mapping mmcif chain id to a list of Monomers. Returns:
""" A dict mapping mmcif chain id to a list of Monomers.
# Get polymer information for each entity in the structure. """
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) # Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs: polymers = collections.defaultdict(list)
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( for entity_poly_seq in entity_poly_seqs:
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'], polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
num=int(entity_poly_seq['_entity_poly_seq.num']))) Monomer(
id=entity_poly_seq["_entity_poly_seq.mon_id"],
# Get chemical compositions. Will allow us to identify which of these polymers num=int(entity_poly_seq["_entity_poly_seq.num"]),
# are proteins. )
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info) )
# Get chains information for each entity. Necessary so that we can return a # Get chemical compositions. Will allow us to identify which of these polymers
# dict keyed on chain id rather than entity. # are proteins.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list) # Get chains information for each entity. Necessary so that we can return a
for struct_asym in struct_asyms: # dict keyed on chain id rather than entity.
chain_id = struct_asym['_struct_asym.id'] struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
entity_id = struct_asym['_struct_asym.entity_id']
entity_to_mmcif_chains[entity_id].append(chain_id) entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
# Identify and return the valid protein chains. chain_id = struct_asym["_struct_asym.id"]
valid_chains = {} entity_id = struct_asym["_struct_asym.entity_id"]
for entity_id, seq_info in polymers.items(): entity_to_mmcif_chains[entity_id].append(chain_id)
chain_ids = entity_to_mmcif_chains[entity_id]
# Identify and return the valid protein chains.
# Reject polymers without any peptide-like components, such as DNA/RNA. valid_chains = {}
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'] for entity_id, seq_info in polymers.items():
for monomer in seq_info]): chain_ids = entity_to_mmcif_chains[entity_id]
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info # Reject polymers without any peptide-like components, such as DNA/RNA.
return valid_chains if any(
[
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
for monomer in seq_info
]
):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
def _is_set(data: str) -> bool: def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'.""" """Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?') return data not in (".", "?")
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, mmcif_object: MmcifObject, chain_id: str
chain_id: str
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id] relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1: if len(relevant_chains) != 1:
raise MultipleChainsError( raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.' f"Expected exactly one chain in structure with id {chain_id}."
) )
chain = relevant_chains[0] chain = relevant_chains[0]
...@@ -417,19 +453,23 @@ def get_atom_coords( ...@@ -417,19 +453,23 @@ def get_atom_coords(
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index] res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing: if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag, res = chain[
res_at_position.position.residue_number, (
res_at_position.position.insertion_code)] res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code,
)
]
for atom in res.get_atoms(): for atom in res.get_atoms():
atom_name = atom.get_name() atom_name = atom.get_name()
x, y, z = atom.get_coord() x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys(): if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z] pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0 mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
# Put the coords of the selenium atom in the sulphur column # Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z] pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0 mask[residue_constants.atom_order["SD"]] = 1.0
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
...@@ -440,22 +480,22 @@ def get_atom_coords( ...@@ -440,22 +480,22 @@ def get_atom_coords(
def generate_mmcif_cache(mmcif_dir: str, out_path: str): def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {} data = {}
for f in os.listdir(mmcif_dir): for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')): if f.endswith(".cif"):
with open(os.path.join(mmcif_dir, f), 'r') as fp: with open(os.path.join(mmcif_dir, f), "r") as fp:
mmcif_string = fp.read() mmcif_string = fp.read()
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string) mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None): if mmcif.mmcif_object is None:
logging.warning(f'Could not parse {f}. Skipping...') logging.warning(f"Could not parse {f}. Skipping...")
continue continue
else: else:
mmcif = mmcif.mmcif_object mmcif = mmcif.mmcif_object
local_data = {} local_data = {}
local_data['release_date'] = mmcif.header["release_date"] local_data["release_date"] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains())) local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data data[file_id] = local_data
with open(out_path, 'w') as fp: with open(out_path, "w") as fp:
fp.write(json.dumps(data)) fp.write(json.dumps(data))
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple ...@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateHit: class TemplateHit:
"""Class representing a template hit.""" """Class representing a template hit."""
index: int index: int
name: str name: str
aligned_cols: int aligned_cols: int
...@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index = -1 index = -1
for line in fasta_string.splitlines(): for line in fasta_string.splitlines():
line = line.strip() line = line.strip()
if line.startswith('>'): if line.startswith(">"):
index += 1 index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning. descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('') sequences.append("")
continue continue
elif not line: elif not line:
continue # Skip blank lines. continue # Skip blank lines.
...@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm(stockholm_string: str def parse_stockholm(
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str ...@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence = collections.OrderedDict() name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines(): for line in stockholm_string.splitlines():
line = line.strip() line = line.strip()
if not line or line.startswith(('#', '//')): if not line or line.startswith(("#", "//")):
continue continue
name, sequence = line.split() name, sequence = line.split()
if name not in name_to_sequence: if name not in name_to_sequence:
name_to_sequence[name] = '' name_to_sequence[name] = ""
name_to_sequence[name] += sequence name_to_sequence[name] += sequence
msa = [] msa = []
deletion_matrix = [] deletion_matrix = []
query = '' query = ""
keep_columns = [] keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()): for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0: if seq_index == 0:
# Gather the columns with gaps from the query # Gather the columns with gaps from the query
query = sequence query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-'] keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences. # Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns]) aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence) msa.append(aligned_sequence)
...@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str ...@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec = [] deletion_vec = []
deletion_count = 0 deletion_count = 0
for seq_res, query_res in zip(sequence, query): for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-': if seq_res != "-" or query_res != "-":
if query_res == '-': if query_res == "-":
deletion_count += 1 deletion_count += 1
else: else:
deletion_vec.append(deletion_count) deletion_vec.append(deletion_count)
...@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase) deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap: if is_query_res_non_gap:
yield sequence_res yield sequence_res
elif sequence_res != '-': elif sequence_res != "-":
yield sequence_res.lower() yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str, def convert_stockholm_to_a3m(
max_sequences: Optional[int] = None) -> str: stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format.""" """Converts MSA in Stockholm format to the A3M format."""
descriptions = {} descriptions = {}
sequences = {} sequences = {}
reached_max_sequences = False reached_max_sequences = False
for line in stockholm_format.splitlines(): for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences reached_max_sequences = (
if line.strip() and not line.startswith(('#', '//')): max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment # Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts. # sequence parts.
seqname, aligned_seq = line.split(maxsplit=1) seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences: if seqname not in sequences:
if reached_max_sequences: if reached_max_sequences:
continue continue
sequences[seqname] = '' sequences[seqname] = ""
sequences[seqname] += aligned_seq sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines(): for line in stockholm_format.splitlines():
if line[:4] == '#=GS': if line[:4] == "#=GS":
# Description row - example format is: # Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3) columns = line.split(maxsplit=3)
seqname, feature = columns[1:3] seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else '' value = columns[3] if len(columns) == 4 else ""
if feature != 'DE': if feature != "DE":
continue continue
if reached_max_sequences and seqname not in sequences: if reached_max_sequences and seqname not in sequences:
continue continue
...@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str, ...@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences = {} a3m_sequences = {}
# query_sequence is assumed to be the first sequence # query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values())) query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence] query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items(): for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join( a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" fasta_chunks = (
for k in a3m_sequences) f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups( def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]: regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line) match = re.match(regex_pattern, line)
if match is None: if match is None:
raise RuntimeError(f'Could not parse query line {line}') raise RuntimeError(f"Could not parse query line {line}")
return match.groups() return match.groups()
def _update_hhr_residue_indices_list( def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]): sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence.""" """Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index counter = start_index
for symbol in sequence: for symbol in sequence:
if symbol == '-': if symbol == "-":
indices_list.append(-1) indices_list.append(-1)
else: else:
indices_list.append(counter) indices_list.append(counter)
...@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line. # Parse the summary line.
pattern = ( pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
']*Template_Neff=(.*)') "]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2]) match = re.match(pattern, detailed_lines[2])
if match is None: if match is None:
raise RuntimeError( raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' % "Could not parse section: %s. Expected this: \n%s to contain summary."
(detailed_lines, detailed_lines[2])) % (detailed_lines, detailed_lines[2])
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, )
neff) = [float(x) for x in match.groups()] (prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human # The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to # readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse # assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block. # that with a regexp in order to deduce the fixed length used for that block.
query = '' query = ""
hit_sequence = '' hit_sequence = ""
indices_query = [] indices_query = []
indices_hit = [] indices_hit = []
length_block = None length_block = None
for line in detailed_lines[3:]: for line in detailed_lines[3:]:
# Parse the query sequence line # Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and if (
not line.startswith('Q ss_pred') and line.startswith("Q ")
not line.startswith('Q Consensus')): and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse # Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that. # everything after that.
# start sequence end total_sequence_length # start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:]) groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices, # Get the length of the parsed block using the start and finish indices,
...@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start = int(groups[0]) - 1 # Make index zero based. start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1] delta_query = groups[1]
end = int(groups[2]) end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-']) num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions length_block = end - start + num_insertions
assert length_block == len(delta_query) assert length_block == len(delta_query)
...@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query += delta_query query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query) _update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '): elif line.startswith("T "):
# Parse the hit sequence. # Parse the hit sequence.
if (not line.startswith('T ss_dssp') and if (
not line.startswith('T ss_pred') and not line.startswith("T ss_dssp")
not line.startswith('T Consensus')): and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can # Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that. # parse everything after that.
# start sequence end total_sequence_length # start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:]) groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based. start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1] delta_hit_sequence = groups[1]
...@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list. # Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) _update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit( return TemplateHit(
index=number_of_hit, index=number_of_hit,
...@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: ...@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We # "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit. # iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = [] hits = []
if block_starts: if block_starts:
block_starts.append(len(lines)) # Add the end of the final block. block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1): for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string.""" """Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0} e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != '#'] lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and # space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1). # (5) E-value (full sequence) (numbering from 1).
......
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -89,47 +89,50 @@ class LengthError(PrefilterError): ...@@ -89,47 +89,50 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES = { TEMPLATE_FEATURES = {
'template_aatype': np.int64, "template_aatype": np.int64,
'template_all_atom_mask': np.float32, "template_all_atom_mask": np.float32,
'template_all_atom_positions': np.float32, "template_all_atom_positions": np.float32,
'template_domain_names': np.object, "template_domain_names": np.object,
'template_sequence': np.object, "template_sequence": np.object,
'template_sum_probs': np.float32, "template_sum_probs": np.float32,
} }
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit.""" """Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name) id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
if not id_match: if not id_match:
raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}') raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
pdb_id, chain_id = id_match.group(0).split('_') pdb_id, chain_id = id_match.group(0).split("_")
return pdb_id.lower(), chain_id return pdb_id.lower(), chain_id
def _is_after_cutoff( def _is_after_cutoff(
pdb_id: str, pdb_id: str,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: Optional[datetime.datetime]) -> bool: release_date_cutoff: Optional[datetime.datetime],
) -> bool:
"""Checks if the template date is after the release date cutoff. """Checks if the template date is after the release date cutoff.
Args: Args:
pdb_id: 4 letter pdb code. pdb_id: 4 letter pdb code.
release_dates: Dictionary mapping PDB ids to their structure release dates. release_dates: Dictionary mapping PDB ids to their structure release dates.
release_date_cutoff: Max release date that is valid for this query. release_date_cutoff: Max release date that is valid for this query.
Returns: Returns:
True if the template release date is after the cutoff, False otherwise. True if the template release date is after the cutoff, False otherwise.
""" """
if release_date_cutoff is None: if release_date_cutoff is None:
raise ValueError('The release_date_cutoff must not be None.') raise ValueError("The release_date_cutoff must not be None.")
if pdb_id in release_dates: if pdb_id in release_dates:
return release_dates[pdb_id] > release_date_cutoff return release_dates[pdb_id] > release_date_cutoff
else: else:
# Since this is just a quick prefilter to reduce the number of mmCIF files # Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here. # we need to parse, we don't have to worry about returning True here.
logging.warning('Template structure not in release dates dict: %s', pdb_id) logging.warning(
"Template structure not in release dates dict: %s", pdb_id
)
return False return False
...@@ -140,7 +143,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -140,7 +143,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
for line in f: for line in f:
line = line.strip() line = line.strip()
# We skip obsolete entries that don't contain a mapping to a new entry. # We skip obsolete entries that don't contain a mapping to a new entry.
if line.startswith('OBSLTE') and len(line) > 30: if line.startswith("OBSLTE") and len(line) > 30:
# Format: Date From To # Format: Date From To
# 'OBSLTE 31-JUL-94 116L 216L' # 'OBSLTE 31-JUL-94 116L 216L'
from_id = line[20:24].lower() from_id = line[20:24].lower()
...@@ -152,47 +155,51 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -152,47 +155,51 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
def generate_release_dates_cache(mmcif_dir: str, out_path: str): def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {} dates = {}
for f in os.listdir(mmcif_dir): for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')): if f.endswith(".cif"):
path = os.path.join(mmcif_dir, f) path = os.path.join(mmcif_dir, f)
with open(path, 'r') as fp: with open(path, "r") as fp:
mmcif_string = fp.read() mmcif_string = fp.read()
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse( mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string file_id=file_id, mmcif_string=mmcif_string
) )
if(mmcif.mmcif_object is None): if mmcif.mmcif_object is None:
logging.warning(f'Failed to parse {f}. Skipping...') logging.warning(f"Failed to parse {f}. Skipping...")
continue continue
mmcif = mmcif.mmcif_object mmcif = mmcif.mmcif_object
release_date = mmcif.header['release_date'] release_date = mmcif.header["release_date"]
dates[file_id] = release_date dates[file_id] = release_date
with open(out_path, 'r') as fp: with open(out_path, "r") as fp:
fp.write(json.dumps(dates)) fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates.""" """Parses release dates file, returns a mapping from PDBs to release dates."""
with open(path, 'r') as fp: with open(path, "r") as fp:
data = json.load(fp) data = json.load(fp)
return { return {
pdb:to_date(v) for pdb,d in data.items() for k,v in d.items() pdb: to_date(v)
for pdb, d in data.items()
for k, v in d.items()
if k == "release_date" if k == "release_date"
} }
def _assess_hhsearch_hit( def _assess_hhsearch_hit(
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
hit_pdb_code: str, hit_pdb_code: str,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime, release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95, max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1) -> bool: min_align_ratio: float = 0.1,
) -> bool:
"""Determines if template is valid (without parsing the template mmcif file). """Determines if template is valid (without parsing the template mmcif file).
Args: Args:
...@@ -221,40 +228,51 @@ def _assess_hhsearch_hit( ...@@ -221,40 +228,51 @@ def _assess_hhsearch_hit(
aligned_cols = hit.aligned_cols aligned_cols = hit.aligned_cols
align_ratio = aligned_cols / len(query_sequence) align_ratio = aligned_cols / len(query_sequence)
template_sequence = hit.hit_sequence.replace('-', '') template_sequence = hit.hit_sequence.replace("-", "")
length_ratio = float(len(template_sequence)) / len(query_sequence) length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original # Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database. # query. This can happen due to duplicate entries in the PDB database.
duplicate = (template_sequence in query_sequence and duplicate = (
length_ratio > max_subsequence_ratio) template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff): if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date ' raise DateError(
f'({release_date_cutoff}).') f"Date ({release_dates[hit_pdb_code]}) > max template date "
f"({release_date_cutoff})."
)
if query_pdb_code is not None: if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower(): if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError('PDB code identical to Query PDB code.') raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio: if align_ratio <= min_align_ratio:
raise AlignRatioError('Proportion of residues aligned to query too small. ' raise AlignRatioError(
f'Align ratio: {align_ratio}.') "Proportion of residues aligned to query too small. "
f"Align ratio: {align_ratio}."
)
if duplicate: if duplicate:
raise DuplicateError('Template is an exact subsequence of query with large ' raise DuplicateError(
f'coverage. Length ratio: {length_ratio}.') "Template is an exact subsequence of query with large "
f"coverage. Length ratio: {length_ratio}."
)
if len(template_sequence) < 10: if len(template_sequence) < 10:
raise LengthError(f'Template too short. Length: {len(template_sequence)}.') raise LengthError(
f"Template too short. Length: {len(template_sequence)}."
)
return True return True
def _find_template_in_pdb( def _find_template_in_pdb(
template_chain_id: str, template_chain_id: str,
template_sequence: str, template_sequence: str,
mmcif_object: mmcif_parsing.MmcifObject) -> Tuple[str, str, int]: mmcif_object: mmcif_parsing.MmcifObject,
) -> Tuple[str, str, int]:
"""Tries to find the template chain in the given pdb file. """Tries to find the template chain in the given pdb file.
This method tries the three following things in order: This method tries the three following things in order:
...@@ -286,41 +304,51 @@ def _find_template_in_pdb( ...@@ -286,41 +304,51 @@ def _find_template_in_pdb(
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id) chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
if chain_sequence and (template_sequence in chain_sequence): if chain_sequence and (template_sequence in chain_sequence):
logging.info( logging.info(
'Found an exact template match %s_%s.', pdb_id, template_chain_id) "Found an exact template match %s_%s.", pdb_id, template_chain_id
)
mapping_offset = chain_sequence.find(template_sequence) mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, template_chain_id, mapping_offset return chain_sequence, template_chain_id, mapping_offset
# Try if there is an exact match in the (sub)sequence only. # Try if there is an exact match in the (sub)sequence only.
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
if chain_sequence and (template_sequence in chain_sequence): if chain_sequence and (template_sequence in chain_sequence):
logging.info('Found a sequence-only match %s_%s.', pdb_id, chain_id) logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
mapping_offset = chain_sequence.find(template_sequence) mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, chain_id, mapping_offset return chain_sequence, chain_id, mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template. # Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit. # Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence] regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
regex = re.compile(''.join(regex)) regex = re.compile("".join(regex))
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items(): for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
match = re.search(regex, chain_sequence) match = re.search(regex, chain_sequence)
if match: if match:
logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, chain_id) logging.info(
"Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
)
mapping_offset = match.start() mapping_offset = match.start()
return chain_sequence, chain_id, mapping_offset return chain_sequence, chain_id, mapping_offset
# No hits, raise an error. # No hits, raise an error.
raise SequenceNotInTemplateError( raise SequenceNotInTemplateError(
'Could not find the template sequence in %s_%s. Template sequence: %s, ' "Could not find the template sequence in %s_%s. Template sequence: %s, "
'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence, "chain_to_seqres: %s"
mmcif_object.chain_to_seqres)) % (
pdb_id,
template_chain_id,
template_sequence,
mmcif_object.chain_to_seqres,
)
)
def _realign_pdb_template_to_query( def _realign_pdb_template_to_query(
old_template_sequence: str, old_template_sequence: str,
template_chain_id: str, template_chain_id: str,
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
old_mapping: Mapping[int, int], old_mapping: Mapping[int, int],
kalign_binary_path: str) -> Tuple[str, Mapping[int, int]]: kalign_binary_path: str,
) -> Tuple[str, Mapping[int, int]]:
"""Aligns template from the mmcif_object to the query. """Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need In case PDB70 contains a different version of the template sequence, we need
...@@ -361,76 +389,104 @@ def _realign_pdb_template_to_query( ...@@ -361,76 +389,104 @@ def _realign_pdb_template_to_query(
""" """
aligner = kalign.Kalign(binary_path=kalign_binary_path) aligner = kalign.Kalign(binary_path=kalign_binary_path)
new_template_sequence = mmcif_object.chain_to_seqres.get( new_template_sequence = mmcif_object.chain_to_seqres.get(
template_chain_id, '') template_chain_id, ""
)
# Sometimes the template chain id is unknown. But if there is only a single # Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one. # sequence within the mmcif_object, it is safe to assume it is that one.
if not new_template_sequence: if not new_template_sequence:
if len(mmcif_object.chain_to_seqres) == 1: if len(mmcif_object.chain_to_seqres) == 1:
logging.info('Could not find %s in %s, but there is only 1 sequence, so ' logging.info(
'using that one.', "Could not find %s in %s, but there is only 1 sequence, so "
template_chain_id, "using that one.",
mmcif_object.file_id) template_chain_id,
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0] mmcif_object.file_id,
)
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
0
]
else: else:
raise QueryToTemplateAlignError( raise QueryToTemplateAlignError(
f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. ' f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
'If there are no mmCIF parsing errors, it is possible it was not a ' "If there are no mmCIF parsing errors, it is possible it was not a "
'protein chain.') "protein chain."
)
try: try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m( (old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence])) aligner.align([old_template_sequence, new_template_sequence])
)
except Exception as e: except Exception as e:
raise QueryToTemplateAlignError( raise QueryToTemplateAlignError(
'Could not align old template %s to template %s (%s_%s). Error: %s' % "Could not align old template %s to template %s (%s_%s). Error: %s"
(old_template_sequence, new_template_sequence, mmcif_object.file_id, % (
template_chain_id, str(e))) old_template_sequence,
new_template_sequence,
mmcif_object.file_id,
template_chain_id,
str(e),
)
)
logging.info('Old aligned template: %s\nNew aligned template: %s', logging.info(
old_aligned_template, new_aligned_template) "Old aligned template: %s\nNew aligned template: %s",
old_aligned_template,
new_aligned_template,
)
old_to_new_template_mapping = {} old_to_new_template_mapping = {}
old_template_index = -1 old_template_index = -1
new_template_index = -1 new_template_index = -1
num_same = 0 num_same = 0
for old_template_aa, new_template_aa in zip( for old_template_aa, new_template_aa in zip(
old_aligned_template, new_aligned_template): old_aligned_template, new_aligned_template
if old_template_aa != '-': ):
if old_template_aa != "-":
old_template_index += 1 old_template_index += 1
if new_template_aa != '-': if new_template_aa != "-":
new_template_index += 1 new_template_index += 1
if old_template_aa != '-' and new_template_aa != '-': if old_template_aa != "-" and new_template_aa != "-":
old_to_new_template_mapping[old_template_index] = new_template_index old_to_new_template_mapping[old_template_index] = new_template_index
if old_template_aa == new_template_aa: if old_template_aa == new_template_aa:
num_same += 1 num_same += 1
# Require at least 90 % sequence identity wrt to the shorter of the sequences. # Require at least 90 % sequence identity wrt to the shorter of the sequences.
if float(num_same) / min( if (
len(old_template_sequence), len(new_template_sequence)) < 0.9: float(num_same)
/ min(len(old_template_sequence), len(new_template_sequence))
< 0.9
):
raise QueryToTemplateAlignError( raise QueryToTemplateAlignError(
'Insufficient similarity of the sequence in the database: %s to the ' "Insufficient similarity of the sequence in the database: %s to the "
'actual sequence in the mmCIF file %s_%s: %s. We require at least ' "actual sequence in the mmCIF file %s_%s: %s. We require at least "
'90 %% similarity wrt to the shorter of the sequences. This is not a ' "90 %% similarity wrt to the shorter of the sequences. This is not a "
'problem unless you think this is a template that should be included.' % "problem unless you think this is a template that should be included."
(old_template_sequence, mmcif_object.file_id, template_chain_id, % (
new_template_sequence)) old_template_sequence,
mmcif_object.file_id,
template_chain_id,
new_template_sequence,
)
)
new_query_to_template_mapping = {} new_query_to_template_mapping = {}
for query_index, old_template_index in old_mapping.items(): for query_index, old_template_index in old_mapping.items():
new_query_to_template_mapping[query_index] = ( new_query_to_template_mapping[
old_to_new_template_mapping.get(old_template_index, -1)) query_index
] = old_to_new_template_mapping.get(old_template_index, -1)
new_template_sequence = new_template_sequence.replace('-', '') new_template_sequence = new_template_sequence.replace("-", "")
return new_template_sequence, new_query_to_template_mapping return new_template_sequence, new_query_to_template_mapping
def _check_residue_distances(all_positions: np.ndarray, def _check_residue_distances(
all_positions_mask: np.ndarray, all_positions: np.ndarray,
max_ca_ca_distance: float): all_positions_mask: np.ndarray,
max_ca_ca_distance: float,
):
"""Checks if the distance between unmasked neighbor residues is ok.""" """Checks if the distance between unmasked neighbor residues is ok."""
ca_position = residue_constants.atom_order['CA'] ca_position = residue_constants.atom_order["CA"]
prev_is_unmasked = False prev_is_unmasked = False
prev_calpha = None prev_calpha = None
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)): for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
...@@ -441,16 +497,18 @@ def _check_residue_distances(all_positions: np.ndarray, ...@@ -441,16 +497,18 @@ def _check_residue_distances(all_positions: np.ndarray,
distance = np.linalg.norm(this_calpha - prev_calpha) distance = np.linalg.norm(this_calpha - prev_calpha)
if distance > max_ca_ca_distance: if distance > max_ca_ca_distance:
raise CaDistanceError( raise CaDistanceError(
'The distance between residues %d and %d is %f > limit %f.' % ( "The distance between residues %d and %d is %f > limit %f."
i, i + 1, distance, max_ca_ca_distance)) % (i, i + 1, distance, max_ca_ca_distance)
)
prev_calpha = this_calpha prev_calpha = this_calpha
prev_is_unmasked = this_is_unmasked prev_is_unmasked = this_is_unmasked
def _get_atom_positions( def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]: max_ca_ca_distance: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords( coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id mmcif_object=mmcif_object, chain_id=auth_chain_id
...@@ -463,13 +521,14 @@ def _get_atom_positions( ...@@ -463,13 +521,14 @@ def _get_atom_positions(
def _extract_template_features( def _extract_template_features(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject,
pdb_id: str, pdb_id: str,
mapping: Mapping[int, int], mapping: Mapping[int, int],
template_sequence: str, template_sequence: str,
query_sequence: str, query_sequence: str,
template_chain_id: str, template_chain_id: str,
kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]: kalign_binary_path: str,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query. """Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide Atoms for each residue in the template structure are indexed to coincide
...@@ -509,21 +568,25 @@ def _extract_template_features( ...@@ -509,21 +568,25 @@ def _extract_template_features(
unmasked residues. unmasked residues.
""" """
if mmcif_object is None or not mmcif_object.chain_to_seqres: if mmcif_object is None or not mmcif_object.chain_to_seqres:
raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id)) raise NoChainsError(
"No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
)
warning = None warning = None
try: try:
seqres, chain_id, mapping_offset = _find_template_in_pdb( seqres, chain_id, mapping_offset = _find_template_in_pdb(
template_chain_id=template_chain_id, template_chain_id=template_chain_id,
template_sequence=template_sequence, template_sequence=template_sequence,
mmcif_object=mmcif_object) mmcif_object=mmcif_object,
)
except SequenceNotInTemplateError: except SequenceNotInTemplateError:
# If PDB70 contains a different version of the template, we use the sequence # If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object. # from the mmcif_object.
chain_id = template_chain_id chain_id = template_chain_id
warning = ( warning = (
f'The exact sequence {template_sequence} was not found in ' f"The exact sequence {template_sequence} was not found in "
f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.') f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
)
logging.warning(warning) logging.warning(warning)
# This throws an exception if it fails to realign the hit. # This throws an exception if it fails to realign the hit.
seqres, mapping = _realign_pdb_template_to_query( seqres, mapping = _realign_pdb_template_to_query(
...@@ -531,9 +594,15 @@ def _extract_template_features( ...@@ -531,9 +594,15 @@ def _extract_template_features(
template_chain_id=template_chain_id, template_chain_id=template_chain_id,
mmcif_object=mmcif_object, mmcif_object=mmcif_object,
old_mapping=mapping, old_mapping=mapping,
kalign_binary_path=kalign_binary_path) kalign_binary_path=kalign_binary_path,
logging.info('Sequence in %s_%s: %s successfully realigned to %s', )
pdb_id, chain_id, template_sequence, seqres) logging.info(
"Sequence in %s_%s: %s successfully realigned to %s",
pdb_id,
chain_id,
template_sequence,
seqres,
)
# The template sequence changed. # The template sequence changed.
template_sequence = seqres template_sequence = seqres
# No mapping offset, the query is aligned to the actual sequence. # No mapping offset, the query is aligned to the actual sequence.
...@@ -543,13 +612,16 @@ def _extract_template_features( ...@@ -543,13 +612,16 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless # Essentially set to infinity - we don't want to reject templates unless
# they're really really bad. # they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions( all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0) mmcif_object, chain_id, max_ca_ca_distance=150.0
)
except (CaDistanceError, KeyError) as ex: except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError( raise NoAtomDataInTemplateError(
'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex)) "Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
) from ex ) from ex
all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0]) all_atom_positions = np.split(
all_atom_positions, all_atom_positions.shape[0]
)
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0]) all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
output_templates_sequence = [] output_templates_sequence = []
...@@ -559,9 +631,12 @@ def _extract_template_features( ...@@ -559,9 +631,12 @@ def _extract_template_features(
for _ in query_sequence: for _ in query_sequence:
# Residues in the query_sequence that are not in the template_sequence: # Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions.append( templates_all_atom_positions.append(
np.zeros((residue_constants.atom_type_num, 3))) np.zeros((residue_constants.atom_type_num, 3))
templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num)) )
output_templates_sequence.append('-') templates_all_atom_masks.append(
np.zeros(residue_constants.atom_type_num)
)
output_templates_sequence.append("-")
for k, v in mapping.items(): for k, v in mapping.items():
template_index = v + mapping_offset template_index = v + mapping_offset
...@@ -572,32 +647,42 @@ def _extract_template_features( ...@@ -572,32 +647,42 @@ def _extract_template_features(
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O). # Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if np.sum(templates_all_atom_masks) < 5: if np.sum(templates_all_atom_masks) < 5:
raise TemplateAtomMaskAllZerosError( raise TemplateAtomMaskAllZerosError(
'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' % "Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
(pdb_id, chain_id, min(mapping.values()) + mapping_offset, % (
max(mapping.values()) + mapping_offset)) pdb_id,
chain_id,
min(mapping.values()) + mapping_offset,
max(mapping.values()) + mapping_offset,
)
)
output_templates_sequence = ''.join(output_templates_sequence) output_templates_sequence = "".join(output_templates_sequence)
templates_aatype = residue_constants.sequence_to_onehot( templates_aatype = residue_constants.sequence_to_onehot(
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID) output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
)
return ( return (
{ {
'template_all_atom_positions': np.array(templates_all_atom_positions), "template_all_atom_positions": np.array(
'template_all_atom_mask': np.array(templates_all_atom_masks), templates_all_atom_positions
'template_sequence': output_templates_sequence.encode(), ),
'template_aatype': np.array(templates_aatype), "template_all_atom_mask": np.array(templates_all_atom_masks),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), "template_sequence": output_templates_sequence.encode(),
"template_aatype": np.array(templates_aatype),
"template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
}, },
warning) warning,
)
def _build_query_to_hit_index_mapping( def _build_query_to_hit_index_mapping(
hit_query_sequence: str, hit_query_sequence: str,
hit_sequence: str, hit_sequence: str,
indices_hit: Sequence[int], indices_hit: Sequence[int],
indices_query: Sequence[int], indices_query: Sequence[int],
original_query_sequence: str) -> Mapping[int, int]: original_query_sequence: str,
) -> Mapping[int, int]:
"""Gets mapping from indices in original query sequence to indices in the hit. """Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap hit_query_sequence and hit_sequence are two aligned sequences containing gap
...@@ -624,15 +709,15 @@ def _build_query_to_hit_index_mapping( ...@@ -624,15 +709,15 @@ def _build_query_to_hit_index_mapping(
return {} return {}
# Remove gaps and find the offset of hit.query relative to original query. # Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence = hit_query_sequence.replace('-', '') hhsearch_query_sequence = hit_query_sequence.replace("-", "")
hit_sequence = hit_sequence.replace('-', '') hit_sequence = hit_sequence.replace("-", "")
hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence) hhsearch_query_offset = original_query_sequence.find(
hhsearch_query_sequence
)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps. # Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx = min(x for x in indices_hit if x > -1) min_idx = min(x for x in indices_hit if x > -1)
fixed_indices_hit = [ fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
x - min_idx if x > -1 else -1 for x in indices_hit
]
min_idx = min(x for x in indices_query if x > -1) min_idx = min(x for x in indices_query if x > -1)
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query] fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
...@@ -641,8 +726,9 @@ def _build_query_to_hit_index_mapping( ...@@ -641,8 +726,9 @@ def _build_query_to_hit_index_mapping(
mapping = {} mapping = {}
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit): for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
if q_t != -1 and q_i != -1: if q_t != -1 and q_i != -1:
if (q_t >= len(hit_sequence) or if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
q_i + hhsearch_query_offset >= len(original_query_sequence)): original_query_sequence
):
continue continue
mapping[q_i + hhsearch_query_offset] = q_t mapping[q_i + hhsearch_query_offset] = q_t
...@@ -657,15 +743,16 @@ class SingleHitResult: ...@@ -657,15 +743,16 @@ class SingleHitResult:
def _process_single_hit( def _process_single_hit(
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
mmcif_dir: str, mmcif_dir: str,
max_template_date: datetime.datetime, max_template_date: datetime.datetime,
release_dates: Mapping[str, datetime.datetime], release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str], obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str, kalign_binary_path: str,
strict_error_check: bool = False) -> SingleHitResult: strict_error_check: bool = False,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit.""" """Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit. # Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit) hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
...@@ -682,41 +769,56 @@ def _process_single_hit( ...@@ -682,41 +769,56 @@ def _process_single_hit(
query_sequence=query_sequence, query_sequence=query_sequence,
query_pdb_code=query_pdb_code, query_pdb_code=query_pdb_code,
release_dates=release_dates, release_dates=release_dates,
release_date_cutoff=max_template_date) release_date_cutoff=max_template_date,
)
except PrefilterError as e: except PrefilterError as e:
msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}' msg = f"hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}"
logging.info('%s: %s', query_pdb_code, msg) logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance( if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)): e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors. # In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None) return SingleHitResult(features=None, error=msg, warning=None)
return SingleHitResult(features=None, error=None, warning=None) return SingleHitResult(features=None, error=None, warning=None)
mapping = _build_query_to_hit_index_mapping( mapping = _build_query_to_hit_index_mapping(
hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query, hit.query,
query_sequence) hit.hit_sequence,
hit.indices_hit,
hit.indices_query,
query_sequence,
)
# The mapping is from the query to the actual hit sequence, so we need to # The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score). # remove gaps (which regardless have a missing confidence score).
template_sequence = hit.hit_sequence.replace('-', '') template_sequence = hit.hit_sequence.replace("-", "")
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif') cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
logging.info('Reading PDB entry from %s. Query: %s, template: %s', logging.info(
cif_path, query_sequence, template_sequence) "Reading PDB entry from %s. Query: %s, template: %s",
cif_path,
query_sequence,
template_sequence,
)
# Fail if we can't find the mmCIF file. # Fail if we can't find the mmCIF file.
with open(cif_path, 'r') as cif_file: with open(cif_path, "r") as cif_file:
cif_string = cif_file.read() cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse( parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string) file_id=hit_pdb_code, mmcif_string=cif_string
)
if parsing_result.mmcif_object is not None: if parsing_result.mmcif_object is not None:
hit_release_date = datetime.datetime.strptime( hit_release_date = datetime.datetime.strptime(
parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d') parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
)
if hit_release_date > max_template_date: if hit_release_date > max_template_date:
error = ('Template %s date (%s) > max template date (%s).' % error = "Template %s date (%s) > max template date (%s)." % (
(hit_pdb_code, hit_release_date, max_template_date)) hit_pdb_code,
hit_release_date,
max_template_date,
)
if strict_error_check: if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
else: else:
...@@ -731,31 +833,52 @@ def _process_single_hit( ...@@ -731,31 +833,52 @@ def _process_single_hit(
template_sequence=template_sequence, template_sequence=template_sequence,
query_sequence=query_sequence, query_sequence=query_sequence,
template_chain_id=hit_chain_id, template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path) kalign_binary_path=kalign_binary_path,
features['template_sum_probs'] = [hit.sum_probs] )
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the # It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still # mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant. # computed. In such case the mmCIF parsing errors are not relevant.
return SingleHitResult( return SingleHitResult(
features=features, error=None, warning=realign_warning) features=features, error=None, warning=realign_warning
except (NoChainsError, NoAtomDataInTemplateError, )
TemplateAtomMaskAllZerosError) as e: except (
NoChainsError,
NoAtomDataInTemplateError,
TemplateAtomMaskAllZerosError,
) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a # These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings. # problem with the template search, so turn them into warnings.
warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' warning = (
'%s, mmCIF parsing errors: %s' "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index, "%s, mmCIF parsing errors: %s"
str(e), parsing_result.errors)) % (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
if strict_error_check: if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None) return SingleHitResult(features=None, error=warning, warning=None)
else: else:
return SingleHitResult(features=None, error=None, warning=warning) return SingleHitResult(features=None, error=None, warning=warning)
except Error as e: except Error as e:
error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: ' error = (
'%s, mmCIF parsing errors: %s' "%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index, "%s, mmCIF parsing errors: %s"
str(e), parsing_result.errors)) % (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
return SingleHitResult(features=None, error=error, warning=None) return SingleHitResult(features=None, error=error, warning=None)
...@@ -770,14 +893,15 @@ class TemplateHitFeaturizer: ...@@ -770,14 +893,15 @@ class TemplateHitFeaturizer:
"""A class for turning hhr hits to template features.""" """A class for turning hhr hits to template features."""
def __init__( def __init__(
self, self,
mmcif_dir: str, mmcif_dir: str,
max_template_date: str, max_template_date: str,
max_hits: int, max_hits: int,
kalign_binary_path: str, kalign_binary_path: str,
release_dates_path: Optional[str], release_dates_path: Optional[str],
obsolete_pdbs_path: Optional[str], obsolete_pdbs_path: Optional[str],
strict_error_check: bool = False): strict_error_check: bool = False,
):
"""Initializes the Template Search. """Initializes the Template Search.
Args: Args:
...@@ -800,42 +924,49 @@ class TemplateHitFeaturizer: ...@@ -800,42 +924,49 @@ class TemplateHitFeaturizer:
* If any template has identical PDB ID to the query. * If any template has identical PDB ID to the query.
* If any template is a duplicate of the query. * If any template is a duplicate of the query.
* Any feature computation errors. * Any feature computation errors.
""" """
self._mmcif_dir = mmcif_dir self._mmcif_dir = mmcif_dir
if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')): if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
logging.error('Could not find CIFs in %s', self._mmcif_dir) logging.error("Could not find CIFs in %s", self._mmcif_dir)
raise ValueError(f'Could not find CIFs in {self._mmcif_dir}') raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
try: try:
self._max_template_date = datetime.datetime.strptime( self._max_template_date = datetime.datetime.strptime(
max_template_date, '%Y-%m-%d') max_template_date, "%Y-%m-%d"
)
except ValueError: except ValueError:
raise ValueError( raise ValueError(
'max_template_date must be set and have format YYYY-MM-DD.') "max_template_date must be set and have format YYYY-MM-DD."
)
self._max_hits = max_hits self._max_hits = max_hits
self._kalign_binary_path = kalign_binary_path self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check self._strict_error_check = strict_error_check
if release_dates_path: if release_dates_path:
logging.info('Using precomputed release dates %s.', release_dates_path) logging.info(
"Using precomputed release dates %s.", release_dates_path
)
self._release_dates = _parse_release_dates(release_dates_path) self._release_dates = _parse_release_dates(release_dates_path)
else: else:
self._release_dates = {} self._release_dates = {}
if obsolete_pdbs_path: if obsolete_pdbs_path:
logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path) logging.info(
"Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
)
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path) self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
else: else:
self._obsolete_pdbs = {} self._obsolete_pdbs = {}
def get_templates( def get_templates(
self, self,
query_sequence: str, query_sequence: str,
query_pdb_code: Optional[str], query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime], query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult: hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above).""" """Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_pdb_code) logging.info("Searching for template for: %s", query_pdb_code)
template_features = {} template_features = {}
for template_feature_name in TEMPLATE_FEATURES: for template_feature_name in TEMPLATE_FEATURES:
...@@ -869,7 +1000,8 @@ class TemplateHitFeaturizer: ...@@ -869,7 +1000,8 @@ class TemplateHitFeaturizer:
release_dates=self._release_dates, release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs, obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check, strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path) kalign_binary_path=self._kalign_binary_path,
)
if result.error: if result.error:
errors.append(result.error) errors.append(result.error)
...@@ -880,8 +1012,12 @@ class TemplateHitFeaturizer: ...@@ -880,8 +1012,12 @@ class TemplateHitFeaturizer:
warnings.append(result.warning) warnings.append(result.warning)
if result.features is None: if result.features is None:
logging.info('Skipped invalid hit %s, error: %s, warning: %s', logging.info(
hit.name, result.error, result.warning) "Skipped invalid hit %s, error: %s, warning: %s",
hit.name,
result.error,
result.warning,
)
else: else:
# Increment the hit counter, since we got features out of this hit. # Increment the hit counter, since we got features out of this hit.
num_hits += 1 num_hits += 1
...@@ -891,10 +1027,14 @@ class TemplateHitFeaturizer: ...@@ -891,10 +1027,14 @@ class TemplateHitFeaturizer:
for name in template_features: for name in template_features:
if num_hits > 0: if num_hits > 0:
template_features[name] = np.stack( template_features[name] = np.stack(
template_features[name], axis=0).astype(TEMPLATE_FEATURES[name]) template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else: else:
# Make sure the feature has correct dtype even if empty. # Make sure the feature has correct dtype even if empty.
template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name]) template_features[name] = np.array(
[], dtype=TEMPLATE_FEATURES[name]
)
return TemplateSearchResult( return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings) features=template_features, errors=errors, warnings=warnings
)
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -28,127 +28,148 @@ _HHBLITS_DEFAULT_Z = 500 ...@@ -28,127 +28,148 @@ _HHBLITS_DEFAULT_Z = 500
class HHBlits: class HHBlits:
"""Python wrapper of the HHblits binary.""" """Python wrapper of the HHblits binary."""
def __init__(self, def __init__(
*, self,
binary_path: str, *,
databases: Sequence[str], binary_path: str,
n_cpu: int = 4, databases: Sequence[str],
n_iter: int = 3, n_cpu: int = 4,
e_value: float = 0.001, n_iter: int = 3,
maxseq: int = 1_000_000, e_value: float = 0.001,
realign_max: int = 100_000, maxseq: int = 1_000_000,
maxfilt: int = 100_000, realign_max: int = 100_000,
min_prefilter_hits: int = 1000, maxfilt: int = 100_000,
all_seqs: bool = False, min_prefilter_hits: int = 1000,
alt: Optional[int] = None, all_seqs: bool = False,
p: int = _HHBLITS_DEFAULT_P, alt: Optional[int] = None,
z: int = _HHBLITS_DEFAULT_Z): p: int = _HHBLITS_DEFAULT_P,
"""Initializes the Python HHblits wrapper. z: int = _HHBLITS_DEFAULT_Z,
):
Args: """Initializes the Python HHblits wrapper.
binary_path: The path to the HHblits executable.
databases: A sequence of HHblits database paths. This should be the Args:
common prefix for the database files (i.e. up to but not including binary_path: The path to the HHblits executable.
_hhm.ffindex etc.) databases: A sequence of HHblits database paths. This should be the
n_cpu: The number of CPUs to give HHblits. common prefix for the database files (i.e. up to but not including
n_iter: The number of HHblits iterations. _hhm.ffindex etc.)
e_value: The E-value, see HHblits docs for more details. n_cpu: The number of CPUs to give HHblits.
maxseq: The maximum number of rows in an input alignment. Note that this n_iter: The number of HHblits iterations.
parameter is only supported in HHBlits version 3.1 and higher. e_value: The E-value, see HHblits docs for more details.
realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. maxseq: The maximum number of rows in an input alignment. Note that this
maxfilt: Max number of hits allowed to pass the 2nd prefilter. parameter is only supported in HHBlits version 3.1 and higher.
HHblits default: 20000. realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500.
min_prefilter_hits: Min number of hits to pass prefilter. maxfilt: Max number of hits allowed to pass the 2nd prefilter.
HHblits default: 100. HHblits default: 20000.
all_seqs: Return all sequences in the MSA / Do not filter the result MSA. min_prefilter_hits: Min number of hits to pass prefilter.
HHblits default: False. HHblits default: 100.
alt: Show up to this many alternative alignments. all_seqs: Return all sequences in the MSA / Do not filter the result MSA.
p: Minimum Prob for a hit to be included in the output hhr file. HHblits default: False.
HHblits default: 20. alt: Show up to this many alternative alignments.
z: Hard cap on number of hits reported in the hhr file. p: Minimum Prob for a hit to be included in the output hhr file.
HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. HHblits default: 20.
z: Hard cap on number of hits reported in the hhr file.
Raises: HHblits default: 500. NB: The relevant HHblits flag is -Z not -z.
RuntimeError: If HHblits binary not found within the path.
""" Raises:
self.binary_path = binary_path RuntimeError: If HHblits binary not found within the path.
self.databases = databases """
self.binary_path = binary_path
for database_path in self.databases: self.databases = databases
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHBlits database %s', database_path) for database_path in self.databases:
raise ValueError(f'Could not find HHBlits database {database_path}') if not glob.glob(database_path + "_*"):
logging.error(
self.n_cpu = n_cpu "Could not find HHBlits database %s", database_path
self.n_iter = n_iter )
self.e_value = e_value raise ValueError(
self.maxseq = maxseq f"Could not find HHBlits database {database_path}"
self.realign_max = realign_max )
self.maxfilt = maxfilt
self.min_prefilter_hits = min_prefilter_hits self.n_cpu = n_cpu
self.all_seqs = all_seqs self.n_iter = n_iter
self.alt = alt self.e_value = e_value
self.p = p self.maxseq = maxseq
self.z = z self.realign_max = realign_max
self.maxfilt = maxfilt
def query(self, input_fasta_path: str) -> Mapping[str, Any]: self.min_prefilter_hits = min_prefilter_hits
"""Queries the database using HHblits.""" self.all_seqs = all_seqs
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: self.alt = alt
a3m_path = os.path.join(query_tmp_dir, 'output.a3m') self.p = p
self.z = z
db_cmd = []
for db_path in self.databases: def query(self, input_fasta_path: str) -> Mapping[str, Any]:
db_cmd.append('-d') """Queries the database using HHblits."""
db_cmd.append(db_path) with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
cmd = [ a3m_path = os.path.join(query_tmp_dir, "output.a3m")
self.binary_path,
'-i', input_fasta_path, db_cmd = []
'-cpu', str(self.n_cpu), for db_path in self.databases:
'-oa3m', a3m_path, db_cmd.append("-d")
'-o', '/dev/null', db_cmd.append(db_path)
'-n', str(self.n_iter), cmd = [
'-e', str(self.e_value), self.binary_path,
'-maxseq', str(self.maxseq), "-i",
'-realign_max', str(self.realign_max), input_fasta_path,
'-maxfilt', str(self.maxfilt), "-cpu",
'-min_prefilter_hits', str(self.min_prefilter_hits)] str(self.n_cpu),
if self.all_seqs: "-oa3m",
cmd += ['-all'] a3m_path,
if self.alt: "-o",
cmd += ['-alt', str(self.alt)] "/dev/null",
if self.p != _HHBLITS_DEFAULT_P: "-n",
cmd += ['-p', str(self.p)] str(self.n_iter),
if self.z != _HHBLITS_DEFAULT_Z: "-e",
cmd += ['-Z', str(self.z)] str(self.e_value),
cmd += db_cmd "-maxseq",
str(self.maxseq),
logging.info('Launching subprocess "%s"', ' '.join(cmd)) "-realign_max",
process = subprocess.Popen( str(self.realign_max),
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) "-maxfilt",
str(self.maxfilt),
with utils.timing('HHblits query'): "-min_prefilter_hits",
stdout, stderr = process.communicate() str(self.min_prefilter_hits),
retcode = process.wait() ]
if self.all_seqs:
if retcode: cmd += ["-all"]
# Logs have a 15k character limit, so log HHblits error line by line. if self.alt:
logging.error('HHblits failed. HHblits stderr begin:') cmd += ["-alt", str(self.alt)]
for error_line in stderr.decode('utf-8').splitlines(): if self.p != _HHBLITS_DEFAULT_P:
if error_line.strip(): cmd += ["-p", str(self.p)]
logging.error(error_line.strip()) if self.z != _HHBLITS_DEFAULT_Z:
logging.error('HHblits stderr end') cmd += ["-Z", str(self.z)]
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( cmd += db_cmd
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
logging.info('Launching subprocess "%s"', " ".join(cmd))
with open(a3m_path) as f: process = subprocess.Popen(
a3m = f.read() cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
raw_output = dict(
a3m=a3m, with utils.timing("HHblits query"):
output=stdout, stdout, stderr = process.communicate()
stderr=stderr, retcode = process.wait()
n_iter=self.n_iter,
e_value=self.e_value) if retcode:
return raw_output # Logs have a 15k character limit, so log HHblits error line by line.
logging.error("HHblits failed. HHblits stderr begin:")
for error_line in stderr.decode("utf-8").splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error("HHblits stderr end")
raise RuntimeError(
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
)
with open(a3m_path) as f:
a3m = f.read()
raw_output = dict(
a3m=a3m,
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value,
)
return raw_output
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -24,70 +24,83 @@ from openfold.data.np import utils ...@@ -24,70 +24,83 @@ from openfold.data.np import utils
class HHSearch: class HHSearch:
"""Python wrapper of the HHsearch binary.""" """Python wrapper of the HHsearch binary."""
def __init__(self, def __init__(
*, self,
binary_path: str, *,
databases: Sequence[str], binary_path: str,
n_cpu: int = 2, databases: Sequence[str],
maxseq: int = 1_000_000): n_cpu: int = 2,
"""Initializes the Python HHsearch wrapper. maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper.
Args: Args:
binary_path: The path to the HHsearch executable. binary_path: The path to the HHsearch executable.
databases: A sequence of HHsearch database paths. This should be the databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.) _hhm.ffindex etc.)
n_cpu: The number of CPUs to use n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher. parameter is only supported in HHBlits version 3.1 and higher.
Raises: Raises:
RuntimeError: If HHsearch binary not found within the path. RuntimeError: If HHsearch binary not found within the path.
""" """
self.binary_path = binary_path self.binary_path = binary_path
self.databases = databases self.databases = databases
self.n_cpu = n_cpu self.n_cpu = n_cpu
self.maxseq = maxseq self.maxseq = maxseq
for database_path in self.databases: for database_path in self.databases:
if not glob.glob(database_path + '_*'): if not glob.glob(database_path + "_*"):
logging.error('Could not find HHsearch database %s', database_path) logging.error(
raise ValueError(f'Could not find HHsearch database {database_path}') "Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str: def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m.""" """Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m') input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, 'output.hhr') hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, 'w') as f: with open(input_path, "w") as f:
f.write(a3m) f.write(a3m)
db_cmd = [] db_cmd = []
for db_path in self.databases: for db_path in self.databases:
db_cmd.append('-d') db_cmd.append("-d")
db_cmd.append(db_path) db_cmd.append(db_path)
cmd = [self.binary_path, cmd = [
'-i', input_path, self.binary_path,
'-o', hhr_path, "-i",
'-maxseq', str(self.maxseq), input_path,
'-cpu', str(self.n_cpu), "-o",
] + db_cmd hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen( process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
with utils.timing('HHsearch query'): )
stdout, stderr = process.communicate() with utils.timing("HHsearch query"):
retcode = process.wait() stdout, stderr = process.communicate()
retcode = process.wait()
if retcode: if retcode:
# Stderr is truncated to prevent proto size errors in Beam. # Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError( raise RuntimeError(
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f: with open(hhr_path) as f:
hhr = f.read() hhr = f.read()
return hhr return hhr
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -27,171 +27,202 @@ from openfold.data.tools import utils ...@@ -27,171 +27,202 @@ from openfold.data.tools import utils
class Jackhmmer: class Jackhmmer:
"""Python wrapper of the Jackhmmer binary.""" """Python wrapper of the Jackhmmer binary."""
def __init__(self, def __init__(
*, self,
binary_path: str, *,
database_path: str, binary_path: str,
n_cpu: int = 8, database_path: str,
n_iter: int = 1, n_cpu: int = 8,
e_value: float = 0.0001, n_iter: int = 1,
z_value: Optional[int] = None, e_value: float = 0.0001,
get_tblout: bool = False, z_value: Optional[int] = None,
filter_f1: float = 0.0005, get_tblout: bool = False,
filter_f2: float = 0.00005, filter_f1: float = 0.0005,
filter_f3: float = 0.0000005, filter_f2: float = 0.00005,
incdom_e: Optional[float] = None, filter_f3: float = 0.0000005,
dom_e: Optional[float] = None, incdom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None, dom_e: Optional[float] = None,
streaming_callback: Optional[Callable[[int], None]] = None): num_streamed_chunks: Optional[int] = None,
"""Initializes the Python Jackhmmer wrapper. streaming_callback: Optional[Callable[[int], None]] = None,
):
Args: """Initializes the Python Jackhmmer wrapper.
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format). Args:
n_cpu: The number of CPUs to give Jackhmmer. binary_path: The path to the jackhmmer executable.
n_iter: The number of Jackhmmer iterations. database_path: The path to the jackhmmer database (FASTA format).
e_value: The E-value, see Jackhmmer docs for more details. n_cpu: The number of CPUs to give Jackhmmer.
z_value: The Z-value, see Jackhmmer docs for more details. n_iter: The number of Jackhmmer iterations.
get_tblout: Whether to save tblout string. e_value: The E-value, see Jackhmmer docs for more details.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. z_value: The Z-value, see Jackhmmer docs for more details.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off. get_tblout: Whether to save tblout string.
filter_f3: Forward pre-filter, set to >1.0 to turn off. filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
incdom_e: Domain e-value criteria for inclusion of domains in MSA/next filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
round. filter_f3: Forward pre-filter, set to >1.0 to turn off.
dom_e: Domain e-value criteria for inclusion in tblout. incdom_e: Domain e-value criteria for inclusion of domains in MSA/next
num_streamed_chunks: Number of database chunks to stream over. round.
streaming_callback: Callback function run after each chunk iteration with dom_e: Domain e-value criteria for inclusion in tblout.
the iteration number as argument. num_streamed_chunks: Number of database chunks to stream over.
""" streaming_callback: Callback function run after each chunk iteration with
self.binary_path = binary_path the iteration number as argument.
self.database_path = database_path """
self.num_streamed_chunks = num_streamed_chunks self.binary_path = binary_path
self.database_path = database_path
if not os.path.exists(self.database_path) and num_streamed_chunks is None: self.num_streamed_chunks = num_streamed_chunks
logging.error('Could not find Jackhmmer database %s', database_path)
raise ValueError(f'Could not find Jackhmmer database {database_path}') if (
not os.path.exists(self.database_path)
self.n_cpu = n_cpu and num_streamed_chunks is None
self.n_iter = n_iter ):
self.e_value = e_value logging.error("Could not find Jackhmmer database %s", database_path)
self.z_value = z_value raise ValueError(
self.filter_f1 = filter_f1 f"Could not find Jackhmmer database {database_path}"
self.filter_f2 = filter_f2 )
self.filter_f3 = filter_f3
self.incdom_e = incdom_e self.n_cpu = n_cpu
self.dom_e = dom_e self.n_iter = n_iter
self.get_tblout = get_tblout self.e_value = e_value
self.streaming_callback = streaming_callback self.z_value = z_value
self.filter_f1 = filter_f1
def _query_chunk(self, input_fasta_path: str, database_path: str self.filter_f2 = filter_f2
) -> Mapping[str, Any]: self.filter_f3 = filter_f3
"""Queries the database chunk using Jackhmmer.""" self.incdom_e = incdom_e
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: self.dom_e = dom_e
sto_path = os.path.join(query_tmp_dir, 'output.sto') self.get_tblout = get_tblout
self.streaming_callback = streaming_callback
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these def _query_chunk(
# speeds up the pipeline at the expensive of sensitivity. They are self, input_fasta_path: str, database_path: str
# currently set very low to make querying Mgnify run in a reasonable ) -> Mapping[str, Any]:
# amount of time. """Queries the database chunk using Jackhmmer."""
cmd_flags = [ with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
# Don't pollute stdout with Jackhmmer output. sto_path = os.path.join(query_tmp_dir, "output.sto")
'-o', '/dev/null',
'-A', sto_path, # The F1/F2/F3 are the expected proportion to pass each of the filtering
'--noali', # stages (which get progressively more expensive), reducing these
'--F1', str(self.filter_f1), # speeds up the pipeline at the expensive of sensitivity. They are
'--F2', str(self.filter_f2), # currently set very low to make querying Mgnify run in a reasonable
'--F3', str(self.filter_f3), # amount of time.
'--incE', str(self.e_value), cmd_flags = [
# Report only sequences with E-values <= x in per-sequence output. # Don't pollute stdout with Jackhmmer output.
'-E', str(self.e_value), "-o",
'--cpu', str(self.n_cpu), "/dev/null",
'-N', str(self.n_iter) "-A",
] sto_path,
if self.get_tblout: "--noali",
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') "--F1",
cmd_flags.extend(['--tblout', tblout_path]) str(self.filter_f1),
"--F2",
if self.z_value: str(self.filter_f2),
cmd_flags.extend(['-Z', str(self.z_value)]) "--F3",
str(self.filter_f3),
if self.dom_e is not None: "--incE",
cmd_flags.extend(['--domE', str(self.dom_e)]) str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output.
if self.incdom_e is not None: "-E",
cmd_flags.extend(['--incdomE', str(self.incdom_e)]) str(self.e_value),
"--cpu",
cmd = [self.binary_path] + cmd_flags + [input_fasta_path, str(self.n_cpu),
database_path] "-N",
str(self.n_iter),
logging.info('Launching subprocess "%s"', ' '.join(cmd)) ]
process = subprocess.Popen( if self.get_tblout:
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
with utils.timing( cmd_flags.extend(["--tblout", tblout_path])
f'Jackhmmer ({os.path.basename(database_path)}) query'):
_, stderr = process.communicate() if self.z_value:
retcode = process.wait() cmd_flags.extend(["-Z", str(self.z_value)])
if retcode: if self.dom_e is not None:
raise RuntimeError( cmd_flags.extend(["--domE", str(self.dom_e)])
'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
if self.incdom_e is not None:
# Get e-values for each target name cmd_flags.extend(["--incdomE", str(self.incdom_e)])
tbl = ''
if self.get_tblout: cmd = (
with open(tblout_path) as f: [self.binary_path]
tbl = f.read() + cmd_flags
+ [input_fasta_path, database_path]
with open(sto_path) as f: )
sto = f.read()
logging.info('Launching subprocess "%s"', " ".join(cmd))
raw_output = dict( process = subprocess.Popen(
sto=sto, cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
tbl=tbl, )
stderr=stderr, with utils.timing(
n_iter=self.n_iter, f"Jackhmmer ({os.path.basename(database_path)}) query"
e_value=self.e_value) ):
_, stderr = process.communicate()
return raw_output retcode = process.wait()
def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: if retcode:
"""Queries the database using Jackhmmer.""" raise RuntimeError(
if self.num_streamed_chunks is None: "Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
return [self._query_chunk(input_fasta_path, self.database_path)] )
db_basename = os.path.basename(self.database_path) # Get e-values for each target name
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' tbl = ""
db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' if self.get_tblout:
with open(tblout_path) as f:
# Remove existing files to prevent OOM tbl = f.read()
for f in glob.glob(db_local_chunk('[0-9]*')):
try: with open(sto_path) as f:
os.remove(f) sto = f.read()
except OSError:
print(f'OSError while deleting {f}') raw_output = dict(
sto=sto,
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk tbl=tbl,
with futures.ThreadPoolExecutor(max_workers=2) as executor: stderr=stderr,
chunked_output = [] n_iter=self.n_iter,
for i in range(1, self.num_streamed_chunks + 1): e_value=self.e_value,
# Copy the chunk locally )
if i == 1:
future = executor.submit( return raw_output
request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
if i < self.num_streamed_chunks: def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]:
next_future = executor.submit( """Queries the database using Jackhmmer."""
request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) if self.num_streamed_chunks is None:
return [self._query_chunk(input_fasta_path, self.database_path)]
# Run Jackhmmer with the chunk
future.result() db_basename = os.path.basename(self.database_path)
chunked_output.append( db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
self._query_chunk(input_fasta_path, db_local_chunk(i))) db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
# Remove the local copy of the chunk # Remove existing files to prevent OOM
os.remove(db_local_chunk(i)) for f in glob.glob(db_local_chunk("[0-9]*")):
future = next_future try:
if self.streaming_callback: os.remove(f)
self.streaming_callback(i) except OSError:
return chunked_output print(f"OSError while deleting {f}")
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
chunked_output = []
for i in range(1, self.num_streamed_chunks + 1):
# Copy the chunk locally
if i == 1:
future = executor.submit(
request.urlretrieve,
db_remote_chunk(i),
db_local_chunk(i),
)
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve,
db_remote_chunk(i + 1),
db_local_chunk(i + 1),
)
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i))
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
future = next_future
if self.streaming_callback:
self.streaming_callback(i)
return chunked_output
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -24,81 +24,92 @@ from openfold.data.tools import utils ...@@ -24,81 +24,92 @@ from openfold.data.tools import utils
def _to_a3m(sequences: Sequence[str]) -> str: def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file.""" """Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
a3m = [] a3m = []
for sequence, name in zip(sequences, names): for sequence, name in zip(sequences, names):
a3m.append(u'>' + name + u'\n') a3m.append(u">" + name + u"\n")
a3m.append(sequence + u'\n') a3m.append(sequence + u"\n")
return ''.join(a3m) return "".join(a3m)
class Kalign: class Kalign:
"""Python wrapper of the Kalign binary.""" """Python wrapper of the Kalign binary."""
def __init__(self, *, binary_path: str): def __init__(self, *, binary_path: str):
"""Initializes the Python Kalign wrapper. """Initializes the Python Kalign wrapper.
Args: Args:
binary_path: The path to the Kalign binary. binary_path: The path to the Kalign binary.
Raises: Raises:
RuntimeError: If Kalign binary not found within the path. RuntimeError: If Kalign binary not found within the path.
""" """
self.binary_path = binary_path self.binary_path = binary_path
def align(self, sequences: Sequence[str]) -> str: def align(self, sequences: Sequence[str]) -> str:
"""Aligns the sequences and returns the alignment in A3M string. """Aligns the sequences and returns the alignment in A3M string.
Args: Args:
sequences: A list of query sequence strings. The sequences have to be at sequences: A list of query sequence strings. The sequences have to be at
least 6 residues long (Kalign requires this). Note that the order in least 6 residues long (Kalign requires this). Note that the order in
which you give the sequences might alter the output slightly as which you give the sequences might alter the output slightly as
different alignment tree might get constructed. different alignment tree might get constructed.
Returns: Returns:
A string with the alignment in a3m format. A string with the alignment in a3m format.
Raises: Raises:
RuntimeError: If Kalign fails. RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long. ValueError: If any of the sequences is less than 6 residues long.
""" """
logging.info('Aligning %d sequences', len(sequences)) logging.info("Aligning %d sequences", len(sequences))
for s in sequences: for s in sequences:
if len(s) < 6: if len(s) < 6:
raise ValueError('Kalign requires all sequences to be at least 6 ' raise ValueError(
'residues long. Got %s (%d residues).' % (s, len(s))) "Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)." % (s, len(s))
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: )
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
with open(input_fasta_path, 'w') as f: output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
f.write(_to_a3m(sequences))
with open(input_fasta_path, "w") as f:
cmd = [ f.write(_to_a3m(sequences))
self.binary_path,
'-i', input_fasta_path, cmd = [
'-o', output_a3m_path, self.binary_path,
'-format', 'fasta', "-i",
] input_fasta_path,
"-o",
logging.info('Launching subprocess "%s"', ' '.join(cmd)) output_a3m_path,
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, "-format",
stderr=subprocess.PIPE) "fasta",
]
with utils.timing('Kalign query'):
stdout, stderr = process.communicate() logging.info('Launching subprocess "%s"', " ".join(cmd))
retcode = process.wait() process = subprocess.Popen(
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
stdout.decode('utf-8'), stderr.decode('utf-8')) )
if retcode: with utils.timing("Kalign query"):
raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' stdout, stderr = process.communicate()
% (stdout.decode('utf-8'), stderr.decode('utf-8'))) retcode = process.wait()
logging.info(
with open(output_a3m_path) as f: "Kalign stdout:\n%s\n\nstderr:\n%s\n",
a3m = f.read() stdout.decode("utf-8"),
stderr.decode("utf-8"),
return a3m )
if retcode:
raise RuntimeError(
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(output_a3m_path) as f:
a3m = f.read()
return a3m
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -25,21 +25,21 @@ from typing import Optional ...@@ -25,21 +25,21 @@ from typing import Optional
@contextlib.contextmanager @contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None): def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit.""" """Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir) tmpdir = tempfile.mkdtemp(dir=base_dir)
try: try:
yield tmpdir yield tmpdir
finally: finally:
shutil.rmtree(tmpdir, ignore_errors=True) shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager @contextlib.contextmanager
def timing(msg: str): def timing(msg: str):
logging.info('Started %s', msg) logging.info("Started %s", msg)
tic = time.time() tic = time.time()
yield yield
toc = time.time() toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic) logging.info("Finished %s in %.3f seconds", msg, toc - tic)
def to_date(s: str): def to_date(s: str):
......
...@@ -3,13 +3,14 @@ import glob ...@@ -3,13 +3,14 @@ import glob
import importlib as importlib import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")] __all__ = [
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__] os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules: for _m in _modules:
globals()[_m[0]] = _m[1] globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace # Avoid needlessly cluttering the global namespace
del _files, _m, _modules del _files, _m, _modules
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -21,36 +21,37 @@ from typing import Union, List ...@@ -21,36 +21,37 @@ from typing import Union, List
class Dropout(nn.Module): class Dropout(nn.Module):
""" """
Implementation of dropout with the ability to share the dropout mask Implementation of dropout with the ability to share the dropout mask
along a particular dimension. along a particular dimension.
If not in training mode, this module computes the identity function. If not in training mode, this module computes the identity function.
""" """
def __init__(self, r: float, batch_dim: Union[int, List[int]]): def __init__(self, r: float, batch_dim: Union[int, List[int]]):
""" """
Args: Args:
r: r:
Dropout rate Dropout rate
batch_dim: batch_dim:
Dimension(s) along which the dropout mask is shared Dimension(s) along which the dropout mask is shared
""" """
super(Dropout, self).__init__() super(Dropout, self).__init__()
self.r = r self.r = r
if(type(batch_dim) == int): if type(batch_dim) == int:
batch_dim = [batch_dim] batch_dim = [batch_dim]
self.batch_dim = batch_dim self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r) self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
Args: Args:
x: x:
Tensor to which dropout is applied. Can have any shape Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim compatible with self.batch_dim
""" """
shape = list(x.shape) shape = list(x.shape)
if(self.batch_dim is not None): if self.batch_dim is not None:
for bd in self.batch_dim: for bd in self.batch_dim:
shape[bd] = 1 shape[bd] = 1
mask = x.new_ones(shape) mask = x.new_ones(shape)
...@@ -60,16 +61,18 @@ class Dropout(nn.Module): ...@@ -60,16 +61,18 @@ class Dropout(nn.Module):
class DropoutRowwise(Dropout): class DropoutRowwise(Dropout):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
""" """
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3) __init__ = partialmethod(Dropout.__init__, batch_dim=-3)
class DropoutColumnwise(Dropout): class DropoutColumnwise(Dropout):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
""" """
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2) __init__ = partialmethod(Dropout.__init__, batch_dim=-2)
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -22,11 +22,12 @@ from openfold.utils.tensor_utils import one_hot ...@@ -22,11 +22,12 @@ from openfold.utils.tensor_utils import one_hot
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
""" """
Embeds a subset of the input features. Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos). Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
""" """
def __init__( def __init__(
self, self,
tf_dim: int, tf_dim: int,
...@@ -37,18 +38,18 @@ class InputEmbedder(nn.Module): ...@@ -37,18 +38,18 @@ class InputEmbedder(nn.Module):
**kwargs, **kwargs,
): ):
""" """
Args: Args:
tf_dim: tf_dim:
Final dimension of the target features Final dimension of the target features
msa_dim: msa_dim:
Final dimension of the MSA features Final dimension of the MSA features
c_z: c_z:
Pair embedding dimension Pair embedding dimension
c_m: c_m:
MSA embedding dimension MSA embedding dimension
relpos_k: relpos_k:
Window size used in relative positional encoding Window size used in relative positional encoding
""" """
super(InputEmbedder, self).__init__() super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim self.tf_dim = tf_dim
...@@ -67,43 +68,42 @@ class InputEmbedder(nn.Module): ...@@ -67,43 +68,42 @@ class InputEmbedder(nn.Module):
self.no_bins = 2 * relpos_k + 1 self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z) self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, def relpos(self, ri: torch.Tensor):
ri: torch.Tensor
):
""" """
Computes relative positional encodings Computes relative positional encodings
Implements Algorithm 4. Implements Algorithm 4.
Args: Args:
ri: ri:
"residue_index" features of shape [*, N] "residue_index" features of shape [*, N]
""" """
d = ri[..., None] - ri[..., None, :] d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries).type(ri.dtype) oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh) return self.linear_relpos(oh)
def forward(self, def forward(
tf: torch.Tensor, self,
ri: torch.Tensor, tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor, msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
tf: tf:
"target_feat" features of shape [*, N_res, tf_dim] "target_feat" features of shape [*, N_res, tf_dim]
ri: ri:
"residue_index" features of shape [*, N_res] "residue_index" features of shape [*, N_res]
msa: msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim] "msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns: Returns:
msa_emb: msa_emb:
[*, N_clust, N_res, C_m] MSA embedding [*, N_clust, N_res, C_m] MSA embedding
pair_emb: pair_emb:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
""" """
# [*, N_res, c_z] # [*, N_res, c_z]
...@@ -128,31 +128,33 @@ class InputEmbedder(nn.Module): ...@@ -128,31 +128,33 @@ class InputEmbedder(nn.Module):
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
Embeds the output of an iteration of the model for recycling. Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32. Implements Algorithm 32.
""" """
def __init__(self,
c_m: int, def __init__(
c_z: int, self,
c_m: int,
c_z: int,
min_bin: float, min_bin: float,
max_bin: float, max_bin: float,
no_bins: int, no_bins: int,
inf: float = 1e8, inf: float = 1e8,
**kwargs **kwargs,
): ):
""" """
Args: Args:
c_m: c_m:
MSA channel dimension MSA channel dimension
c_z: c_z:
Pair embedding channel dimension Pair embedding channel dimension
min_bin: min_bin:
Smallest distogram bin (Angstroms) Smallest distogram bin (Angstroms)
max_bin: max_bin:
Largest distogram bin (Angstroms) Largest distogram bin (Angstroms)
no_bins: no_bins:
Number of distogram bins Number of distogram bins
""" """
super(RecyclingEmbedder, self).__init__() super(RecyclingEmbedder, self).__init__()
...@@ -162,58 +164,54 @@ class RecyclingEmbedder(nn.Module): ...@@ -162,58 +164,54 @@ class RecyclingEmbedder(nn.Module):
self.max_bin = max_bin self.max_bin = max_bin
self.no_bins = no_bins self.no_bins = no_bins
self.inf = inf self.inf = inf
self.bins = None self.bins = None
self.linear = Linear(self.no_bins, self.c_z) self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = nn.LayerNorm(self.c_m) self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = nn.LayerNorm(self.c_z)
def forward(self, def forward(
m: torch.Tensor, self,
z: torch.Tensor, m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
m: m:
First row of the MSA embedding. [*, N_res, C_m] First row of the MSA embedding. [*, N_res, C_m]
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
x: x:
[*, N_res, 3] predicted C_beta coordinates [*, N_res, 3] predicted C_beta coordinates
Returns: Returns:
m: m:
[*, N_res, C_m] MSA embedding update [*, N_res, C_m] MSA embedding update
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
if(self.bins is None): if self.bins is None:
self.bins = torch.linspace( self.bins = torch.linspace(
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
self.no_bins, self.no_bins,
dtype=x.dtype, dtype=x.dtype,
device=x.device device=x.device,
) )
# [*, N, C_m] # [*, N, C_m]
m_update = self.layer_norm_m(m) m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode. # This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I # I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time. # couldn't find in time.
squared_bins = self.bins ** 2 squared_bins = self.bins ** 2
upper = torch.cat( upper = torch.cat(
[ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
squared_bins[1:],
squared_bins.new_tensor([self.inf])
], dim=-1
) )
d = torch.sum( d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
dim=-1,
keepdims=True
) )
# [*, N, N, no_bins] # [*, N, N, no_bins]
...@@ -228,21 +226,23 @@ class RecyclingEmbedder(nn.Module): ...@@ -228,21 +226,23 @@ class RecyclingEmbedder(nn.Module):
class TemplateAngleEmbedder(nn.Module): class TemplateAngleEmbedder(nn.Module):
""" """
Embeds the "template_angle_feat" feature. Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7. Implements Algorithm 2, line 7.
""" """
def __init__(self,
def __init__(
self,
c_in: int, c_in: int,
c_out: int, c_out: int,
**kwargs, **kwargs,
): ):
""" """
Args: Args:
c_in: c_in:
Final dimension of "template_angle_feat" Final dimension of "template_angle_feat"
c_out: c_out:
Output channel dimension Output channel dimension
""" """
super(TemplateAngleEmbedder, self).__init__() super(TemplateAngleEmbedder, self).__init__()
...@@ -253,14 +253,12 @@ class TemplateAngleEmbedder(nn.Module): ...@@ -253,14 +253,12 @@ class TemplateAngleEmbedder(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu") self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self, def forward(self, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor
) -> torch.Tensor:
""" """
Args: Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns: Returns:
x: [*, N_templ, N_res, C_out] embedding x: [*, N_templ, N_res, C_out] embedding
""" """
x = self.linear_1(x) x = self.linear_1(x)
x = self.relu(x) x = self.relu(x)
...@@ -271,21 +269,23 @@ class TemplateAngleEmbedder(nn.Module): ...@@ -271,21 +269,23 @@ class TemplateAngleEmbedder(nn.Module):
class TemplatePairEmbedder(nn.Module): class TemplatePairEmbedder(nn.Module):
""" """
Embeds "template_pair_feat" features. Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9. Implements Algorithm 2, line 9.
""" """
def __init__(self,
def __init__(
self,
c_in: int, c_in: int,
c_out: int, c_out: int,
**kwargs, **kwargs,
): ):
""" """
Args: Args:
c_in: c_in:
c_out: c_out:
Output channel dimension Output channel dimension
""" """
super(TemplatePairEmbedder, self).__init__() super(TemplatePairEmbedder, self).__init__()
...@@ -294,16 +294,17 @@ class TemplatePairEmbedder(nn.Module): ...@@ -294,16 +294,17 @@ class TemplatePairEmbedder(nn.Module):
# Despite there being no relu nearby, the source uses that initializer # Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu") self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(self, def forward(
self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
x: x:
[*, C_in] input tensor [*, C_in] input tensor
Returns: Returns:
[*, C_out] output tensor [*, C_out] output tensor
""" """
x = self.linear(x) x = self.linear(x)
...@@ -312,21 +313,23 @@ class TemplatePairEmbedder(nn.Module): ...@@ -312,21 +313,23 @@ class TemplatePairEmbedder(nn.Module):
class ExtraMSAEmbedder(nn.Module): class ExtraMSAEmbedder(nn.Module):
""" """
Embeds unclustered MSA sequences. Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15 Implements Algorithm 2, line 15
""" """
def __init__(self,
def __init__(
self,
c_in: int, c_in: int,
c_out: int, c_out: int,
**kwargs, **kwargs,
): ):
""" """
Args: Args:
c_in: c_in:
Input channel dimension Input channel dimension
c_out: c_out:
Output channel dimension Output channel dimension
""" """
super(ExtraMSAEmbedder, self).__init__() super(ExtraMSAEmbedder, self).__init__()
...@@ -335,15 +338,13 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -335,15 +338,13 @@ class ExtraMSAEmbedder(nn.Module):
self.linear = Linear(self.c_in, self.c_out) self.linear = Linear(self.c_in, self.c_out)
def forward(self, def forward(self, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor
) -> torch.Tensor:
""" """
Args: Args:
x: x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features [*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns: Returns:
[*, N_extra_seq, N_res, C_out] embedding [*, N_extra_seq, N_res, C_out] embedding
""" """
x = self.linear(x) x = self.linear(x)
......
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -19,7 +19,7 @@ from typing import Tuple, Optional ...@@ -19,7 +19,7 @@ from typing import Tuple, Optional
from functools import partial from functools import partial
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.deepspeed import checkpoint_blocks from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
from openfold.model.msa import ( from openfold.model.msa import (
MSARowAttentionWithPairBias, MSARowAttentionWithPairBias,
...@@ -41,18 +41,19 @@ from openfold.utils.tensor_utils import chunk_layer ...@@ -41,18 +41,19 @@ from openfold.utils.tensor_utils import chunk_layer
class MSATransition(nn.Module): class MSATransition(nn.Module):
""" """
Feed-forward network applied to MSA activations after attention. Feed-forward network applied to MSA activations after attention.
Implements Algorithm 9 Implements Algorithm 9
""" """
def __init__(self, c_m, n, chunk_size): def __init__(self, c_m, n, chunk_size):
""" """
Args: Args:
c_m: c_m:
MSA channel dimension MSA channel dimension
n: n:
Factor multiplied to c_m to obtain the hidden channel Factor multiplied to c_m to obtain the hidden channel
dimension dimension
""" """
super(MSATransition, self).__init__() super(MSATransition, self).__init__()
...@@ -64,29 +65,30 @@ class MSATransition(nn.Module): ...@@ -64,29 +65,30 @@ class MSATransition(nn.Module):
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu") self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final") self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask): def _transition(self, m, mask):
m = self.linear_1(m) m = self.linear_1(m)
m = self.relu(m) m = self.relu(m)
m = self.linear_2(m) * mask m = self.linear_2(m) * mask
return m return m
def forward(self, def forward(
self,
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
m: m:
[*, N_seq, N_res, C_m] MSA activation [*, N_seq, N_res, C_m] MSA activation
mask: mask:
[*, N_seq, N_res, C_m] MSA mask [*, N_seq, N_res, C_m] MSA mask
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA activation update [*, N_seq, N_res, C_m] MSA activation update
""" """
# DISCREPANCY: DeepMind forgets to apply the MSA mask here. # DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if(mask is None): if mask is None:
mask = m.new_ones(m.shape[:-1]) mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
...@@ -94,7 +96,7 @@ class MSATransition(nn.Module): ...@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
m = self.layer_norm(m) m = self.layer_norm(m)
inp = {"m": m, "mask": mask} inp = {"m": m, "mask": mask}
if(self.chunk_size is not None): if self.chunk_size is not None:
m = chunk_layer( m = chunk_layer(
self._transition, self._transition,
inp, inp,
...@@ -108,7 +110,8 @@ class MSATransition(nn.Module): ...@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
class EvoformerBlock(nn.Module): class EvoformerBlock(nn.Module):
def __init__(self, def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -126,7 +129,7 @@ class EvoformerBlock(nn.Module): ...@@ -126,7 +129,7 @@ class EvoformerBlock(nn.Module):
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
): ):
super(EvoformerBlock, self).__init__() super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias( self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m, c_m=c_m,
c_z=c_z, c_z=c_z,
...@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module): ...@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
inf=inf, inf=inf,
) )
if(_is_extra_msa_stack): if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention( self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m, c_in=c_m,
c_hidden=c_hidden_msa_att, c_hidden=c_hidden_msa_att,
...@@ -196,16 +199,17 @@ class EvoformerBlock(nn.Module): ...@@ -196,16 +199,17 @@ class EvoformerBlock(nn.Module):
transition_n, transition_n,
chunk_size=chunk_size, chunk_size=chunk_size,
) )
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(self, def forward(
m: torch.Tensor, self,
z: torch.Tensor, m: torch.Tensor,
msa_mask: torch.Tensor, z: torch.Tensor,
pair_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans # DeepMind doesn't mask these transitions in the source, so _mask_trans
...@@ -229,11 +233,13 @@ class EvoformerBlock(nn.Module): ...@@ -229,11 +233,13 @@ class EvoformerBlock(nn.Module):
class EvoformerStack(nn.Module): class EvoformerStack(nn.Module):
""" """
Main Evoformer trunk. Main Evoformer trunk.
Implements Algorithm 6. Implements Algorithm 6.
""" """
def __init__(self,
def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -248,43 +254,43 @@ class EvoformerStack(nn.Module): ...@@ -248,43 +254,43 @@ class EvoformerStack(nn.Module):
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int, blocks_per_ckpt: int,
chunk_size: int, chunk_size: int,
inf: float, inf: float,
eps: float, eps: float,
_is_extra_msa_stack: bool = False, _is_extra_msa_stack: bool = False,
**kwargs, **kwargs,
): ):
""" """
Args: Args:
c_m: c_m:
MSA channel dimension MSA channel dimension
c_z: c_z:
Pair channel dimension Pair channel dimension
c_hidden_msa_att: c_hidden_msa_att:
Hidden dimension in MSA attention Hidden dimension in MSA attention
c_hidden_opm: c_hidden_opm:
Hidden dimension in outer product mean module Hidden dimension in outer product mean module
c_hidden_mul: c_hidden_mul:
Hidden dimension in multiplicative updates Hidden dimension in multiplicative updates
c_hidden_pair_att: c_hidden_pair_att:
Hidden dimension in triangular attention Hidden dimension in triangular attention
c_s: c_s:
Channel dimension of the output "single" embedding Channel dimension of the output "single" embedding
no_heads_msa: no_heads_msa:
Number of heads used for MSA attention Number of heads used for MSA attention
no_heads_pair: no_heads_pair:
Number of heads used for pair attention Number of heads used for pair attention
no_blocks: no_blocks:
Number of Evoformer blocks in the stack Number of Evoformer blocks in the stack
transition_n: transition_n:
Factor by which to multiply c_m to obtain the MSATransition Factor by which to multiply c_m to obtain the MSATransition
hidden dimension hidden dimension
msa_dropout: msa_dropout:
Dropout rate for MSA activations Dropout rate for MSA activations
pair_dropout: pair_dropout:
Dropout used for pair activations Dropout used for pair activations
blocks_per_ckpt: blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint Number of Evoformer blocks in each activation checkpoint
""" """
super(EvoformerStack, self).__init__() super(EvoformerStack, self).__init__()
...@@ -313,49 +319,51 @@ class EvoformerStack(nn.Module): ...@@ -313,49 +319,51 @@ class EvoformerStack(nn.Module):
) )
self.blocks.append(block) self.blocks.append(block)
if(not self._is_extra_msa_stack): if not self._is_extra_msa_stack:
self.linear = Linear(c_m, c_s) self.linear = Linear(c_m, c_s)
def forward(self, def forward(
m: torch.Tensor, self,
m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
pair_mask: torch.Tensor, pair_mask: torch.Tensor,
_mask_trans: bool = True, _mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" """
Args: Args:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
msa_mask: msa_mask:
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
pair_mask: pair_mask:
[*, N_res, N_res] pair mask [*, N_res, N_res] pair mask
Returns: Returns:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
s: s:
[*, N_res, C_s] single embedding [*, N_res, C_s] single embedding
""" """
m, z = checkpoint_blocks( m, z = checkpoint_blocks(
blocks=[ blocks=[
partial( partial(
b, b,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) for b in self.blocks )
], for b in self.blocks
],
args=(m, z), args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
s = None s = None
if(not self._is_extra_msa_stack): if not self._is_extra_msa_stack:
seq_dim = -3 seq_dim = -3
index = torch.tensor([0], device=m.device) index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
...@@ -365,10 +373,12 @@ class EvoformerStack(nn.Module): ...@@ -365,10 +373,12 @@ class EvoformerStack(nn.Module):
class ExtraMSAStack(nn.Module): class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
""" """
def __init__(self, Implements Algorithm 18.
"""
def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -408,34 +418,35 @@ class ExtraMSAStack(nn.Module): ...@@ -408,34 +418,35 @@ class ExtraMSAStack(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
inf=inf, inf=inf,
eps=eps, eps=eps,
_is_extra_msa_stack=True, _is_extra_msa_stack=True,
) )
def forward(self, def forward(
m: torch.Tensor, self,
z: torch.Tensor, m: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None, z: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
m: m:
[*, N_extra, N_res, C_m] extra MSA embedding [*, N_extra, N_res, C_m] extra MSA embedding
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
msa_mask: msa_mask:
Optional [*, N_extra, N_res] MSA mask Optional [*, N_extra, N_res] MSA mask
pair_mask: pair_mask:
Optional [*, N_res, N_res] pair mask Optional [*, N_res, N_res] pair mask
Returns: Returns:
[*, N_res, N_res, C_z] pair update [*, N_res, N_res, C_z] pair update
""" """
_, z, _ = self.stack( _, z, _ = self.stack(
m, m,
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=_mask_trans _mask_trans=_mask_trans,
) )
return z return z
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -18,8 +18,8 @@ import torch.nn as nn ...@@ -18,8 +18,8 @@ import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.loss import ( from openfold.utils.loss import (
compute_plddt, compute_plddt,
compute_tm, compute_tm,
compute_predicted_aligned_error, compute_predicted_aligned_error,
) )
...@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module): ...@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
**config["experimentally_resolved"], **config["experimentally_resolved"],
) )
if(config.tm.enabled): if config.tm.enabled:
self.tm = TMScoreHead( self.tm = TMScoreHead(
**config.tm, **config.tm,
) )
...@@ -68,20 +68,23 @@ class AuxiliaryHeads(nn.Module): ...@@ -68,20 +68,23 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits = self.experimentally_resolved( experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"] outputs["single"]
) )
aux_out["experimentally_resolved_logits"] = ( aux_out[
experimentally_resolved_logits "experimentally_resolved_logits"
) ] = experimentally_resolved_logits
if(self.config.tm.enabled): if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm( aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm tm_logits, **self.config.tm
) )
aux_out.update(compute_predicted_aligned_error( aux_out.update(
tm_logits, **self.config.tm, compute_predicted_aligned_error(
)) tm_logits,
**self.config.tm,
)
)
return aux_out return aux_out
...@@ -114,17 +117,18 @@ class PerResidueLDDTCaPredictor(nn.Module): ...@@ -114,17 +117,18 @@ class PerResidueLDDTCaPredictor(nn.Module):
class DistogramHead(nn.Module): class DistogramHead(nn.Module):
""" """
Computes a distogram probability distribution. Computes a distogram probability distribution.
For use in computation of distogram loss, subsection 1.9.8 For use in computation of distogram loss, subsection 1.9.8
""" """
def __init__(self, c_z, no_bins, **kwargs): def __init__(self, c_z, no_bins, **kwargs):
""" """
Args: Args:
c_z: c_z:
Input channel dimension Input channel dimension
no_bins: no_bins:
Number of distogram bins Number of distogram bins
""" """
super(DistogramHead, self).__init__() super(DistogramHead, self).__init__()
...@@ -133,15 +137,13 @@ class DistogramHead(nn.Module): ...@@ -133,15 +137,13 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final") self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, def forward(self, z): # [*, N, N, C_z]
z # [*, N, N, C_z]
):
""" """
Args: Args:
z: z:
[*, N_res, N_res, C_z] pair embedding [*, N_res, N_res, C_z] pair embedding
Returns: Returns:
[*, N, N, no_bins] distogram probability distribution [*, N, N, no_bins] distogram probability distribution
""" """
# [*, N, N, no_bins] # [*, N, N, no_bins]
logits = self.linear(z) logits = self.linear(z)
...@@ -151,15 +153,16 @@ class DistogramHead(nn.Module): ...@@ -151,15 +153,16 @@ class DistogramHead(nn.Module):
class TMScoreHead(nn.Module): class TMScoreHead(nn.Module):
""" """
For use in computation of TM-score, subsection 1.9.7 For use in computation of TM-score, subsection 1.9.7
""" """
def __init__(self, c_z, no_bins, **kwargs): def __init__(self, c_z, no_bins, **kwargs):
""" """
Args: Args:
c_z: c_z:
Input channel dimension Input channel dimension
no_bins: no_bins:
Number of bins Number of bins
""" """
super(TMScoreHead, self).__init__() super(TMScoreHead, self).__init__()
...@@ -170,11 +173,11 @@ class TMScoreHead(nn.Module): ...@@ -170,11 +173,11 @@ class TMScoreHead(nn.Module):
def forward(self, z): def forward(self, z):
""" """
Args: Args:
z: z:
[*, N_res, N_res, C_z] pairwise embedding [*, N_res, N_res, C_z] pairwise embedding
Returns: Returns:
[*, N_res, N_res, no_bins] prediction [*, N_res, N_res, no_bins] prediction
""" """
# [*, N, N, no_bins] # [*, N, N, no_bins]
logits = self.linear(z) logits = self.linear(z)
...@@ -183,15 +186,16 @@ class TMScoreHead(nn.Module): ...@@ -183,15 +186,16 @@ class TMScoreHead(nn.Module):
class MaskedMSAHead(nn.Module): class MaskedMSAHead(nn.Module):
""" """
For use in computation of masked MSA loss, subsection 1.9.9 For use in computation of masked MSA loss, subsection 1.9.9
""" """
def __init__(self, c_m, c_out, **kwargs): def __init__(self, c_m, c_out, **kwargs):
""" """
Args: Args:
c_m: c_m:
MSA channel dimension MSA channel dimension
c_out: c_out:
Output channel dimension Output channel dimension
""" """
super(MaskedMSAHead, self).__init__() super(MaskedMSAHead, self).__init__()
...@@ -202,11 +206,11 @@ class MaskedMSAHead(nn.Module): ...@@ -202,11 +206,11 @@ class MaskedMSAHead(nn.Module):
def forward(self, m): def forward(self, m):
""" """
Args: Args:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
Returns: Returns:
[*, N_seq, N_res, C_out] reconstruction [*, N_seq, N_res, C_out] reconstruction
""" """
# [*, N_seq, N_res, C_out] # [*, N_seq, N_res, C_out]
logits = self.linear(m) logits = self.linear(m)
...@@ -215,16 +219,17 @@ class MaskedMSAHead(nn.Module): ...@@ -215,16 +219,17 @@ class MaskedMSAHead(nn.Module):
class ExperimentallyResolvedHead(nn.Module): class ExperimentallyResolvedHead(nn.Module):
""" """
For use in computation of "experimentally resolved" loss, subsection For use in computation of "experimentally resolved" loss, subsection
1.9.10 1.9.10
""" """
def __init__(self, c_s, c_out, **kwargs): def __init__(self, c_s, c_out, **kwargs):
""" """
Args: Args:
c_s: c_s:
Input channel dimension Input channel dimension
c_out: c_out:
Number of distogram bins Number of distogram bins
""" """
super(ExperimentallyResolvedHead, self).__init__() super(ExperimentallyResolvedHead, self).__init__()
...@@ -235,11 +240,11 @@ class ExperimentallyResolvedHead(nn.Module): ...@@ -235,11 +240,11 @@ class ExperimentallyResolvedHead(nn.Module):
def forward(self, s): def forward(self, s):
""" """
Args: Args:
s: s:
[*, N_res, C_s] single embedding [*, N_res, C_s] single embedding
Returns: Returns:
[*, N, C_out] logits [*, N, C_out] logits
""" """
# [*, N, C_out] # [*, N, C_out]
logits = self.linear(s) logits = self.linear(s)
......
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -25,7 +25,7 @@ from openfold.utils.feats import ( ...@@ -25,7 +25,7 @@ from openfold.utils.feats import (
atom14_to_atom37, atom14_to_atom37,
) )
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateAngleEmbedder,
TemplatePairEmbedder, TemplatePairEmbedder,
...@@ -36,7 +36,7 @@ from openfold.model.heads import AuxiliaryHeads ...@@ -36,7 +36,7 @@ from openfold.model.heads import AuxiliaryHeads
import openfold.np.residue_constants as residue_constants import openfold.np.residue_constants as residue_constants
from openfold.model.structure_module import StructureModule from openfold.model.structure_module import StructureModule
from openfold.model.template import ( from openfold.model.template import (
TemplatePairStack, TemplatePairStack,
TemplatePointwiseAttention, TemplatePointwiseAttention,
) )
from openfold.utils.loss import ( from openfold.utils.loss import (
...@@ -46,19 +46,20 @@ from openfold.utils.tensor_utils import ( ...@@ -46,19 +46,20 @@ from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
tensor_tree_map, tensor_tree_map,
) )
class AlphaFold(nn.Module): class AlphaFold(nn.Module):
""" """
Alphafold 2. Alphafold 2.
Implements Algorithm 2 (but with training). Implements Algorithm 2 (but with training).
""" """
def __init__(self, config): def __init__(self, config):
""" """
Args: Args:
config: config:
A dict-like config object (like the one in config.py) A dict-like config object (like the one in config.py)
""" """
super(AlphaFold, self).__init__() super(AlphaFold, self).__init__()
...@@ -107,7 +108,7 @@ class AlphaFold(nn.Module): ...@@ -107,7 +108,7 @@ class AlphaFold(nn.Module):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx), lambda t: torch.index_select(t, templ_dim, idx),
...@@ -115,11 +116,11 @@ class AlphaFold(nn.Module): ...@@ -115,11 +116,11 @@ class AlphaFold(nn.Module):
) )
single_template_embeds = {} single_template_embeds = {}
if(self.config.template.embed_angles): if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
single_template_feats, single_template_feats,
) )
# [*, S_t, N, C_m] # [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat) a = self.template_angle_embedder(template_angle_feat)
...@@ -130,19 +131,19 @@ class AlphaFold(nn.Module): ...@@ -130,19 +131,19 @@ class AlphaFold(nn.Module):
single_template_feats, single_template_feats,
inf=self.config.template.inf, inf=self.config.template.inf,
eps=self.config.template.eps, eps=self.config.template.eps,
**self.config.template.distogram **self.config.template.distogram,
) )
t = self.template_pair_embedder(t) t = self.template_pair_embedder(t)
t = self.template_pair_stack( t = self.template_pair_stack(
t, t, pair_mask.unsqueeze(-3), _mask_trans=self.config._mask_trans
pair_mask.unsqueeze(-3), )
_mask_trans=self.config._mask_trans
single_template_embeds.update(
{
"pair": t,
}
) )
single_template_embeds.update({
"pair": t,
})
template_embeds.append(single_template_embeds) template_embeds.append(single_template_embeds)
template_embeds = dict_multimap( template_embeds = dict_multimap(
...@@ -152,19 +153,19 @@ class AlphaFold(nn.Module): ...@@ -152,19 +153,19 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
t = self.template_pointwise_att( t = self.template_pointwise_att(
template_embeds["pair"], template_embeds["pair"], z, template_mask=batch["template_mask"]
z,
template_mask=batch["template_mask"]
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {} ret = {}
if(self.config.template.embed_angles): if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"] ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({ ret.update(
"template_pair_embedding": t, {
}) "template_pair_embedding": t,
}
)
return ret return ret
...@@ -189,18 +190,18 @@ class AlphaFold(nn.Module): ...@@ -189,18 +190,18 @@ class AlphaFold(nn.Module):
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder( m, z = self.input_embedder(
feats["target_feat"], feats["target_feat"],
feats["residue_index"], feats["residue_index"],
feats["msa_feat"], feats["msa_feat"],
) )
# Inject information from previous recycling iterations # Inject information from previous recycling iterations
if(self.config.num_recycle > 0): if self.config.num_recycle > 0:
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]): if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m.new_zeros( m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m), (*batch_dims, n, self.config.input_embedder.c_m),
) )
# [*, N, N, C_z] # [*, N, N, C_z]
...@@ -213,17 +214,13 @@ class AlphaFold(nn.Module): ...@@ -213,17 +214,13 @@ class AlphaFold(nn.Module):
(*batch_dims, n, residue_constants.atom_type_num, 3), (*batch_dims, n, residue_constants.atom_type_num, 3),
) )
x_prev = pseudo_beta_fn( x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
feats["aatype"],
x_prev,
None
)
# m_1_prev_emb: [*, N, C_m] # m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z] # z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev, m_1_prev,
z_prev, z_prev,
x_prev, x_prev,
) )
...@@ -237,9 +234,9 @@ class AlphaFold(nn.Module): ...@@ -237,9 +234,9 @@ class AlphaFold(nn.Module):
del m_1_prev_emb, z_prev_emb del m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled): if self.config.template.enabled:
template_feats = { template_feats = {
k:v for k,v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
template_embeds = self.embed_templates( template_embeds = self.embed_templates(
template_feats, template_feats,
...@@ -251,28 +248,27 @@ class AlphaFold(nn.Module): ...@@ -251,28 +248,27 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"] z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles): if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_angle_embedding"]], dim=-3
dim=-3
) )
# [*, S, N] # [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2 [feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
) )
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled): if self.config.extra_msa.enabled:
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z] # [*, N, N, C_z]
z = self.extra_msa_stack( z = self.extra_msa_stack(
a, a,
z, z,
msa_mask=feats["extra_msa_mask"], msa_mask=feats["extra_msa_mask"],
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=self.config._mask_trans, _mask_trans=self.config._mask_trans,
...@@ -283,11 +279,11 @@ class AlphaFold(nn.Module): ...@@ -283,11 +279,11 @@ class AlphaFold(nn.Module):
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
# s: [*, N, C_s] # s: [*, N, C_s]
m, z, s = self.evoformer( m, z, s = self.evoformer(
m, m,
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=self.config._mask_trans _mask_trans=self.config._mask_trans,
) )
outputs["msa"] = m[..., :n_seq, :, :] outputs["msa"] = m[..., :n_seq, :, :]
...@@ -296,15 +292,18 @@ class AlphaFold(nn.Module): ...@@ -296,15 +292,18 @@ class AlphaFold(nn.Module):
# Predict 3D structure # Predict 3D structure
outputs["sm"] = self.structure_module( outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"], s,
) z,
feats["aatype"],
mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37( outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats outputs["sm"]["positions"][-1], feats
) )
outputs["final_atom_mask"] = feats["atom37_atom_exists"] outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1] outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration # Save embeddings for use during the next recycling iteration
# [*, N, C_m] # [*, N, C_m]
m_1_prev = m[..., 0, :, :] m_1_prev = m[..., 0, :, :]
...@@ -335,81 +334,84 @@ class AlphaFold(nn.Module): ...@@ -335,81 +334,84 @@ class AlphaFold(nn.Module):
def forward(self, batch): def forward(self, batch):
""" """
Args: Args:
batch: batch:
Dictionary of arguments outlined in Algorithm 2. Keys must Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the include the official names of the features in the
supplement subsection 1.2.9. supplement subsection 1.2.9.
The final dimension of each input must have length equal to The final dimension of each input must have length equal to
the number of recycling iterations. the number of recycling iterations.
Features (without the recycling dimension): Features (without the recycling dimension):
"aatype" ([*, N_res]): "aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue Contrary to the supplement, this tensor of residue
indices is not one-hot. indices is not one-hot.
"target_feat" ([*, N_res, C_tf]) "target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim. config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res]) "residue_index" ([*, N_res])
Tensor whose final dimension consists of Tensor whose final dimension consists of
consecutive indices from 0 to N_res. consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa]) "msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement. MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim. C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res]) "seq_mask" ([*, N_res])
1-D sequence mask 1-D sequence mask
"msa_mask" ([*, N_seq, N_res]) "msa_mask" ([*, N_seq, N_res])
MSA mask MSA mask
"pair_mask" ([*, N_res, N_res]) "pair_mask" ([*, N_res, N_res])
2-D pair mask 2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res]) "extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask Extra MSA mask
"template_mask" ([*, N_templ]) "template_mask" ([*, N_templ])
Template mask (on the level of templates, not Template mask (on the level of templates, not
residues) residues)
"template_aatype" ([*, N_templ, N_res]) "template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown)) than 19 are clamped to 20 (Unknown))
"template_all_atom_positions" "template_all_atom_positions"
([*, N_templ, N_res, 37, 3]) ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37]) "template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3]) "template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for (i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead) for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res]) "template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask Pseudo-beta mask
""" """
# Initialize recycling embeddings # Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
is_grad_enabled = torch.is_grad_enabled() is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing() self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.num_recycle + 1): for cycle_no in range(self.config.num_recycle + 1):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.num_recycle) is_final_iter = cycle_no == self.config.num_recycle
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766 # Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter): if is_final_iter:
self._enable_activation_checkpointing() self._enable_activation_checkpointing()
if(torch.is_autocast_enabled()): if torch.is_autocast_enabled():
torch.clear_autocast_cache() torch.clear_autocast_cache()
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev, feats,
m_1_prev,
z_prev,
x_prev,
) )
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
return outputs return outputs
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -18,39 +18,40 @@ import torch ...@@ -18,39 +18,40 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional from typing import Optional
from openfold.model.primitives import Linear, Attention, GlobalAttention from openfold.model.primitives import Linear, Attention, GlobalAttention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
) )
class MSAAttention(nn.Module): class MSAAttention(nn.Module):
def __init__(self, def __init__(
c_in, self,
c_hidden, c_in,
no_heads, c_hidden,
pair_bias=False, no_heads,
c_z=None, pair_bias=False,
c_z=None,
chunk_size=4, chunk_size=4,
inf=1e9, inf=1e9,
): ):
""" """
Args: Args:
c_in: c_in:
Input channel dimension Input channel dimension
c_hidden: c_hidden:
Per-head hidden channel dimension Per-head hidden channel dimension
no_heads: no_heads:
Number of attention heads Number of attention heads
pair_bias: pair_bias:
Whether to use pair embedding bias Whether to use pair embedding bias
c_z: c_z:
Pair embedding channel dimension. Ignored unless pair_bias Pair embedding channel dimension. Ignored unless pair_bias
is true is true
inf: inf:
A large number to be used in computing the attention mask A large number to be used in computing the attention mask
""" """
super(MSAAttention, self).__init__() super(MSAAttention, self).__init__()
...@@ -64,49 +65,46 @@ class MSAAttention(nn.Module): ...@@ -64,49 +65,46 @@ class MSAAttention(nn.Module):
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = nn.LayerNorm(self.c_in)
if(self.pair_bias): if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = nn.LayerNorm(self.c_z)
self.linear_z = Linear( self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal" self.c_z, self.no_heads, bias=False, init="normal"
) )
self.mha = Attention( self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
self.c_hidden,
self.no_heads
) )
def forward(self, m, z=None, mask=None): def forward(self, m, z=None, mask=None):
""" """
Args: Args:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
z: z:
[*, N_res, N_res, C_z] pair embedding. Required only if [*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True pair_bias is True
mask: mask:
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
""" """
# [*, N_seq, N_res, C_m] # [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m) m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1] n_seq, n_res = m.shape[-3:-1]
if(mask is None): if mask is None:
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = m.new_ones( mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res), m.shape[:-3] + (n_seq, n_res),
) )
# [*, N_seq, 1, 1, N_res] # [*, N_seq, 1, 1, N_res]
bias = (self.inf * (mask - 1))[..., :, None, None, :] bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, N_seq, no_heads, N_res, N_res] # [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand( bias = bias.expand(
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1) ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
) )
biases = [bias] biases = [bias]
if(self.pair_bias): if self.pair_bias:
# [*, N_res, N_res, C_z] # [*, N_res, N_res, C_z]
z = self.layer_norm_z(z) z = self.layer_norm_z(z)
...@@ -118,18 +116,13 @@ class MSAAttention(nn.Module): ...@@ -118,18 +116,13 @@ class MSAAttention(nn.Module):
biases.append(z) biases.append(z)
mha_inputs = { mha_inputs = {"q_x": m, "k_x": m, "v_x": m, "biases": biases}
"q_x": m, if self.chunk_size is not None:
"k_x": m,
"v_x": m,
"biases": biases
}
if(self.chunk_size is not None):
m = chunk_layer( m = chunk_layer(
self.mha, self.mha,
mha_inputs, mha_inputs,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2]),
) )
else: else:
m = self.mha(**mha_inputs) m = self.mha(**mha_inputs)
...@@ -139,27 +132,28 @@ class MSAAttention(nn.Module): ...@@ -139,27 +132,28 @@ class MSAAttention(nn.Module):
class MSARowAttentionWithPairBias(MSAAttention): class MSARowAttentionWithPairBias(MSAAttention):
""" """
Implements Algorithm 7. Implements Algorithm 7.
""" """
def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9): def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9):
""" """
Args: Args:
c_m: c_m:
Input channel dimension Input channel dimension
c_z: c_z:
Pair embedding channel dimension Pair embedding channel dimension
c_hidden: c_hidden:
Per-head hidden channel dimension Per-head hidden channel dimension
no_heads: no_heads:
Number of attention heads Number of attention heads
inf: inf:
Large number used to construct attention masks Large number used to construct attention masks
""" """
super(MSARowAttentionWithPairBias, self).__init__( super(MSARowAttentionWithPairBias, self).__init__(
c_m, c_m,
c_hidden, c_hidden,
no_heads, no_heads,
pair_bias=True, pair_bias=True,
c_z=c_z, c_z=c_z,
chunk_size=chunk_size, chunk_size=chunk_size,
inf=inf, inf=inf,
...@@ -168,19 +162,20 @@ class MSARowAttentionWithPairBias(MSAAttention): ...@@ -168,19 +162,20 @@ class MSARowAttentionWithPairBias(MSAAttention):
class MSAColumnAttention(MSAAttention): class MSAColumnAttention(MSAAttention):
""" """
Implements Algorithm 8. Implements Algorithm 8.
""" """
def __init__(self, c_m, c_hidden, no_heads, chunk_size=4, inf=1e9): def __init__(self, c_m, c_hidden, no_heads, chunk_size=4, inf=1e9):
""" """
Args: Args:
c_m: c_m:
MSA channel dimension MSA channel dimension
c_hidden: c_hidden:
Per-head hidden channel dimension Per-head hidden channel dimension
no_heads: no_heads:
Number of attention heads Number of attention heads
inf: inf:
Large number used to construct attention masks Large number used to construct attention masks
""" """
super(MSAColumnAttention, self).__init__( super(MSAColumnAttention, self).__init__(
c_in=c_m, c_in=c_m,
...@@ -192,37 +187,31 @@ class MSAColumnAttention(MSAAttention): ...@@ -192,37 +187,31 @@ class MSAColumnAttention(MSAAttention):
inf=inf, inf=inf,
) )
def forward(self, m, mask=None): def forward(self, m, mask=None):
""" """
Args: Args:
m: m:
[*, N_seq, N_res, C_m] MSA embedding [*, N_seq, N_res, C_m] MSA embedding
mask: mask:
[*, N_seq, N_res] MSA mask [*, N_seq, N_res] MSA mask
""" """
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
if(mask is not None): if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
m = super().forward(m, mask=mask) m = super().forward(m, mask=mask)
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
if(mask is not None): if mask is not None:
mask = mask.transpose(-1, -2) mask = mask.transpose(-1, -2)
return m return m
class MSAColumnGlobalAttention(nn.Module): class MSAColumnGlobalAttention(nn.Module):
def __init__(self, def __init__(
c_in, self, c_in, c_hidden, no_heads, chunk_size=4, inf=1e9, eps=1e-10
c_hidden,
no_heads,
chunk_size=4,
inf=1e9,
eps=1e-10
): ):
super(MSAColumnGlobalAttention, self).__init__() super(MSAColumnGlobalAttention, self).__init__()
...@@ -243,13 +232,12 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -243,13 +232,12 @@ class MSAColumnGlobalAttention(nn.Module):
eps=eps, eps=eps,
) )
def forward(self, def forward(
m: torch.Tensor, self, m: torch.Tensor, mask: Optional[torch.Tensor] = None
mask: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
if(mask is None): if mask is None:
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = torch.ones( mask = torch.ones(
m.shape[:-1], m.shape[:-1],
...@@ -268,16 +256,16 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -268,16 +256,16 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m, "m": m,
"mask": mask, "mask": mask,
} }
if(self.chunk_size is not None): if self.chunk_size is not None:
m = chunk_layer( m = chunk_layer(
self.global_attention, self.global_attention,
mha_input, mha_input,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2]),
) )
else: else:
m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"]) m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"])
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment