Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
......@@ -4,50 +4,50 @@ import ml_collections as mlc
def set_inf(c, inf):
for k, v in c.items():
if(isinstance(v, mlc.ConfigDict)):
if isinstance(v, mlc.ConfigDict):
set_inf(v, inf)
elif(k == 'inf'):
elif k == "inf":
c[k] = inf
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
if(name == 'model_1'):
if name == "model_1":
pass
elif(name == 'model_2'):
elif name == "model_2":
pass
elif(name == 'model_3'):
elif name == "model_3":
c.model.template.enabled = False
elif(name == 'model_4'):
elif name == "model_4":
c.model.template.enabled = False
elif(name == 'model_5'):
elif name == "model_5":
c.model.template.enabled = False
elif(name == 'model_1_ptm'):
elif name == "model_1_ptm":
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == 'model_2_ptm'):
elif name == "model_2_ptm":
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == 'model_3_ptm'):
elif name == "model_3_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == 'model_4_ptm'):
elif name == "model_4_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == 'model_5_ptm'):
elif name == "model_5_ptm":
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else:
raise ValueError('Invalid model name')
raise ValueError("Invalid model name")
if(train):
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
if(low_prec):
if low_prec:
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
......@@ -69,370 +69,384 @@ num_recycle = mlc.FieldReference(3, field_type=int)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
NUM_EXTRA_SEQ = 'extra msa placeholder'
NUM_TEMPLATES = 'num templates placeholder'
NUM_RES = "num residues placeholder"
NUM_MSA_SEQ = "msa placeholder"
NUM_EXTRA_SEQ = "extra msa placeholder"
NUM_TEMPLATES = "num templates placeholder"
config = mlc.ConfigDict({
'data': {
'common': {
'batch_modes': [('clamped', 0.9), ('unclamped', 0.1)],
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
'all_atom_positions': [NUM_RES, None, None],
'alt_chi_angles': [NUM_RES, None],
'atom14_alt_gt_exists': [NUM_RES, None],
'atom14_alt_gt_positions': [NUM_RES, None, None],
'atom14_atom_exists': [NUM_RES, None],
'atom14_atom_is_ambiguous': [NUM_RES, None],
'atom14_gt_exists': [NUM_RES, None],
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
'backbone_affine_tensor': [NUM_RES, None, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
'extra_deletion_value': [NUM_EXTRA_SEQ, NUM_RES],
'extra_has_deletion': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_mask': [NUM_EXTRA_SEQ, NUM_RES],
'extra_msa_row_mask': [NUM_EXTRA_SEQ],
'is_distillation': [],
'msa_feat': [NUM_MSA_SEQ, NUM_RES, None],
'msa_mask': [NUM_MSA_SEQ, NUM_RES],
'msa_row_mask': [NUM_MSA_SEQ],
'pseudo_beta': [NUM_RES, None],
'pseudo_beta_mask': [NUM_RES],
'residue_index': [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None],
'residx_atom37_to_atom14': [NUM_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None, None],
'rigidgroups_group_exists': [NUM_RES, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None, None],
'seq_length': [],
'seq_mask': [NUM_RES],
'target_feat': [NUM_RES, None],
'template_aatype': [NUM_TEMPLATES, NUM_RES],
'template_all_atom_mask': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions':
[NUM_TEMPLATES, NUM_RES, None, None],
'template_alt_torsion_angles_sin_cos':
[NUM_TEMPLATES, NUM_RES, None, None],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
'template_backbone_affine_tensor': [
NUM_TEMPLATES, NUM_RES, None, None],
'template_mask': [NUM_TEMPLATES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
'template_sum_probs': [NUM_TEMPLATES, None],
'template_torsion_angles_mask': [NUM_TEMPLATES, NUM_RES, None],
'template_torsion_angles_sin_cos':
[NUM_TEMPLATES, NUM_RES, None, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES],
'use_clamped_fape': [],
},
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
'msa_cluster_features': True,
'num_recycle': num_recycle,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_mask',
config = mlc.ConfigDict(
{
"data": {
"common": {
"batch_modes": [("clamped", 0.9), ("unclamped", 0.1)],
"feat": {
"aatype": [NUM_RES],
"all_atom_mask": [NUM_RES, None],
"all_atom_positions": [NUM_RES, None, None],
"alt_chi_angles": [NUM_RES, None],
"atom14_alt_gt_exists": [NUM_RES, None],
"atom14_alt_gt_positions": [NUM_RES, None, None],
"atom14_atom_exists": [NUM_RES, None],
"atom14_atom_is_ambiguous": [NUM_RES, None],
"atom14_gt_exists": [NUM_RES, None],
"atom14_gt_positions": [NUM_RES, None, None],
"atom37_atom_exists": [NUM_RES, None],
"backbone_affine_mask": [NUM_RES],
"backbone_affine_tensor": [NUM_RES, None, None],
"bert_mask": [NUM_MSA_SEQ, NUM_RES],
"chi_angles": [NUM_RES, None],
"chi_mask": [NUM_RES, None],
"extra_deletion_value": [NUM_EXTRA_SEQ, NUM_RES],
"extra_has_deletion": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_mask": [NUM_EXTRA_SEQ, NUM_RES],
"extra_msa_row_mask": [NUM_EXTRA_SEQ],
"is_distillation": [],
"msa_feat": [NUM_MSA_SEQ, NUM_RES, None],
"msa_mask": [NUM_MSA_SEQ, NUM_RES],
"msa_row_mask": [NUM_MSA_SEQ],
"pseudo_beta": [NUM_RES, None],
"pseudo_beta_mask": [NUM_RES],
"residue_index": [NUM_RES],
"residx_atom14_to_atom37": [NUM_RES, None],
"residx_atom37_to_atom14": [NUM_RES, None],
"resolution": [],
"rigidgroups_alt_gt_frames": [NUM_RES, None, None, None],
"rigidgroups_group_exists": [NUM_RES, None],
"rigidgroups_group_is_ambiguous": [NUM_RES, None],
"rigidgroups_gt_exists": [NUM_RES, None],
"rigidgroups_gt_frames": [NUM_RES, None, None, None],
"seq_length": [],
"seq_mask": [NUM_RES],
"target_feat": [NUM_RES, None],
"template_aatype": [NUM_TEMPLATES, NUM_RES],
"template_all_atom_mask": [NUM_TEMPLATES, NUM_RES, None],
"template_all_atom_positions": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_alt_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_backbone_affine_mask": [NUM_TEMPLATES, NUM_RES],
"template_backbone_affine_tensor": [
NUM_TEMPLATES, NUM_RES, None, None,
],
"template_mask": [NUM_TEMPLATES],
"template_pseudo_beta": [NUM_TEMPLATES, NUM_RES, None],
"template_pseudo_beta_mask": [NUM_TEMPLATES, NUM_RES],
"template_sum_probs": [NUM_TEMPLATES, None],
"template_torsion_angles_mask": [
NUM_TEMPLATES, NUM_RES, None,
],
'unsupervised_features': [
'aatype', 'residue_index', 'msa', 'num_alignments',
'seq_length', 'between_segment_residues', 'deletion_matrix'
"template_torsion_angles_sin_cos": [
NUM_TEMPLATES, NUM_RES, None, None,
],
'use_templates': templates_enabled,
'use_template_torsion_angles': embed_template_torsion_angles,
'supervised_features': [
'all_atom_mask', 'all_atom_positions', 'resolution',
'use_clamped_fape',
"true_msa": [NUM_MSA_SEQ, NUM_RES],
"use_clamped_fape": [],
},
"masked_msa": {
"profile_prob": 0.1,
"same_prob": 0.1,
"uniform_prob": 0.1,
},
"max_extra_msa": 1024,
"msa_cluster_features": True,
"num_recycle": num_recycle,
"reduce_msa_clusters_by_max_templates": False,
"resample_msa_in_recycling": True,
"template_features": [
"template_all_atom_positions",
"template_sum_probs",
"template_aatype",
"template_all_atom_mask",
],
"unsupervised_features": [
"aatype",
"residue_index",
"msa",
"num_alignments",
"seq_length",
"between_segment_residues",
"deletion_matrix",
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
"supervised_features": [
"all_atom_mask",
"all_atom_positions",
"resolution",
"use_clamped_fape",
],
},
'predict': {
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': False,
'crop_size': None,
'supervised': False,
},
'eval': {
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': False,
'crop_size': None,
'supervised': True,
},
'train': {
'fixed_size': True,
'subsample_templates': True,
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': True,
'crop_size': 256,
'supervised': True,
},
'data_module': {
'use_small_bfd': False,
'data_loaders': {
'batch_size': 1,
'num_workers': 1,
"predict": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
"crop_size": None,
"supervised": False,
},
"eval": {
"fixed_size": True,
"subsample_templates": False, # We want top templates.
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_templates": 4,
"num_ensemble": 1,
"crop": False,
"crop_size": None,
"supervised": True,
},
"train": {
"fixed_size": True,
"subsample_templates": True,
"masked_msa_replace_fraction": 0.15,
"max_msa_clusters": 512,
"max_templates": 4,
"num_ensemble": 1,
"crop": True,
"crop_size": 256,
"supervised": True,
},
"data_module": {
"use_small_bfd": False,
"data_loaders": {
"batch_size": 1,
"num_workers": 1,
},
},
}
},
# Recurring FieldReferences that can be changed globally here
'globals': {
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'c_z': c_z,
'c_m': c_m,
'c_t': c_t,
'c_e': c_e,
'c_s': c_s,
'eps': eps,
},
'model': {
'num_recycle': num_recycle,
'_mask_trans': False,
'input_embedder': {
'tf_dim': 22,
'msa_dim': 49,
'c_z': c_z,
'c_m': c_m,
'relpos_k': 32,
},
'recycling_embedder': {
'c_z': c_z,
'c_m': c_m,
'min_bin': 3.25,
'max_bin': 20.75,
'no_bins': 15,
'inf': 1e8,
},
'template': {
'distogram': {
'min_bin': 3.25,
'max_bin': 50.75,
'no_bins': 39,
},
'template_angle_embedder': {
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
},
"model": {
"num_recycle": num_recycle,
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"recycling_embedder": {
"c_z": c_z,
"c_m": c_m,
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
},
"template_angle_embedder": {
# DISCREPANCY: c_in is supposed to be 51.
'c_in': 57,
'c_out': c_m,
"c_in": 57,
"c_out": c_m,
},
'template_pair_embedder': {
'c_in': 88,
'c_out': c_t,
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
},
'template_pair_stack': {
'c_t': c_t,
"template_pair_stack": {
"c_t": c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
'c_hidden_tri_att': 16,
'c_hidden_tri_mul': 64,
'no_blocks': 2,
'no_heads': 4,
'pair_transition_n': 2,
'dropout_rate': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
},
'template_pointwise_attention': {
'c_t': c_t,
'c_z': c_z,
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
'c_hidden': 16,
'no_heads': 4,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
},
'inf': 1e5,#1e9,
'eps': eps,#1e-6,
'enabled': templates_enabled,
'embed_angles': embed_template_torsion_angles,
},
'extra_msa': {
'extra_msa_embedder': {
'c_in': 25,
'c_out': c_e,
},
'extra_msa_stack': {
'c_m': c_e,
'c_z': c_z,
'c_hidden_msa_att': 8,
'c_hidden_opm': 32,
'c_hidden_mul': 128,
'c_hidden_pair_att': 32,
'no_heads_msa': 8,
'no_heads_pair': 4,
'no_blocks': 4,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
'eps': eps,#1e-10,
},
'enabled': True,
},
'evoformer_stack': {
'c_m': c_m,
'c_z': c_z,
'c_hidden_msa_att': 32,
'c_hidden_opm': 32,
'c_hidden_mul': 128,
'c_hidden_pair_att': 32,
'c_s': c_s,
'no_heads_msa': 8,
'no_heads_pair': 4,
'no_blocks': 48,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
'eps': eps,#1e-10,
},
'structure_module': {
'c_s': c_s,
'c_z': c_z,
'c_ipa': 16,
'c_resnet': 128,
'no_heads_ipa': 12,
'no_qk_points': 4,
'no_v_points': 8,
'dropout_rate': 0.1,
'no_blocks': 8,
'no_transition_layers': 1,
'no_resnet_blocks': 2,
'no_angles': 7,
'trans_scale_factor': 10,
'epsilon': eps,#1e-12,
'inf': 1e5,
},
'heads': {
'lddt': {
'no_bins': 50,
'c_in': c_s,
'c_hidden': 128,
},
'distogram': {
'c_z': c_z,
'no_bins': aux_distogram_bins,
},
'tm': {
'c_z': c_z,
'no_bins': aux_distogram_bins,
'enabled': False,
},
'masked_msa': {
'c_m': c_m,
'c_out': 23,
},
'experimentally_resolved': {
'c_s': c_s,
'c_out': 37,
},
},
},
'relax': {
'max_iterations': 0, # no max
'tolerance': 2.39,
'stiffness': 10.0,
'max_outer_iterations': 20,
'exclude_residues': [],
},
'loss': {
'distogram': {
'min_bin': 2.3125,
'max_bin': 21.6875,
'no_bins': 64,
'eps': eps,#1e-6,
'weight': 0.3,
},
'experimentally_resolved': {
'eps': eps,#1e-8,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.,
},
'fape': {
'backbone': {
'clamp_distance': 10.,
'loss_unit_distance': 10.,
'weight': 0.5,
},
'sidechain': {
'clamp_distance': 10.,
'length_scale': 10.,
'weight': 0.5,
},
'eps': 1e-4,
'weight': 1.0,
},
'lddt': {
'min_resolution': 0.1,
'max_resolution': 3.0,
'cutoff': 15.,
'no_bins': 50,
'eps': eps,#1e-10,
'weight': 0.01,
},
'masked_msa': {
'eps': eps,#1e-8,
'weight': 2.0,
},
'supervised_chi': {
'chi_weight': 0.5,
'angle_norm_weight': 0.01,
'eps': eps,#1e-6,
'weight': 1.0,
},
'violation': {
'violation_tolerance_factor': 12.0,
'clash_overlap_tolerance': 1.5,
'eps': eps,#1e-6,
'weight': 0.,
},
'tm': {
'max_bin': 31,
'no_bins': 64,
'min_resolution': 0.1,
'max_resolution': 3.0,
'eps': eps,#1e-8,
'weight': 0.,
},
'eps': eps,
},
'ema': {
'decay': 0.999
},
})
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
},
"inf": 1e5, # 1e9,
"eps": eps, # 1e-6,
"enabled": templates_enabled,
"embed_angles": embed_template_torsion_angles,
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-10,
},
"enabled": True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5, # 1e9,
"eps": eps, # 1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps, # 1e-12,
"inf": 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": False,
},
"masked_msa": {
"c_m": c_m,
"c_out": 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
},
},
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps, # 1e-6,
"weight": 0.3,
},
"experimentally_resolved": {
"eps": eps, # 1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.0,
},
"fape": {
"backbone": {
"clamp_distance": 10.0,
"loss_unit_distance": 10.0,
"weight": 0.5,
},
"sidechain": {
"clamp_distance": 10.0,
"length_scale": 10.0,
"weight": 0.5,
},
"eps": 1e-4,
"weight": 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.0,
"no_bins": 50,
"eps": eps, # 1e-10,
"weight": 0.01,
},
"masked_msa": {
"eps": eps, # 1e-8,
"weight": 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps, # 1e-6,
"weight": 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": eps, # 1e-6,
"weight": 0.0,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps, # 1e-8,
"weight": 0.0,
},
"eps": eps,
},
"ema": {"decay": 0.999},
}
)
......@@ -27,45 +27,45 @@ from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features(
sequence: str,
description: str,
num_res: int
sequence: str, description: str, num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True
map_unknown_to_x=True,
)
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array(
[description.encode('utf-8')], dtype=np.object_
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array(
[sequence.encode('utf-8')], dtype=np.object_
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject,
chain_id: str
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id])
description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(make_sequence_features(
mmcif_feats.update(
make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
))
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
......@@ -78,7 +78,7 @@ def make_mmcif_features(
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
return mmcif_feats
......@@ -86,17 +86,20 @@ def make_mmcif_features(
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
......@@ -109,17 +112,19 @@ def make_msa_features(
num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features["msa"] = np.array(int_msa, dtype=np.int32)
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
return features
class AlignmentRunner:
""" Runs alignment tools and saves the results """
def __init__(self,
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
hhsearch_binary_path: str,
......@@ -161,105 +166,109 @@ class AlignmentRunner:
)
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path]
binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
def run(self,
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0]
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits
jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'w') as f:
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits
jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'w') as f:
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f:
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0]
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto'])
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path)
if(output_dir is not None):
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m'])
hhblits_bfd_uniclust_result = (
self.hhblits_bfd_uniclust_runner.query(fasta_path)
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline:
"""Assembles input features."""
def __init__(self,
def __init__(
self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
):
self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self,
def _parse_alignment_output(
self,
alignment_dir: str,
) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'r') as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(
f.read()
)
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "r") as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'r') as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "r") as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'r') as f:
hhsearch_hits = parsers.parse_hhr(
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "r") as f:
hhsearch_hits = parsers.parse_hhr(f.read())
if(self.use_small_bfd):
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'r') as f:
if self.use_small_bfd:
bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
)
else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
f.read()
)
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
return {
'uniref90_msa': uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits,
'bfd_msa': bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix,
"uniref90_msa": uniref90_msa,
"uniref90_deletion_matrix": uniref90_deletion_matrix,
"mgnify_msa": mgnify_msa,
"mgnify_deletion_matrix": mgnify_deletion_matrix,
"hhsearch_hits": hhsearch_hits,
"bfd_msa": bfd_msa,
"bfd_deletion_matrix": bfd_deletion_matrix,
}
def process_fasta(self,
def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
......@@ -269,7 +278,8 @@ class DataPipeline:
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {fasta_path}.')
f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
......@@ -280,30 +290,31 @@ class DataPipeline:
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=alignments['hhsearch_hits']
hits=alignments["hhsearch_hits"],
)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res
num_res=num_res,
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
return {**sequence_features, **msa_features, **templates_result.data}
def process_mmcif(self,
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
......@@ -314,13 +325,11 @@ class DataPipeline:
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if(chain_id is None):
if chain_id is None:
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if(chain is None):
raise ValueError(
'No chains in mmCIF file'
)
if chain is None:
raise ValueError("No chains in mmCIF file")
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
......@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits']
hits=alignments["hhsearch_hits"],
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**mmcif_feats, **templates_result.data, **msa_features}
......@@ -23,13 +23,23 @@ import torch
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.tools import residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import tree_map, tensor_tree_map, batched_gather
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
batched_gather,
)
MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'
"msa",
"deletion_matrix",
"msa_mask",
"msa_row_mask",
"bert_mask",
"true_msa",
]
def cast_to_64bit_ints(protein):
# We keep all ints as int64
for k, v in protein.items():
......@@ -37,21 +47,27 @@ def cast_to_64bit_ints(protein):
protein[k] = v.type(torch.int64)
return protein
def make_one_hot(x, num_classes):
x_one_hot = torch.zeros(*x.shape, num_classes)
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
return x_one_hot
def make_seq_mask(protein):
protein['seq_mask'] = torch.ones(protein['aatype'].shape, dtype=torch.float32)
protein["seq_mask"] = torch.ones(
protein["aatype"].shape, dtype=torch.float32
)
return protein
def make_template_mask(protein):
protein['template_mask'] = torch.ones(
protein['template_aatype'].shape[0], dtype=torch.float32
protein["template_mask"] = torch.ones(
protein["template_aatype"].shape[0], dtype=torch.float32
)
return protein
def curry1(f):
"""Supply all arguments but the first."""
......@@ -60,137 +76,167 @@ def curry1(f):
return fc
@curry1
def add_distillation_flag(protein, distillation):
protein['is_distillation'] = torch.tensor(
protein["is_distillation"] = torch.tensor(
float(distillation), dtype=torch.float32
)
return protein
def make_all_atom_aatype(protein):
protein['all_atom_aatype'] = protein['aatype']
protein["all_atom_aatype"] = protein["aatype"]
return protein
def fix_templates_aatype(protein):
# Map one-hot to indices
num_templates = protein['template_aatype'].shape[0]
protein['template_aatype'] = torch.argmax(protein['template_aatype'], dim=-1)
num_templates = protein["template_aatype"].shape[0]
protein["template_aatype"] = torch.argmax(
protein["template_aatype"], dim=-1
)
# Map hhsearch-aatype to our aatype.
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64
).expand(num_templates, -1)
protein['template_aatype'] = torch.gather(
new_order, 1, index=protein['template_aatype']
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
num_templates, -1
)
protein["template_aatype"] = torch.gather(
new_order, 1, index=protein["template_aatype"]
)
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
[new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype
).transpose(0,1)
protein['msa'] = torch.gather(new_order, 0, protein['msa'])
[new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
).transpose(0, 1)
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
perm_matrix = np.zeros((22, 22), dtype=np.float32)
perm_matrix[range(len(new_order_list)), new_order_list] = 1.
perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
for k in protein:
if 'profile' in k:
if "profile" in k:
num_dim = protein[k].shape.as_list()[-1]
assert num_dim in [20,21,22], (
'num_dim for %s out of expected range: %s' % (k, num_dim))
assert num_dim in [
20,
21,
22,
], "num_dim for %s out of expected range: %s" % (k, num_dim)
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
return protein
def squeeze_features(protein):
"""Remove singleton and repeated dimensions in protein features."""
protein['aatype'] = torch.argmax(protein['aatype'], dim=-1)
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
for k in [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_mask']:
"domain_name",
"msa",
"num_alignments",
"seq_length",
"sequence",
"superfamily",
"deletion_matrix",
"resolution",
"between_segment_residues",
"residue_index",
"template_all_atom_mask",
]:
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
protein[k] = torch.squeeze(protein[k], dim=-1)
for k in ['seq_length', 'num_alignments']:
for k in ["seq_length", "num_alignments"]:
if k in protein:
protein[k] = protein[k][0]
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a portion of the MSA with 'X'."""
msa_mask = (torch.rand(protein['msa'].shape) < replace_proportion)
msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
x_idx = 20
gap_idx = 21
msa_mask = torch.logical_and(msa_mask, protein['msa'] != gap_idx)
protein['msa'] = torch.where(msa_mask, torch.ones_like(protein['msa'])*x_idx,
protein['msa'])
aatype_mask = (
torch.rand(protein['aatype'].shape) < replace_proportion
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
protein["msa"] = torch.where(
msa_mask, torch.ones_like(protein["msa"]) * x_idx, protein["msa"]
)
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
protein['aatype'] = torch.where(
aatype_mask, torch.ones_like(protein['aatype']) * x_idx,
protein['aatype']
protein["aatype"] = torch.where(
aatype_mask,
torch.ones_like(protein["aatype"]) * x_idx,
protein["aatype"],
)
return protein
@curry1
def sample_msa(protein, max_seq, keep_extra):
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.
"""
num_seq = protein['msa'].shape[0]
shuffled = torch.randperm(num_seq-1)+1
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
num_seq = protein["msa"].shape[0]
shuffled = torch.randperm(num_seq - 1) + 1
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
num_sel = min(max_seq, num_seq)
sel_seq, not_sel_seq = torch.split(index_order, [num_sel, num_seq-num_sel])
sel_seq, not_sel_seq = torch.split(
index_order, [num_sel, num_seq - num_sel]
)
for k in MSA_FEATURE_NAMES:
if k in protein:
if keep_extra:
protein['extra_'+k] = torch.index_select(protein[k], 0, not_sel_seq)
protein["extra_" + k] = torch.index_select(
protein[k], 0, not_sel_seq
)
protein[k] = torch.index_select(protein[k], 0, sel_seq)
return protein
@curry1
def crop_extra_msa(protein, max_extra_msa):
num_seq = protein['extra_msa'].shape[0]
num_seq = protein["extra_msa"].shape[0]
num_sel = min(max_extra_msa, num_seq)
select_indices = torch.randperm(num_seq)[:num_sel]
for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
protein['extra_'+k] = torch.index_select(protein['extra_'+k], 0, select_indices)
if "extra_" + k in protein:
protein["extra_" + k] = torch.index_select(
protein["extra_" + k], 0, select_indices
)
return protein
def delete_extra_msa(protein):
for k in MSA_FEATURE_NAMES:
if 'extra_' + k in protein:
del protein['extra_' + k]
if "extra_" + k in protein:
del protein["extra_" + k]
return protein
# Not used in inference
@curry1
def block_delete_msa(protein, config):
num_seq = protein['msa'].shape[0]
num_seq = protein["msa"].shape[0]
block_num_seq = torch.floor(
torch.tensor(
num_seq, dtype=torch.float32
) * config.msa_fraction_per_block
torch.tensor(num_seq, dtype=torch.float32)
* config.msa_fraction_per_block
).to(torch.int32)
if config.randomize_num_blocks:
nb = torch.distributions.uniform.Uniform(0, config.num_blocks+1).sample()
nb = torch.distributions.uniform.Uniform(
0, config.num_blocks + 1
).sample()
else:
nb = config.num_blocks
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
del_blocks = torch.clip(del_blocks, 0, num_seq-1)
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
# Make sure we keep the original sequence
......@@ -206,19 +252,19 @@ def block_delete_msa(protein, config):
return protein
@curry1
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
weights = torch.cat([
torch.ones(21),
gap_agreement_weight * torch.ones(1),
torch.zeros(1)
], 0)
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
weights = torch.cat(
[torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
0,
)
# Make agreement score as weighted Hamming distance
msa_one_hot = make_one_hot(protein['msa'], 23)
sample_one_hot = (protein['msa_mask'][:,:,None] * msa_one_hot)
extra_msa_one_hot = make_one_hot(protein['extra_msa'], 23)
extra_one_hot = (protein['extra_msa_mask'][:,:,None] * extra_msa_one_hot)
msa_one_hot = make_one_hot(protein["msa"], 23)
sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
num_seq, num_res, _ = sample_one_hot.shape
extra_num_seq, _, _ = extra_one_hot.shape
......@@ -226,17 +272,20 @@ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.):
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
# in an optimized fashion to avoid possible memory or computation blowup.
agreement = torch.matmul(
torch.reshape(extra_one_hot, [extra_num_seq, num_res*23]),
torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
torch.reshape(
sample_one_hot * weights, [num_seq, num_res * 23]
).transpose(0, 1),
)
# Assign each sequence in the extra sequences to the closest MSA sample
protein['extra_cluster_assignment'] = torch.argmax(agreement, dim=1).to(torch.int64)
protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
torch.int64
)
return protein
def unsorted_segment_sum(data, segment_ids, num_segments):
"""
Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.
......@@ -264,123 +313,145 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
tensor = tensor.type(data.dtype)
return tensor
@curry1
def summarize_clusters(protein):
"""Produce profile and deletion_matrix_mean within each cluster."""
num_seq = protein['msa'].shape[0]
num_seq = protein["msa"].shape[0]
def csum(x):
return unsorted_segment_sum(
x, protein['extra_cluster_assignment'], num_seq
x, protein["extra_cluster_assignment"], num_seq
)
mask = protein['extra_msa_mask']
mask_counts = 1e-6 + protein['msa_mask'] + csum(mask) # Include center
mask = protein["extra_msa_mask"]
mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
msa_sum = csum(mask[:, :, None] * make_one_hot(protein['extra_msa'], 23))
msa_sum += make_one_hot(protein['msa'], 23) # Original sequence
protein['cluster_profile'] = msa_sum / mask_counts[:, :, None]
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
del msa_sum
del_sum = csum(mask * protein['extra_deletion_matrix'])
del_sum += protein['deletion_matrix'] # Original sequence
protein['cluster_deletion_mean'] = del_sum / mask_counts
del_sum = csum(mask * protein["extra_deletion_matrix"])
del_sum += protein["deletion_matrix"] # Original sequence
protein["cluster_deletion_mean"] = del_sum / mask_counts
del del_sum
return protein
def make_msa_mask(protein):
"""Mask features are all ones, but will later be zero-padded."""
protein['msa_mask'] = torch.ones(protein['msa'].shape, dtype=torch.float32)
protein['msa_row_mask'] = torch.ones(protein['msa'].shape[0], dtype=torch.float32)
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
protein["msa_row_mask"] = torch.ones(
protein["msa"].shape[0], dtype=torch.float32
)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
is_gly = torch.eq(aatype, rc.restype_order['G'])
ca_idx = rc.atom_order['CA']
cb_idx = rc.atom_order['CB']
is_gly = torch.eq(aatype, rc.restype_order["G"])
ca_idx = rc.atom_order["CA"]
cb_idx = rc.atom_order["CB"]
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
all_atom_positions[..., cb_idx, :],
)
if all_atom_mask is not None:
pseudo_beta_mask = torch.where(
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx])
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
)
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
@curry1
def make_pseudo_beta(protein, prefix=''):
def make_pseudo_beta(protein, prefix=""):
"""Create pseudo-beta (alpha for glycine) position and mask."""
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_mask' if prefix else 'all_atom_mask']))
assert prefix in ["", "template_"]
(
protein[prefix + "pseudo_beta"],
protein[prefix + "pseudo_beta_mask"],
) = pseudo_beta_fn(
protein["template_aatype" if prefix else "aatype"],
protein[prefix + "all_atom_positions"],
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
)
return protein
@curry1
def add_constant_field(protein, key, value):
protein[key] = torch.tensor(value)
return protein
def shaped_categorical(probs, epsilon=1e-10):
ds = probs.shape
num_classes = ds[-1]
distribution = torch.distributions.categorical.Categorical(
torch.reshape(probs+epsilon,[-1, num_classes])
torch.reshape(probs + epsilon, [-1, num_classes])
)
counts = distribution.sample()
return torch.reshape(counts, ds[:-1])
def make_hhblits_profile(protein):
"""Compute the HHblits MSA profile if not already present."""
if 'hhblits_profile' in protein:
if "hhblits_profile" in protein:
return protein
# Compute the profile for every residue (over all MSA sequences).
msa_one_hot = make_one_hot(protein['msa'], 22)
msa_one_hot = make_one_hot(protein["msa"], 22)
protein['hhblits_profile'] = torch.mean(msa_one_hot, dim=0)
protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
return protein
@curry1
def make_masked_msa(protein, config, replace_fraction):
"""Create data for BERT on raw MSA."""
# Add a random amino acid uniformly.
random_aa = torch.tensor([0.05] * 20 + [0., 0.], dtype=torch.float32)
random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
categorical_probs = (
config.uniform_prob * random_aa +
config.profile_prob * protein['hhblits_profile'] +
config.same_prob * make_one_hot(protein['msa'], 22))
config.uniform_prob * random_aa
+ config.profile_prob * protein["hhblits_profile"]
+ config.same_prob * make_one_hot(protein["msa"], 22)
)
# Put all remaining probability on [MASK] which is a new column
pad_shapes = list(reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))]))
pad_shapes = list(
reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
)
pad_shapes[1] = 1
mask_prob = 1. - config.profile_prob - config.same_prob - config.uniform_prob
assert mask_prob >= 0.
mask_prob = (
1.0 - config.profile_prob - config.same_prob - config.uniform_prob
)
assert mask_prob >= 0.0
categorical_probs = torch.nn.functional.pad(
categorical_probs, pad_shapes, value=mask_prob
)
sh = protein['msa'].shape
sh = protein["msa"].shape
mask_position = torch.rand(sh) < replace_fraction
bert_msa = shaped_categorical(categorical_probs)
bert_msa = torch.where(mask_position, bert_msa, protein['msa'])
bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
# Mix real and masked MSA
protein['bert_mask'] = mask_position.to(torch.float32)
protein['true_msa'] = protein['msa']
protein['msa'] = bert_msa
protein["bert_mask"] = mask_position.to(torch.float32)
protein["true_msa"] = protein["msa"]
protein["msa"] = bert_msa
return protein
@curry1
def make_fixed_size(
protein,
......@@ -388,7 +459,7 @@ def make_fixed_size(
msa_cluster_size,
extra_msa_size,
num_res=0,
num_templates=0
num_templates=0,
):
"""Guess at the MSA and sequence dimension to make fixed size."""
......@@ -401,14 +472,12 @@ def make_fixed_size(
for k, v in protein.items():
# Don't transfer this to the accelerator.
if k == 'extra_cluster_assignment':
if k == "extra_cluster_assignment":
continue
shape = list(v.shape)
schema = shape_schema[k]
msg = "Rank mismatch between shape and shape schema for"
assert len(shape) == len(schema), (
f'{msg} {k}: {shape} vs {schema}'
)
assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
pad_size = [
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
]
......@@ -422,24 +491,27 @@ def make_fixed_size(
return protein
@curry1
def make_msa_feat(protein):
"""Create and concatenate MSA features."""
# Whether there is a domain break. Always zero for chains, but keeping for
# compatibility with domain datasets.
has_break = torch.clip(
protein['between_segment_residues'].to(torch.float32), 0, 1
protein["between_segment_residues"].to(torch.float32), 0, 1
)
aatype_1hot = make_one_hot(protein['aatype'], 21)
aatype_1hot = make_one_hot(protein["aatype"], 21)
target_feat = [
torch.unsqueeze(has_break, dim=-1),
aatype_1hot, # Everyone gets the original sequence.
]
msa_1hot = make_one_hot(protein['msa'], 23)
has_deletion = torch.clip(protein['deletion_matrix'], 0., 1.)
deletion_value = torch.atan(protein['deletion_matrix'] / 3.) * (2. / np.pi)
msa_1hot = make_one_hot(protein["msa"], 23)
has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
2.0 / np.pi
)
msa_feat = [
msa_1hot,
......@@ -447,24 +519,27 @@ def make_msa_feat(protein):
torch.unsqueeze(deletion_value, dim=-1),
]
if 'cluster_profile' in protein:
deletion_mean_value = (
torch.atan(protein['cluster_deletion_mean'] / 3.) * (2. / np.pi)
)
msa_feat.extend([protein['cluster_profile'],
if "cluster_profile" in protein:
deletion_mean_value = torch.atan(
protein["cluster_deletion_mean"] / 3.0
) * (2.0 / np.pi)
msa_feat.extend(
[
protein["cluster_profile"],
torch.unsqueeze(deletion_mean_value, dim=-1),
])
]
)
if 'extra_deletion_matrix' in protein:
protein['extra_has_deletion'] = torch.clip(
protein['extra_deletion_matrix'], 0., 1.
if "extra_deletion_matrix" in protein:
protein["extra_has_deletion"] = torch.clip(
protein["extra_deletion_matrix"], 0.0, 1.0
)
protein['extra_deletion_value'] = torch.atan(
protein['extra_deletion_matrix'] / 3.
) * (2. / np.pi)
protein["extra_deletion_value"] = torch.atan(
protein["extra_deletion_matrix"] / 3.0
) * (2.0 / np.pi)
protein['msa_feat'] = torch.cat(msa_feat, dim=-1)
protein['target_feat'] = torch.cat(target_feat, dim=-1)
protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
protein["target_feat"] = torch.cat(target_feat, dim=-1)
return protein
......@@ -476,7 +551,7 @@ def select_feat(protein, feature_list):
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
if k.startswith('template_'):
if k.startswith("template_"):
protein[k] = v[:max_templates]
return protein
......@@ -488,57 +563,58 @@ def make_atom14_masks(protein):
restype_atom14_mask = []
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[
rc.restype_1to3[rt]
]
restype_atom14_to_atom37.append([
(rc.atom_order[name] if name else 0)
for name in atom_names
])
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
restype_atom14_to_atom37.append(
[(rc.atom_order[name] if name else 0) for name in atom_names]
)
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
restype_atom37_to_atom14.append(
[
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in rc.atom_types
])
]
)
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
restype_atom14_mask.append(
[(1.0 if name else 0.0) for name in atom_names]
)
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_mask.append([0.0] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein['aatype'].device,
device=protein["aatype"].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein['aatype'].device,
device=protein["aatype"].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask,
dtype=torch.float32,
device=protein['aatype'].device,
device=protein["aatype"].device,
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein['aatype']]
residx_atom14_mask = restype_atom14_mask[protein['aatype']]
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein["aatype"]]
residx_atom14_mask = restype_atom14_mask[protein["aatype"]]
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37.long()
protein["atom14_atom_exists"] = residx_atom14_mask
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein['aatype']]
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14.long()
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein["aatype"]]
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros(
[21, 37], dtype=torch.float32, device=protein['aatype'].device
[21, 37], dtype=torch.float32, device=protein["aatype"].device
)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
......@@ -547,8 +623,8 @@ def make_atom14_masks(protein):
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = restype_atom37_mask[protein['aatype']]
protein['atom37_atom_exists'] = residx_atom37_mask
residx_atom37_mask = restype_atom37_mask[protein["aatype"]]
protein["atom37_atom_exists"] = residx_atom37_mask
return protein
......@@ -570,7 +646,7 @@ def make_atom14_positions(protein):
protein["all_atom_mask"],
residx_atom14_to_atom37,
dim=-1,
no_batch_dims=len(protein["all_atom_mask"].shape[:-1])
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
)
# Gather the ground truth positions.
......@@ -579,7 +655,7 @@ def make_atom14_positions(protein):
protein["all_atom_positions"],
residx_atom14_to_atom37,
dim=-2,
no_batch_dims=len(protein["all_atom_positions"].shape[:-2])
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
)
)
......@@ -589,9 +665,7 @@ def make_atom14_positions(protein):
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [
rc.restype_1to3[res] for res in rc.restypes
]
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
......@@ -599,21 +673,26 @@ def make_atom14_positions(protein):
res: torch.eye(
14,
dtype=protein["all_atom_mask"].dtype,
device=protein["all_atom_mask"].device
) for res in restype_3
device=protein["all_atom_mask"].device,
)
for res in restype_3
}
for resname, swap in rc.residue_atom_renaming_swaps.items():
correspondences = torch.arange(14, device=protein["all_atom_mask"].device)
correspondences = torch.arange(
14, device=protein["all_atom_mask"].device
)
for source_atom_swap, target_atom_swap in swap.items():
source_index = rc.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = rc.restype_name_to_atom14_names[
resname].index(target_atom_swap)
source_index = rc.restype_name_to_atom14_names[resname].index(
source_atom_swap
)
target_index = rc.restype_name_to_atom14_names[resname].index(
target_atom_swap
)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
renaming_matrix[index, correspondence] = 1.0
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3]
......@@ -625,9 +704,7 @@ def make_atom14_positions(protein):
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = torch.einsum(
"...rac,...rab->...rbc",
residx_atom14_gt_positions,
renaming_transform
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
)
protein["atom14_alt_gt_positions"] = alternative_gt_positions
......@@ -635,9 +712,7 @@ def make_atom14_positions(protein):
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = torch.einsum(
"...ra,...rab->...rb",
residx_atom14_gt_mask,
renaming_transform
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
)
protein["atom14_alt_gt_exists"] = alternative_gt_mask
......@@ -645,19 +720,20 @@ def make_atom14_positions(protein):
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
for resname, swap in rc.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = rc.restype_order[
rc.restype_3to1[resname]]
restype = rc.restype_order[rc.restype_3to1[resname]]
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
atom_name1)
atom_name1
)
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
atom_name2)
atom_name2
)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
protein["atom14_atom_is_ambiguous"] = (
restype_atom14_is_ambiguous[protein["aatype"]]
)
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
protein["aatype"]
]
return protein
......@@ -669,14 +745,14 @@ def atom37_to_frames(protein):
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if(rc.chi_angles_mask[restype][chi_idx]):
if rc.chi_angles_mask[restype][chi_idx]:
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
......@@ -687,12 +763,12 @@ def atom37_to_frames(protein):
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
rc.chi_angles_mask
)
lookuptable = rc.atom_order.copy()
lookuptable[''] = 0
lookuptable[""] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
......@@ -702,8 +778,7 @@ def atom37_to_frames(protein):
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims),
*restype_rigidgroup_base_atom37_idx.shape
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
)
)
......@@ -739,13 +814,11 @@ def atom37_to_frames(protein):
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
no_batch_dims=len(all_atom_mask.shape[:-1]),
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
......@@ -764,9 +837,7 @@ def atom37_to_frames(protein):
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
restype = rc.restype_order[rc.restype_3to1[resname]]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
......@@ -791,11 +862,11 @@ def atom37_to_frames(protein):
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
protein['rigidgroups_gt_frames'] = gt_frames_tensor
protein['rigidgroups_gt_exists'] = gt_exists
protein['rigidgroups_group_exists'] = group_exists
protein['rigidgroups_group_is_ambiguous'] = residx_rigidgroup_is_ambiguous
protein['rigidgroups_alt_gt_frames'] = alt_gt_frames_tensor
protein["rigidgroups_gt_frames"] = gt_frames_tensor
protein["rigidgroups_gt_exists"] = gt_exists
protein["rigidgroups_group_exists"] = group_exists
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
return protein
......@@ -815,10 +886,11 @@ def get_chi_atom_indices():
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle])
atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
atom_indices.append(
[0, 0, 0, 0]
) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
......@@ -829,7 +901,7 @@ def get_chi_atom_indices():
@curry1
def atom37_to_torsion_angles(
protein,
prefix='',
prefix="",
):
"""
Convert coordinates to torsion angles.
......@@ -873,35 +945,27 @@ def atom37_to_torsion_angles(
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 1:3, :],
all_atom_positions[..., :2, :]
], dim=-2
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
dim=-2,
)
phi_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 2:3, :],
all_atom_positions[..., :3, :]
], dim=-2
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
dim=-2,
)
psi_atom_pos = torch.cat(
[
all_atom_positions[..., :3, :],
all_atom_positions[..., 4:5, :]
], dim=-2
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
dim=-2,
)
pre_omega_mask = (
torch.prod(prev_all_atom_mask[..., 1:3], dim=-1) *
torch.prod(all_atom_mask[..., :2], dim=-1)
)
phi_mask = (
prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
pre_omega_mask = torch.prod(
prev_all_atom_mask[..., 1:3], dim=-1
) * torch.prod(all_atom_mask[..., :2], dim=-1)
phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) *
all_atom_mask[..., 4]
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
* all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
......@@ -914,7 +978,7 @@ def atom37_to_torsion_angles(
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
......@@ -923,7 +987,7 @@ def atom37_to_torsion_angles(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2])
no_batch_dims=len(atom_indices.shape[:-2]),
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
......@@ -936,7 +1000,8 @@ def atom37_to_torsion_angles(
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
], dim=-3
],
dim=-3,
)
torsion_angles_mask = torch.cat(
......@@ -945,7 +1010,8 @@ def atom37_to_torsion_angles(
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
], dim=-1
],
dim=-1,
)
torsion_frames = T.from_3_points(
......@@ -968,13 +1034,14 @@ def atom37_to_torsion_angles(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True
) + 1e-8
keepdims=True,
)
+ 1e-8
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.],
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
......@@ -984,8 +1051,9 @@ def atom37_to_torsion_angles(
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous
], dim=-1
1.0 - 2.0 * chi_is_ambiguous,
],
dim=-1,
)
alt_torsion_angles_sin_cos = (
......@@ -1001,12 +1069,10 @@ def atom37_to_torsion_angles(
def get_backbone_frames(protein):
# TODO: Verify that this is correct
protein["backbone_affine_tensor"] = (
protein["rigidgroups_gt_frames"][..., 0, :, :]
)
protein["backbone_affine_mask"] = (
protein["rigidgroups_gt_exists"][..., 0]
)
protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][
..., 0, :, :
]
protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0]
return protein
......@@ -1029,32 +1095,37 @@ def random_crop_to_size(
shape_schema,
subsample_templates=False,
seed=None,
batch_mode='clamped'
batch_mode="clamped",
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length']
if 'template_mask' in protein:
num_templates = protein['template_mask'].shape[-1]
seq_length = protein["seq_length"]
if "template_mask" in protein:
num_templates = protein["template_mask"].shape[-1]
else:
num_templates = protein['aatype'].new_zeros((1,))
num_templates = protein["aatype"].new_zeros((1,))
num_res_crop_size = min(seq_length, crop_size)
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein['seq_length'].device)
if(seed is not None):
g = torch.Generator(device=protein["seq_length"].device)
if seed is not None:
g.manual_seed(seed)
def _randint(lower, upper):
return int(torch.randint(
lower, upper, (1,),
device=protein['seq_length'].device, generator=g
)[0])
return int(
torch.randint(
lower,
upper,
(1,),
device=protein["seq_length"].device,
generator=g,
)[0]
)
if subsample_templates:
templates_crop_start = _randint(0, num_templates + 1)
templates_select_indices = torch.randperm(
num_templates, device=protein['seq_length'].device, generator=g
num_templates, device=protein["seq_length"].device, generator=g
)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
......@@ -1064,9 +1135,9 @@ def random_crop_to_size(
num_templates_crop_size = num_templates
n = seq_length - num_res_crop_size
if(batch_mode == 'clamped'):
if batch_mode == "clamped":
right_anchor = n + 1
elif(batch_mode == 'unclamped'):
elif batch_mode == "unclamped":
x = _randint(0, n)
right_anchor = n - x + 1
else:
......@@ -1075,20 +1146,19 @@ def random_crop_to_size(
num_res_crop_start = _randint(0, right_anchor)
for k, v in protein.items():
if (k not in shape_schema or
('template' not in k and NUM_RES not in shape_schema[k])
if k not in shape_schema or (
"template" not in k and NUM_RES not in shape_schema[k]
):
continue
# randomly permute the templates before cropping them.
if k.startswith('template') and subsample_templates:
if k.startswith("template") and subsample_templates:
v = v[templates_select_indices]
slices = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k],
v.shape)):
is_num_res = (dim_size == NUM_RES)
if i == 0 and k.startswith('template'):
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
is_num_res = dim_size == NUM_RES
if i == 0 and k.startswith("template"):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
......@@ -1097,7 +1167,5 @@ def random_crop_to_size(
slices.append(slice(crop_start, crop_start + crop_size))
protein[k] = v[slices]
protein['seq_length'] = (
protein['seq_length'].new_tensor(num_res_crop_size)
)
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
return protein
......@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
......@@ -54,7 +55,7 @@ def make_data_config(
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if(mode_cfg.crop_size is None):
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
......@@ -62,7 +63,7 @@ def make_data_config(
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if(cfg[mode].supervised):
if cfg[mode].supervised:
feature_names += cfg.common.supervised_features
return cfg, feature_names
......@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode: str,
):
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(
config, mode=mode, num_res=num_res
)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32)
)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
).astype(np.float32)
if batch_mode == 'clamped':
np_example['use_clamped_fape'] = (
np.array(1.).astype(np.float32)
)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
if batch_mode == "clamped":
np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
elif batch_mode == "unclamped":
np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode,
tensor_dict,
cfg.common,
cfg[mode],
batch_mode=batch_mode,
)
return {k: v for k, v in features.items()}
class FeaturePipeline:
def __init__(self,
def __init__(
self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
):
self.config = config
self.params = params
def process_features(self,
def process_features(
self,
raw_features: FeatureDict,
mode: str = 'train',
batch_mode: str = 'clamped',
mode: str = "train",
batch_mode: str = "clamped",
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
......
......@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_hhblits_profile,
]
if common_cfg.use_templates:
transforms.extend([
transforms.extend(
[
data_transforms.fix_templates_aatype,
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_')
])
if(common_cfg.use_template_torsion_angles):
transforms.extend([
data_transforms.atom37_to_torsion_angles('template_'),
])
transforms.extend([
data_transforms.make_pseudo_beta("template_"),
]
)
if common_cfg.use_template_torsion_angles:
transforms.extend(
[
data_transforms.atom37_to_torsion_angles("template_"),
]
)
transforms.extend(
[
data_transforms.make_atom14_masks,
])
]
)
if(mode_cfg.supervised):
transforms.extend([
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(''),
data_transforms.make_pseudo_beta(''),
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
])
]
)
return transforms
......@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms.sample_msa(max_msa_clusters, keep_extra=True)
)
if 'masked_msa' in common_cfg:
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
)
)
......@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size(
transforms.append(
data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=torch.Generator().seed()
))
transforms.append(data_transforms.make_fixed_size(
seed=torch.Generator().seed(),
)
)
transforms.append(
data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates
))
mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
......@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped'
tensors, common_cfg, mode_cfg, batch_mode="clamped"
):
"""Based on the config, apply filters and transformations to the data."""
......@@ -136,12 +147,10 @@ def process_tensors_from_config(
d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns)
d['ensemble_index'] = i
d["ensemble_index"] = i
return fn(d)
tensors = compose(
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble
......@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x),
torch.arange(num_ensemble))
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else:
tensors = tree.map_structure(lambda x: x[None], tensors_0)
......
......@@ -90,6 +90,7 @@ class MmcifObject:
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PdbStructure
......@@ -107,6 +108,7 @@ class ParsingResult:
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
......@@ -115,8 +117,9 @@ class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(prefix: str,
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
def mmcif_loop_to_list(
prefix: str, parsed_info: MmCIFDict
) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
......@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str,
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols)
"mmCIF error: Not all loops are the same length: %s" % cols
)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(prefix: str,
def mmcif_loop_to_dict(
prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
......@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str,
return {entry[index]: entry for entry in entries}
def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
......@@ -188,7 +192,7 @@ def parse(*,
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
full_structure = parser.get_structure("", handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
......@@ -206,9 +210,12 @@ def parse(*,
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()}
None, {(file_id, ""): "No protein chains found in this file."}
)
seq_start_num = {
chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()
}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
......@@ -217,34 +224,42 @@ def parse(*,
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
if atom.model_num != "1":
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = ' '
if atom.hetatm_atom == 'HETATM':
hetflag = " "
if atom.hetatm_atom == "HETATM":
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W'
if atom.residue_name in ("HOH", "WAT"):
hetflag = "W"
else:
hetflag = 'H_' + atom.residue_name
hetflag = "H_" + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = ' '
position = ResiduePosition(chain_id=atom.author_chain_id,
insertion_code = " "
position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code)
seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
current = seq_to_structure_mappings.get(atom.author_chain_id, {})
current[seq_idx] = ResidueAtPosition(position=position,
insertion_code=insertion_code,
)
seq_idx = (
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
)
current = seq_to_structure_mappings.get(
atom.author_chain_id, {}
)
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag)
hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
......@@ -253,19 +268,21 @@ def parse(*,
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(position=None,
current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id,
is_missing=True,
hetflag=' ')
hetflag=" ",
)
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
seq.append(code if len(code) == 1 else 'X')
seq = ''.join(seq)
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X")
seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
......@@ -274,11 +291,12 @@ def parse(*,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info)
raw_string=parsed_info,
)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
errors[(file_id, "")] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
......@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
return min(revision_dates)
......@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}
experiments = mmcif_loop_to_list('_exptl.', parsed_info)
header['structure_method'] = ','.join([
experiment['_exptl.method'].lower() for experiment in experiments])
experiments = mmcif_loop_to_list("_exptl.", parsed_info)
header["structure_method"] = ",".join(
[experiment["_exptl.method"].lower() for experiment in experiments]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
header['release_date'] = get_release_date(parsed_info)
if "_pdbx_audit_revision_history.revision_date" in parsed_info:
header["release_date"] = get_release_date(parsed_info)
else:
logging.warning('Could not determine release_date: %s',
parsed_info['_entry.id'])
logging.warning(
"Could not determine release_date: %s", parsed_info["_entry.id"]
)
header['resolution'] = 0.00
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution',
'_reflns.d_resolution_high'):
header["resolution"] = 0.00
for res_key in (
"_refine.ls_d_res_high",
"_em_3d_reconstruction.resolution",
"_reflns.d_resolution_high",
):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
header["resolution"] = float(raw_resolution)
except ValueError:
logging.warning('Invalid resolution format: %s', parsed_info[res_key])
logging.warning(
"Invalid resolution format: %s", parsed_info[res_key]
)
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_comp_id'],
parsed_info['_atom_site.auth_asym_id'],
parsed_info['_atom_site.label_asym_id'],
parsed_info['_atom_site.auth_seq_id'],
parsed_info['_atom_site.label_seq_id'],
parsed_info['_atom_site.pdbx_PDB_ins_code'],
parsed_info['_atom_site.group_PDB'],
parsed_info['_atom_site.pdbx_PDB_model_num'],
)]
return [
AtomSite(*site)
for site in zip( # pylint:disable=g-complex-comprehension
parsed_info["_atom_site.label_comp_id"],
parsed_info["_atom_site.auth_asym_id"],
parsed_info["_atom_site.label_asym_id"],
parsed_info["_atom_site.auth_seq_id"],
parsed_info["_atom_site.label_seq_id"],
parsed_info["_atom_site.pdbx_PDB_ins_code"],
parsed_info["_atom_site.group_PDB"],
parsed_info["_atom_site.pdbx_PDB_model_num"],
)
]
def _get_protein_chains(
*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]:
*, parsed_info: Mapping[str, Any]
) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
......@@ -351,26 +381,29 @@ def _get_protein_chains(
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
num=int(entity_poly_seq['_entity_poly_seq.num'])))
polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
Monomer(
id=entity_poly_seq["_entity_poly_seq.mon_id"],
num=int(entity_poly_seq["_entity_poly_seq.num"]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info)
chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)
struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id']
entity_id = struct_asym['_struct_asym.entity_id']
chain_id = struct_asym["_struct_asym.id"]
entity_id = struct_asym["_struct_asym.entity_id"]
entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains.
......@@ -379,8 +412,12 @@ def _get_protein_chains(
chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type']
for monomer in seq_info]):
if any(
[
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
for monomer in seq_info
]
):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
......@@ -388,19 +425,18 @@ def _get_protein_chains(
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')
return data not in (".", "?")
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str
mmcif_object: MmcifObject, chain_id: str
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.'
f"Expected exactly one chain in structure with id {chain_id}."
)
chain = relevant_chains[0]
......@@ -417,19 +453,23 @@ def get_atom_coords(
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res = chain[
(
res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
res_at_position.position.insertion_code,
)
]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
......@@ -440,22 +480,22 @@ def get_atom_coords(
def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
with open(os.path.join(mmcif_dir, f), 'r') as fp:
if f.endswith(".cif"):
with open(os.path.join(mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None):
logging.warning(f'Could not parse {f}. Skipping...')
if mmcif.mmcif_object is None:
logging.warning(f"Could not parse {f}. Skipping...")
continue
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data['release_date'] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains()))
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data
with open(out_path, 'w') as fp:
with open(out_path, "w") as fp:
fp.write(json.dumps(data))
......@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
......@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith('>'):
if line.startswith(">"):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('')
sequences.append("")
continue
elif not line:
continue # Skip blank lines.
......@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions
def parse_stockholm(stockholm_string: str
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
def parse_stockholm(
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
......@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(('#', '//')):
if not line or line.startswith(("#", "//")):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ''
name_to_sequence[name] = ""
name_to_sequence[name] += sequence
msa = []
deletion_matrix = []
query = ''
query = ""
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-']
keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns])
aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence)
......@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-':
if query_res == '-':
if seq_res != "-" or query_res != "-":
if query_res == "-":
deletion_count += 1
else:
deletion_vec.append(deletion_count)
......@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]:
query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != '-':
elif sequence_res != "-":
yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None) -> str:
def convert_stockholm_to_a3m(
stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences
if line.strip() and not line.startswith(('#', '//')):
reached_max_sequences = (
max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] = ""
sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines():
if line[:4] == '#=GS':
if line[:4] == "#=GS":
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
value = columns[3] if len(columns) == 4 else ""
if feature != "DE":
continue
if reached_max_sequences and seqname not in sequences:
continue
......@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence))
a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
fasta_chunks = (
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]:
regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f'Could not parse query line {line}')
raise RuntimeError(f"Could not parse query line {line}")
return match.groups()
def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]):
sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == '-':
if symbol == "-":
indices_list.append(-1)
else:
indices_list.append(counter)
......@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line.
pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
']*Template_Neff=(.*)')
"Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
" ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
"]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2]))
(prob_true, e_value, _, aligned_cols, _, _, sum_probs,
neff) = [float(x) for x in match.groups()]
"Could not parse section: %s. Expected this: \n%s to contain summary."
% (detailed_lines, detailed_lines[2])
)
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ''
hit_sequence = ''
query = ""
hit_sequence = ""
indices_query = []
indices_hit = []
length_block = None
for line in detailed_lines[3:]:
# Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and
not line.startswith('Q ss_pred') and
not line.startswith('Q Consensus')):
if (
line.startswith("Q ")
and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices,
......@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-'])
num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions
assert length_block == len(delta_query)
......@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '):
elif line.startswith("T "):
# Parse the hit sequence.
if (not line.startswith('T ss_dssp') and
not line.startswith('T ss_pred') and
not line.startswith('T Consensus')):
if (
not line.startswith("T ss_dssp")
and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
......@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit)
_update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit(
index=number_of_hit,
......@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')]
block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0}
lines = [line for line in tblout.splitlines() if line[0] != '#']
e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
......
......@@ -89,29 +89,30 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES = {
'template_aatype': np.int64,
'template_all_atom_mask': np.float32,
'template_all_atom_positions': np.float32,
'template_domain_names': np.object,
'template_sequence': np.object,
'template_sum_probs': np.float32,
"template_aatype": np.int64,
"template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32,
"template_domain_names": np.object,
"template_sequence": np.object,
"template_sum_probs": np.float32,
}
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
"""Returns PDB id and chain id for an HHSearch Hit."""
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
id_match = re.match(r'[a-zA-Z\d]{4}_[a-zA-Z0-9.]+', hit.name)
id_match = re.match(r"[a-zA-Z\d]{4}_[a-zA-Z0-9.]+", hit.name)
if not id_match:
raise ValueError(f'hit.name did not start with PDBID_chain: {hit.name}')
pdb_id, chain_id = id_match.group(0).split('_')
raise ValueError(f"hit.name did not start with PDBID_chain: {hit.name}")
pdb_id, chain_id = id_match.group(0).split("_")
return pdb_id.lower(), chain_id
def _is_after_cutoff(
pdb_id: str,
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: Optional[datetime.datetime]) -> bool:
release_date_cutoff: Optional[datetime.datetime],
) -> bool:
"""Checks if the template date is after the release date cutoff.
Args:
......@@ -123,13 +124,15 @@ def _is_after_cutoff(
True if the template release date is after the cutoff, False otherwise.
"""
if release_date_cutoff is None:
raise ValueError('The release_date_cutoff must not be None.')
raise ValueError("The release_date_cutoff must not be None.")
if pdb_id in release_dates:
return release_dates[pdb_id] > release_date_cutoff
else:
# Since this is just a quick prefilter to reduce the number of mmCIF files
# we need to parse, we don't have to worry about returning True here.
logging.warning('Template structure not in release dates dict: %s', pdb_id)
logging.warning(
"Template structure not in release dates dict: %s", pdb_id
)
return False
......@@ -140,7 +143,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
for line in f:
line = line.strip()
# We skip obsolete entries that don't contain a mapping to a new entry.
if line.startswith('OBSLTE') and len(line) > 30:
if line.startswith("OBSLTE") and len(line) > 30:
# Format: Date From To
# 'OBSLTE 31-JUL-94 116L 216L'
from_id = line[20:24].lower()
......@@ -152,38 +155,41 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
if f.endswith(".cif"):
path = os.path.join(mmcif_dir, f)
with open(path, 'r') as fp:
with open(path, "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}. Skipping...')
if mmcif.mmcif_object is None:
logging.warning(f"Failed to parse {f}. Skipping...")
continue
mmcif = mmcif.mmcif_object
release_date = mmcif.header['release_date']
release_date = mmcif.header["release_date"]
dates[file_id] = release_date
with open(out_path, 'r') as fp:
with open(out_path, "r") as fp:
fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
with open(path, 'r') as fp:
with open(path, "r") as fp:
data = json.load(fp)
return {
pdb:to_date(v) for pdb,d in data.items() for k,v in d.items()
pdb: to_date(v)
for pdb, d in data.items()
for k, v in d.items()
if k == "release_date"
}
def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
hit_pdb_code: str,
......@@ -192,7 +198,8 @@ def _assess_hhsearch_hit(
release_dates: Mapping[str, datetime.datetime],
release_date_cutoff: datetime.datetime,
max_subsequence_ratio: float = 0.95,
min_align_ratio: float = 0.1) -> bool:
min_align_ratio: float = 0.1,
) -> bool:
"""Determines if template is valid (without parsing the template mmcif file).
Args:
......@@ -221,32 +228,42 @@ def _assess_hhsearch_hit(
aligned_cols = hit.aligned_cols
align_ratio = aligned_cols / len(query_sequence)
template_sequence = hit.hit_sequence.replace('-', '')
template_sequence = hit.hit_sequence.replace("-", "")
length_ratio = float(len(template_sequence)) / len(query_sequence)
# Check whether the template is a large subsequence or duplicate of original
# query. This can happen due to duplicate entries in the PDB database.
duplicate = (template_sequence in query_sequence and
length_ratio > max_subsequence_ratio)
duplicate = (
template_sequence in query_sequence
and length_ratio > max_subsequence_ratio
)
if _is_after_cutoff(hit_pdb_code, release_dates, release_date_cutoff):
raise DateError(f'Date ({release_dates[hit_pdb_code]}) > max template date '
f'({release_date_cutoff}).')
raise DateError(
f"Date ({release_dates[hit_pdb_code]}) > max template date "
f"({release_date_cutoff})."
)
if query_pdb_code is not None:
if query_pdb_code.lower() == hit_pdb_code.lower():
raise PdbIdError('PDB code identical to Query PDB code.')
raise PdbIdError("PDB code identical to Query PDB code.")
if align_ratio <= min_align_ratio:
raise AlignRatioError('Proportion of residues aligned to query too small. '
f'Align ratio: {align_ratio}.')
raise AlignRatioError(
"Proportion of residues aligned to query too small. "
f"Align ratio: {align_ratio}."
)
if duplicate:
raise DuplicateError('Template is an exact subsequence of query with large '
f'coverage. Length ratio: {length_ratio}.')
raise DuplicateError(
"Template is an exact subsequence of query with large "
f"coverage. Length ratio: {length_ratio}."
)
if len(template_sequence) < 10:
raise LengthError(f'Template too short. Length: {len(template_sequence)}.')
raise LengthError(
f"Template too short. Length: {len(template_sequence)}."
)
return True
......@@ -254,7 +271,8 @@ def _assess_hhsearch_hit(
def _find_template_in_pdb(
template_chain_id: str,
template_sequence: str,
mmcif_object: mmcif_parsing.MmcifObject) -> Tuple[str, str, int]:
mmcif_object: mmcif_parsing.MmcifObject,
) -> Tuple[str, str, int]:
"""Tries to find the template chain in the given pdb file.
This method tries the three following things in order:
......@@ -286,33 +304,42 @@ def _find_template_in_pdb(
chain_sequence = mmcif_object.chain_to_seqres.get(template_chain_id)
if chain_sequence and (template_sequence in chain_sequence):
logging.info(
'Found an exact template match %s_%s.', pdb_id, template_chain_id)
"Found an exact template match %s_%s.", pdb_id, template_chain_id
)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, template_chain_id, mapping_offset
# Try if there is an exact match in the (sub)sequence only.
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
if chain_sequence and (template_sequence in chain_sequence):
logging.info('Found a sequence-only match %s_%s.', pdb_id, chain_id)
logging.info("Found a sequence-only match %s_%s.", pdb_id, chain_id)
mapping_offset = chain_sequence.find(template_sequence)
return chain_sequence, chain_id, mapping_offset
# Return a chain sequence that fuzzy matches (X = wildcard) the template.
# Make parentheses unnamed groups (?:_) to avoid the 100 named groups limit.
regex = ['.' if aa == 'X' else '(?:%s|X)' % aa for aa in template_sequence]
regex = re.compile(''.join(regex))
regex = ["." if aa == "X" else "(?:%s|X)" % aa for aa in template_sequence]
regex = re.compile("".join(regex))
for chain_id, chain_sequence in mmcif_object.chain_to_seqres.items():
match = re.search(regex, chain_sequence)
if match:
logging.info('Found a fuzzy sequence-only match %s_%s.', pdb_id, chain_id)
logging.info(
"Found a fuzzy sequence-only match %s_%s.", pdb_id, chain_id
)
mapping_offset = match.start()
return chain_sequence, chain_id, mapping_offset
# No hits, raise an error.
raise SequenceNotInTemplateError(
'Could not find the template sequence in %s_%s. Template sequence: %s, '
'chain_to_seqres: %s' % (pdb_id, template_chain_id, template_sequence,
mmcif_object.chain_to_seqres))
"Could not find the template sequence in %s_%s. Template sequence: %s, "
"chain_to_seqres: %s"
% (
pdb_id,
template_chain_id,
template_sequence,
mmcif_object.chain_to_seqres,
)
)
def _realign_pdb_template_to_query(
......@@ -320,7 +347,8 @@ def _realign_pdb_template_to_query(
template_chain_id: str,
mmcif_object: mmcif_parsing.MmcifObject,
old_mapping: Mapping[int, int],
kalign_binary_path: str) -> Tuple[str, Mapping[int, int]]:
kalign_binary_path: str,
) -> Tuple[str, Mapping[int, int]]:
"""Aligns template from the mmcif_object to the query.
In case PDB70 contains a different version of the template sequence, we need
......@@ -361,76 +389,104 @@ def _realign_pdb_template_to_query(
"""
aligner = kalign.Kalign(binary_path=kalign_binary_path)
new_template_sequence = mmcif_object.chain_to_seqres.get(
template_chain_id, '')
template_chain_id, ""
)
# Sometimes the template chain id is unknown. But if there is only a single
# sequence within the mmcif_object, it is safe to assume it is that one.
if not new_template_sequence:
if len(mmcif_object.chain_to_seqres) == 1:
logging.info('Could not find %s in %s, but there is only 1 sequence, so '
'using that one.',
logging.info(
"Could not find %s in %s, but there is only 1 sequence, so "
"using that one.",
template_chain_id,
mmcif_object.file_id)
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[0]
mmcif_object.file_id,
)
new_template_sequence = list(mmcif_object.chain_to_seqres.values())[
0
]
else:
raise QueryToTemplateAlignError(
f'Could not find chain {template_chain_id} in {mmcif_object.file_id}. '
'If there are no mmCIF parsing errors, it is possible it was not a '
'protein chain.')
f"Could not find chain {template_chain_id} in {mmcif_object.file_id}. "
"If there are no mmCIF parsing errors, it is possible it was not a "
"protein chain."
)
try:
(old_aligned_template, new_aligned_template), _ = parsers.parse_a3m(
aligner.align([old_template_sequence, new_template_sequence]))
aligner.align([old_template_sequence, new_template_sequence])
)
except Exception as e:
raise QueryToTemplateAlignError(
'Could not align old template %s to template %s (%s_%s). Error: %s' %
(old_template_sequence, new_template_sequence, mmcif_object.file_id,
template_chain_id, str(e)))
"Could not align old template %s to template %s (%s_%s). Error: %s"
% (
old_template_sequence,
new_template_sequence,
mmcif_object.file_id,
template_chain_id,
str(e),
)
)
logging.info('Old aligned template: %s\nNew aligned template: %s',
old_aligned_template, new_aligned_template)
logging.info(
"Old aligned template: %s\nNew aligned template: %s",
old_aligned_template,
new_aligned_template,
)
old_to_new_template_mapping = {}
old_template_index = -1
new_template_index = -1
num_same = 0
for old_template_aa, new_template_aa in zip(
old_aligned_template, new_aligned_template):
if old_template_aa != '-':
old_aligned_template, new_aligned_template
):
if old_template_aa != "-":
old_template_index += 1
if new_template_aa != '-':
if new_template_aa != "-":
new_template_index += 1
if old_template_aa != '-' and new_template_aa != '-':
if old_template_aa != "-" and new_template_aa != "-":
old_to_new_template_mapping[old_template_index] = new_template_index
if old_template_aa == new_template_aa:
num_same += 1
# Require at least 90 % sequence identity wrt to the shorter of the sequences.
if float(num_same) / min(
len(old_template_sequence), len(new_template_sequence)) < 0.9:
if (
float(num_same)
/ min(len(old_template_sequence), len(new_template_sequence))
< 0.9
):
raise QueryToTemplateAlignError(
'Insufficient similarity of the sequence in the database: %s to the '
'actual sequence in the mmCIF file %s_%s: %s. We require at least '
'90 %% similarity wrt to the shorter of the sequences. This is not a '
'problem unless you think this is a template that should be included.' %
(old_template_sequence, mmcif_object.file_id, template_chain_id,
new_template_sequence))
"Insufficient similarity of the sequence in the database: %s to the "
"actual sequence in the mmCIF file %s_%s: %s. We require at least "
"90 %% similarity wrt to the shorter of the sequences. This is not a "
"problem unless you think this is a template that should be included."
% (
old_template_sequence,
mmcif_object.file_id,
template_chain_id,
new_template_sequence,
)
)
new_query_to_template_mapping = {}
for query_index, old_template_index in old_mapping.items():
new_query_to_template_mapping[query_index] = (
old_to_new_template_mapping.get(old_template_index, -1))
new_query_to_template_mapping[
query_index
] = old_to_new_template_mapping.get(old_template_index, -1)
new_template_sequence = new_template_sequence.replace('-', '')
new_template_sequence = new_template_sequence.replace("-", "")
return new_template_sequence, new_query_to_template_mapping
def _check_residue_distances(all_positions: np.ndarray,
def _check_residue_distances(
all_positions: np.ndarray,
all_positions_mask: np.ndarray,
max_ca_ca_distance: float):
max_ca_ca_distance: float,
):
"""Checks if the distance between unmasked neighbor residues is ok."""
ca_position = residue_constants.atom_order['CA']
ca_position = residue_constants.atom_order["CA"]
prev_is_unmasked = False
prev_calpha = None
for i, (coords, mask) in enumerate(zip(all_positions, all_positions_mask)):
......@@ -441,8 +497,9 @@ def _check_residue_distances(all_positions: np.ndarray,
distance = np.linalg.norm(this_calpha - prev_calpha)
if distance > max_ca_ca_distance:
raise CaDistanceError(
'The distance between residues %d and %d is %f > limit %f.' % (
i, i + 1, distance, max_ca_ca_distance))
"The distance between residues %d and %d is %f > limit %f."
% (i, i + 1, distance, max_ca_ca_distance)
)
prev_calpha = this_calpha
prev_is_unmasked = this_is_unmasked
......@@ -450,7 +507,8 @@ def _check_residue_distances(all_positions: np.ndarray,
def _get_atom_positions(
mmcif_object: mmcif_parsing.MmcifObject,
auth_chain_id: str,
max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
max_ca_ca_distance: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id
......@@ -469,7 +527,8 @@ def _extract_template_features(
template_sequence: str,
query_sequence: str,
template_chain_id: str,
kalign_binary_path: str) -> Tuple[Dict[str, Any], Optional[str]]:
kalign_binary_path: str,
) -> Tuple[Dict[str, Any], Optional[str]]:
"""Parses atom positions in the target structure and aligns with the query.
Atoms for each residue in the template structure are indexed to coincide
......@@ -509,21 +568,25 @@ def _extract_template_features(
unmasked residues.
"""
if mmcif_object is None or not mmcif_object.chain_to_seqres:
raise NoChainsError('No chains in PDB: %s_%s' % (pdb_id, template_chain_id))
raise NoChainsError(
"No chains in PDB: %s_%s" % (pdb_id, template_chain_id)
)
warning = None
try:
seqres, chain_id, mapping_offset = _find_template_in_pdb(
template_chain_id=template_chain_id,
template_sequence=template_sequence,
mmcif_object=mmcif_object)
mmcif_object=mmcif_object,
)
except SequenceNotInTemplateError:
# If PDB70 contains a different version of the template, we use the sequence
# from the mmcif_object.
chain_id = template_chain_id
warning = (
f'The exact sequence {template_sequence} was not found in '
f'{pdb_id}_{chain_id}. Realigning the template to the actual sequence.')
f"The exact sequence {template_sequence} was not found in "
f"{pdb_id}_{chain_id}. Realigning the template to the actual sequence."
)
logging.warning(warning)
# This throws an exception if it fails to realign the hit.
seqres, mapping = _realign_pdb_template_to_query(
......@@ -531,9 +594,15 @@ def _extract_template_features(
template_chain_id=template_chain_id,
mmcif_object=mmcif_object,
old_mapping=mapping,
kalign_binary_path=kalign_binary_path)
logging.info('Sequence in %s_%s: %s successfully realigned to %s',
pdb_id, chain_id, template_sequence, seqres)
kalign_binary_path=kalign_binary_path,
)
logging.info(
"Sequence in %s_%s: %s successfully realigned to %s",
pdb_id,
chain_id,
template_sequence,
seqres,
)
# The template sequence changed.
template_sequence = seqres
# No mapping offset, the query is aligned to the actual sequence.
......@@ -543,13 +612,16 @@ def _extract_template_features(
# Essentially set to infinity - we don't want to reject templates unless
# they're really really bad.
all_atom_positions, all_atom_mask = _get_atom_positions(
mmcif_object, chain_id, max_ca_ca_distance=150.0)
mmcif_object, chain_id, max_ca_ca_distance=150.0
)
except (CaDistanceError, KeyError) as ex:
raise NoAtomDataInTemplateError(
'Could not get atom data (%s_%s): %s' % (pdb_id, chain_id, str(ex))
"Could not get atom data (%s_%s): %s" % (pdb_id, chain_id, str(ex))
) from ex
all_atom_positions = np.split(all_atom_positions, all_atom_positions.shape[0])
all_atom_positions = np.split(
all_atom_positions, all_atom_positions.shape[0]
)
all_atom_masks = np.split(all_atom_mask, all_atom_mask.shape[0])
output_templates_sequence = []
......@@ -559,9 +631,12 @@ def _extract_template_features(
for _ in query_sequence:
# Residues in the query_sequence that are not in the template_sequence:
templates_all_atom_positions.append(
np.zeros((residue_constants.atom_type_num, 3)))
templates_all_atom_masks.append(np.zeros(residue_constants.atom_type_num))
output_templates_sequence.append('-')
np.zeros((residue_constants.atom_type_num, 3))
)
templates_all_atom_masks.append(
np.zeros(residue_constants.atom_type_num)
)
output_templates_sequence.append("-")
for k, v in mapping.items():
template_index = v + mapping_offset
......@@ -572,24 +647,33 @@ def _extract_template_features(
# Alanine (AA with the lowest number of atoms) has 5 atoms (C, CA, CB, N, O).
if np.sum(templates_all_atom_masks) < 5:
raise TemplateAtomMaskAllZerosError(
'Template all atom mask was all zeros: %s_%s. Residue range: %d-%d' %
(pdb_id, chain_id, min(mapping.values()) + mapping_offset,
max(mapping.values()) + mapping_offset))
"Template all atom mask was all zeros: %s_%s. Residue range: %d-%d"
% (
pdb_id,
chain_id,
min(mapping.values()) + mapping_offset,
max(mapping.values()) + mapping_offset,
)
)
output_templates_sequence = ''.join(output_templates_sequence)
output_templates_sequence = "".join(output_templates_sequence)
templates_aatype = residue_constants.sequence_to_onehot(
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID)
output_templates_sequence, residue_constants.HHBLITS_AA_TO_ID
)
return (
{
'template_all_atom_positions': np.array(templates_all_atom_positions),
'template_all_atom_mask': np.array(templates_all_atom_masks),
'template_sequence': output_templates_sequence.encode(),
'template_aatype': np.array(templates_aatype),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
"template_all_atom_positions": np.array(
templates_all_atom_positions
),
"template_all_atom_mask": np.array(templates_all_atom_masks),
"template_sequence": output_templates_sequence.encode(),
"template_aatype": np.array(templates_aatype),
"template_domain_names": f"{pdb_id.lower()}_{chain_id}".encode(),
},
warning)
warning,
)
def _build_query_to_hit_index_mapping(
......@@ -597,7 +681,8 @@ def _build_query_to_hit_index_mapping(
hit_sequence: str,
indices_hit: Sequence[int],
indices_query: Sequence[int],
original_query_sequence: str) -> Mapping[int, int]:
original_query_sequence: str,
) -> Mapping[int, int]:
"""Gets mapping from indices in original query sequence to indices in the hit.
hit_query_sequence and hit_sequence are two aligned sequences containing gap
......@@ -624,15 +709,15 @@ def _build_query_to_hit_index_mapping(
return {}
# Remove gaps and find the offset of hit.query relative to original query.
hhsearch_query_sequence = hit_query_sequence.replace('-', '')
hit_sequence = hit_sequence.replace('-', '')
hhsearch_query_offset = original_query_sequence.find(hhsearch_query_sequence)
hhsearch_query_sequence = hit_query_sequence.replace("-", "")
hit_sequence = hit_sequence.replace("-", "")
hhsearch_query_offset = original_query_sequence.find(
hhsearch_query_sequence
)
# Index of -1 used for gap characters. Subtract the min index ignoring gaps.
min_idx = min(x for x in indices_hit if x > -1)
fixed_indices_hit = [
x - min_idx if x > -1 else -1 for x in indices_hit
]
fixed_indices_hit = [x - min_idx if x > -1 else -1 for x in indices_hit]
min_idx = min(x for x in indices_query if x > -1)
fixed_indices_query = [x - min_idx if x > -1 else -1 for x in indices_query]
......@@ -641,8 +726,9 @@ def _build_query_to_hit_index_mapping(
mapping = {}
for q_i, q_t in zip(fixed_indices_query, fixed_indices_hit):
if q_t != -1 and q_i != -1:
if (q_t >= len(hit_sequence) or
q_i + hhsearch_query_offset >= len(original_query_sequence)):
if q_t >= len(hit_sequence) or q_i + hhsearch_query_offset >= len(
original_query_sequence
):
continue
mapping[q_i + hhsearch_query_offset] = q_t
......@@ -665,7 +751,8 @@ def _process_single_hit(
release_dates: Mapping[str, datetime.datetime],
obsolete_pdbs: Mapping[str, str],
kalign_binary_path: str,
strict_error_check: bool = False) -> SingleHitResult:
strict_error_check: bool = False,
) -> SingleHitResult:
"""Tries to extract template features from a single HHSearch hit."""
# Fail hard if we can't get the PDB ID and chain name from the hit.
hit_pdb_code, hit_chain_id = _get_pdb_id_and_chain(hit)
......@@ -682,41 +769,56 @@ def _process_single_hit(
query_sequence=query_sequence,
query_pdb_code=query_pdb_code,
release_dates=release_dates,
release_date_cutoff=max_template_date)
release_date_cutoff=max_template_date,
)
except PrefilterError as e:
msg = f'hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}'
logging.info('%s: %s', query_pdb_code, msg)
msg = f"hit {hit_pdb_code}_{hit_chain_id} did not pass prefilter: {str(e)}"
logging.info("%s: %s", query_pdb_code, msg)
if strict_error_check and isinstance(
e, (DateError, PdbIdError, DuplicateError)):
e, (DateError, PdbIdError, DuplicateError)
):
# In strict mode we treat some prefilter cases as errors.
return SingleHitResult(features=None, error=msg, warning=None)
return SingleHitResult(features=None, error=None, warning=None)
mapping = _build_query_to_hit_index_mapping(
hit.query, hit.hit_sequence, hit.indices_hit, hit.indices_query,
query_sequence)
hit.query,
hit.hit_sequence,
hit.indices_hit,
hit.indices_query,
query_sequence,
)
# The mapping is from the query to the actual hit sequence, so we need to
# remove gaps (which regardless have a missing confidence score).
template_sequence = hit.hit_sequence.replace('-', '')
template_sequence = hit.hit_sequence.replace("-", "")
cif_path = os.path.join(mmcif_dir, hit_pdb_code + '.cif')
logging.info('Reading PDB entry from %s. Query: %s, template: %s',
cif_path, query_sequence, template_sequence)
cif_path = os.path.join(mmcif_dir, hit_pdb_code + ".cif")
logging.info(
"Reading PDB entry from %s. Query: %s, template: %s",
cif_path,
query_sequence,
template_sequence,
)
# Fail if we can't find the mmCIF file.
with open(cif_path, 'r') as cif_file:
with open(cif_path, "r") as cif_file:
cif_string = cif_file.read()
parsing_result = mmcif_parsing.parse(
file_id=hit_pdb_code, mmcif_string=cif_string)
file_id=hit_pdb_code, mmcif_string=cif_string
)
if parsing_result.mmcif_object is not None:
hit_release_date = datetime.datetime.strptime(
parsing_result.mmcif_object.header['release_date'], '%Y-%m-%d')
parsing_result.mmcif_object.header["release_date"], "%Y-%m-%d"
)
if hit_release_date > max_template_date:
error = ('Template %s date (%s) > max template date (%s).' %
(hit_pdb_code, hit_release_date, max_template_date))
error = "Template %s date (%s) > max template date (%s)." % (
hit_pdb_code,
hit_release_date,
max_template_date,
)
if strict_error_check:
return SingleHitResult(features=None, error=error, warning=None)
else:
......@@ -731,31 +833,52 @@ def _process_single_hit(
template_sequence=template_sequence,
query_sequence=query_sequence,
template_chain_id=hit_chain_id,
kalign_binary_path=kalign_binary_path)
features['template_sum_probs'] = [hit.sum_probs]
kalign_binary_path=kalign_binary_path,
)
features["template_sum_probs"] = [hit.sum_probs]
# It is possible there were some errors when parsing the other chains in the
# mmCIF file, but the template features for the chain we want were still
# computed. In such case the mmCIF parsing errors are not relevant.
return SingleHitResult(
features=features, error=None, warning=realign_warning)
except (NoChainsError, NoAtomDataInTemplateError,
TemplateAtomMaskAllZerosError) as e:
features=features, error=None, warning=realign_warning
)
except (
NoChainsError,
NoAtomDataInTemplateError,
TemplateAtomMaskAllZerosError,
) as e:
# These 3 errors indicate missing mmCIF experimental data rather than a
# problem with the template search, so turn them into warnings.
warning = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
warning = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
if strict_error_check:
return SingleHitResult(features=None, error=warning, warning=None)
else:
return SingleHitResult(features=None, error=None, warning=warning)
except Error as e:
error = ('%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: '
'%s, mmCIF parsing errors: %s'
% (hit_pdb_code, hit_chain_id, hit.sum_probs, hit.index,
str(e), parsing_result.errors))
error = (
"%s_%s (sum_probs: %.2f, rank: %d): feature extracting errors: "
"%s, mmCIF parsing errors: %s"
% (
hit_pdb_code,
hit_chain_id,
hit.sum_probs,
hit.index,
str(e),
parsing_result.errors,
)
)
return SingleHitResult(features=None, error=error, warning=None)
......@@ -777,7 +900,8 @@ class TemplateHitFeaturizer:
kalign_binary_path: str,
release_dates_path: Optional[str],
obsolete_pdbs_path: Optional[str],
strict_error_check: bool = False):
strict_error_check: bool = False,
):
"""Initializes the Template Search.
Args:
......@@ -802,28 +926,34 @@ class TemplateHitFeaturizer:
* Any feature computation errors.
"""
self._mmcif_dir = mmcif_dir
if not glob.glob(os.path.join(self._mmcif_dir, '*.cif')):
logging.error('Could not find CIFs in %s', self._mmcif_dir)
raise ValueError(f'Could not find CIFs in {self._mmcif_dir}')
if not glob.glob(os.path.join(self._mmcif_dir, "*.cif")):
logging.error("Could not find CIFs in %s", self._mmcif_dir)
raise ValueError(f"Could not find CIFs in {self._mmcif_dir}")
try:
self._max_template_date = datetime.datetime.strptime(
max_template_date, '%Y-%m-%d')
max_template_date, "%Y-%m-%d"
)
except ValueError:
raise ValueError(
'max_template_date must be set and have format YYYY-MM-DD.')
"max_template_date must be set and have format YYYY-MM-DD."
)
self._max_hits = max_hits
self._kalign_binary_path = kalign_binary_path
self._strict_error_check = strict_error_check
if release_dates_path:
logging.info('Using precomputed release dates %s.', release_dates_path)
logging.info(
"Using precomputed release dates %s.", release_dates_path
)
self._release_dates = _parse_release_dates(release_dates_path)
else:
self._release_dates = {}
if obsolete_pdbs_path:
logging.info('Using precomputed obsolete pdbs %s.', obsolete_pdbs_path)
logging.info(
"Using precomputed obsolete pdbs %s.", obsolete_pdbs_path
)
self._obsolete_pdbs = _parse_obsolete(obsolete_pdbs_path)
else:
self._obsolete_pdbs = {}
......@@ -833,9 +963,10 @@ class TemplateHitFeaturizer:
query_sequence: str,
query_pdb_code: Optional[str],
query_release_date: Optional[datetime.datetime],
hits: Sequence[parsers.TemplateHit]) -> TemplateSearchResult:
hits: Sequence[parsers.TemplateHit],
) -> TemplateSearchResult:
"""Computes the templates for given query sequence (more details above)."""
logging.info('Searching for template for: %s', query_pdb_code)
logging.info("Searching for template for: %s", query_pdb_code)
template_features = {}
for template_feature_name in TEMPLATE_FEATURES:
......@@ -869,7 +1000,8 @@ class TemplateHitFeaturizer:
release_dates=self._release_dates,
obsolete_pdbs=self._obsolete_pdbs,
strict_error_check=self._strict_error_check,
kalign_binary_path=self._kalign_binary_path)
kalign_binary_path=self._kalign_binary_path,
)
if result.error:
errors.append(result.error)
......@@ -880,8 +1012,12 @@ class TemplateHitFeaturizer:
warnings.append(result.warning)
if result.features is None:
logging.info('Skipped invalid hit %s, error: %s, warning: %s',
hit.name, result.error, result.warning)
logging.info(
"Skipped invalid hit %s, error: %s, warning: %s",
hit.name,
result.error,
result.warning,
)
else:
# Increment the hit counter, since we got features out of this hit.
num_hits += 1
......@@ -891,10 +1027,14 @@ class TemplateHitFeaturizer:
for name in template_features:
if num_hits > 0:
template_features[name] = np.stack(
template_features[name], axis=0).astype(TEMPLATE_FEATURES[name])
template_features[name], axis=0
).astype(TEMPLATE_FEATURES[name])
else:
# Make sure the feature has correct dtype even if empty.
template_features[name] = np.array([], dtype=TEMPLATE_FEATURES[name])
template_features[name] = np.array(
[], dtype=TEMPLATE_FEATURES[name]
)
return TemplateSearchResult(
features=template_features, errors=errors, warnings=warnings)
features=template_features, errors=errors, warnings=warnings
)
......@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500
class HHBlits:
"""Python wrapper of the HHblits binary."""
def __init__(self,
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
......@@ -44,7 +45,8 @@ class HHBlits:
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z):
z: int = _HHBLITS_DEFAULT_Z,
):
"""Initializes the Python HHblits wrapper.
Args:
......@@ -77,9 +79,13 @@ class HHBlits:
self.databases = databases
for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHBlits database %s', database_path)
raise ValueError(f'Could not find HHBlits database {database_path}')
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHBlits database %s", database_path
)
raise ValueError(
f"Could not find HHBlits database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
......@@ -95,52 +101,66 @@ class HHBlits:
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-cpu', str(self.n_cpu),
'-oa3m', a3m_path,
'-o', '/dev/null',
'-n', str(self.n_iter),
'-e', str(self.e_value),
'-maxseq', str(self.maxseq),
'-realign_max', str(self.realign_max),
'-maxfilt', str(self.maxfilt),
'-min_prefilter_hits', str(self.min_prefilter_hits)]
"-i",
input_fasta_path,
"-cpu",
str(self.n_cpu),
"-oa3m",
a3m_path,
"-o",
"/dev/null",
"-n",
str(self.n_iter),
"-e",
str(self.e_value),
"-maxseq",
str(self.maxseq),
"-realign_max",
str(self.realign_max),
"-maxfilt",
str(self.maxfilt),
"-min_prefilter_hits",
str(self.min_prefilter_hits),
]
if self.all_seqs:
cmd += ['-all']
cmd += ["-all"]
if self.alt:
cmd += ['-alt', str(self.alt)]
cmd += ["-alt", str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ['-p', str(self.p)]
cmd += ["-p", str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ['-Z', str(self.z)]
cmd += ["-Z", str(self.z)]
cmd += db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing('HHblits query'):
with utils.timing("HHblits query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
logging.error('HHblits failed. HHblits stderr begin:')
for error_line in stderr.decode('utf-8').splitlines():
logging.error("HHblits failed. HHblits stderr begin:")
for error_line in stderr.decode("utf-8").splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error('HHblits stderr end')
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
logging.error("HHblits stderr end")
raise RuntimeError(
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
)
with open(a3m_path) as f:
a3m = f.read()
......@@ -150,5 +170,6 @@ class HHBlits:
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
e_value=self.e_value,
)
return raw_output
......@@ -26,12 +26,14 @@ from openfold.data.np import utils
class HHSearch:
"""Python wrapper of the HHsearch binary."""
def __init__(self,
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000):
maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper.
Args:
......@@ -52,41 +54,52 @@ class HHSearch:
self.maxseq = maxseq
for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHsearch database %s', database_path)
raise ValueError(f'Could not find HHsearch database {database_path}')
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f:
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f:
f.write(a3m)
db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [self.binary_path,
'-i', input_path,
'-o', hhr_path,
'-maxseq', str(self.maxseq),
'-cpu', str(self.n_cpu),
cmd = [
self.binary_path,
"-i",
input_path,
"-o",
hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing('HHsearch query'):
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHsearch query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f:
hhr = f.read()
......
......@@ -29,7 +29,8 @@ from openfold.data.tools import utils
class Jackhmmer:
"""Python wrapper of the Jackhmmer binary."""
def __init__(self,
def __init__(
self,
*,
binary_path: str,
database_path: str,
......@@ -44,7 +45,8 @@ class Jackhmmer:
incdom_e: Optional[float] = None,
dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None):
streaming_callback: Optional[Callable[[int], None]] = None,
):
"""Initializes the Python Jackhmmer wrapper.
Args:
......@@ -69,9 +71,14 @@ class Jackhmmer:
self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks
if not os.path.exists(self.database_path) and num_streamed_chunks is None:
logging.error('Could not find Jackhmmer database %s', database_path)
raise ValueError(f'Could not find Jackhmmer database {database_path}')
if (
not os.path.exists(self.database_path)
and num_streamed_chunks is None
):
logging.error("Could not find Jackhmmer database %s", database_path)
raise ValueError(
f"Could not find Jackhmmer database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
......@@ -85,11 +92,12 @@ class Jackhmmer:
self.get_tblout = get_tblout
self.streaming_callback = streaming_callback
def _query_chunk(self, input_fasta_path: str, database_path: str
def _query_chunk(
self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
......@@ -98,48 +106,63 @@ class Jackhmmer:
# amount of time.
cmd_flags = [
# Don't pollute stdout with Jackhmmer output.
'-o', '/dev/null',
'-A', sto_path,
'--noali',
'--F1', str(self.filter_f1),
'--F2', str(self.filter_f2),
'--F3', str(self.filter_f3),
'--incE', str(self.e_value),
"-o",
"/dev/null",
"-A",
sto_path,
"--noali",
"--F1",
str(self.filter_f1),
"--F2",
str(self.filter_f2),
"--F3",
str(self.filter_f3),
"--incE",
str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output.
'-E', str(self.e_value),
'--cpu', str(self.n_cpu),
'-N', str(self.n_iter)
"-E",
str(self.e_value),
"--cpu",
str(self.n_cpu),
"-N",
str(self.n_iter),
]
if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
cmd_flags.extend(['--tblout', tblout_path])
tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
cmd_flags.extend(["--tblout", tblout_path])
if self.z_value:
cmd_flags.extend(['-Z', str(self.z_value)])
cmd_flags.extend(["-Z", str(self.z_value)])
if self.dom_e is not None:
cmd_flags.extend(['--domE', str(self.dom_e)])
cmd_flags.extend(["--domE", str(self.dom_e)])
if self.incdom_e is not None:
cmd_flags.extend(['--incdomE', str(self.incdom_e)])
cmd_flags.extend(["--incdomE", str(self.incdom_e)])
cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
database_path]
cmd = (
[self.binary_path]
+ cmd_flags
+ [input_fasta_path, database_path]
)
logging.info('Launching subprocess "%s"', ' '.join(cmd))
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing(
f'Jackhmmer ({os.path.basename(database_path)}) query'):
f"Jackhmmer ({os.path.basename(database_path)}) query"
):
_, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
"Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
)
# Get e-values for each target name
tbl = ''
tbl = ""
if self.get_tblout:
with open(tblout_path) as f:
tbl = f.read()
......@@ -152,7 +175,8 @@ class Jackhmmer:
tbl=tbl,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
e_value=self.e_value,
)
return raw_output
......@@ -162,15 +186,15 @@ class Jackhmmer:
return [self._query_chunk(input_fasta_path, self.database_path)]
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
# Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk('[0-9]*')):
for f in glob.glob(db_local_chunk("[0-9]*")):
try:
os.remove(f)
except OSError:
print(f'OSError while deleting {f}')
print(f"OSError while deleting {f}")
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
......@@ -179,15 +203,22 @@ class Jackhmmer:
# Copy the chunk locally
if i == 1:
future = executor.submit(
request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
request.urlretrieve,
db_remote_chunk(i),
db_local_chunk(i),
)
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
request.urlretrieve,
db_remote_chunk(i + 1),
db_local_chunk(i + 1),
)
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)))
self._query_chunk(input_fasta_path, db_local_chunk(i))
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
......
......@@ -25,12 +25,12 @@ from openfold.data.tools import utils
def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append(u'>' + name + u'\n')
a3m.append(sequence + u'\n')
return ''.join(a3m)
a3m.append(u">" + name + u"\n")
a3m.append(sequence + u"\n")
return "".join(a3m)
class Kalign:
......@@ -63,40 +63,51 @@ class Kalign:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info('Aligning %d sequences', len(sequences))
logging.info("Aligning %d sequences", len(sequences))
for s in sequences:
if len(s) < 6:
raise ValueError('Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s)))
raise ValueError(
"Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
with open(input_fasta_path, 'w') as f:
with open(input_fasta_path, "w") as f:
f.write(_to_a3m(sequences))
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-o', output_a3m_path,
'-format', 'fasta',
"-i",
input_fasta_path,
"-o",
output_a3m_path,
"-format",
"fasta",
]
logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing('Kalign query'):
with utils.timing("Kalign query"):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
logging.info(
"Kalign stdout:\n%s\n\nstderr:\n%s\n",
stdout.decode("utf-8"),
stderr.decode("utf-8"),
)
if retcode:
raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
raise RuntimeError(
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(output_a3m_path) as f:
a3m = f.read()
......
......@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None):
@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
logging.info("Started %s", msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
def to_date(s: str):
......
......@@ -3,13 +3,14 @@ import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")]
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__]
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -26,6 +26,7 @@ class Dropout(nn.Module):
If not in training mode, this module computes the identity function.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
"""
Args:
......@@ -37,7 +38,7 @@ class Dropout(nn.Module):
super(Dropout, self).__init__()
self.r = r
if(type(batch_dim) == int):
if type(batch_dim) == int:
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
......@@ -50,7 +51,7 @@ class Dropout(nn.Module):
compatible with self.batch_dim
"""
shape = list(x.shape)
if(self.batch_dim is not None):
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape)
......@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout):
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
......@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout):
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
......@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module):
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
......@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module):
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self,
ri: torch.Tensor
):
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
......@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module):
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(self,
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
......@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
"""
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs
**kwargs,
):
"""
Args:
......@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module):
self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
......@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if(self.bins is None):
if self.bins is None:
self.bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device
device=x.device,
)
# [*, N, C_m]
......@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module):
# couldn't find in time.
squared_bins = self.bins ** 2
upper = torch.cat(
[
squared_bins[1:],
squared_bins.new_tensor([self.inf])
], dim=-1
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2,
dim=-1,
keepdims=True
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
......@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module):
Implements Algorithm 2, line 7.
"""
def __init__(self,
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
......@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module):
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
......@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module):
Implements Algorithm 2, line 9.
"""
def __init__(self,
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
......@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module):
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(self,
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
......@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
"""
def __init__(self,
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
......@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module):
self.linear = Linear(self.c_in, self.c_out)
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
......
......@@ -45,6 +45,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
"""
def __init__(self, c_m, n, chunk_size):
"""
Args:
......@@ -71,7 +72,8 @@ class MSATransition(nn.Module):
m = self.linear_2(m) * mask
return m
def forward(self,
def forward(
self,
m: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
......@@ -86,7 +88,7 @@ class MSATransition(nn.Module):
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if(mask is None):
if mask is None:
mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1)
......@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
m = self.layer_norm(m)
inp = {"m": m, "mask": mask}
if(self.chunk_size is not None):
if self.chunk_size is not None:
m = chunk_layer(
self._transition,
inp,
......@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
class EvoformerBlock(nn.Module):
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
inf=inf,
)
if(_is_extra_msa_stack):
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
......@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module):
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
......@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module):
Implements Algorithm 6.
"""
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module):
)
self.blocks.append(block)
if(not self._is_extra_msa_stack):
if not self._is_extra_msa_stack:
self.linear = Linear(c_m, c_s)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
......@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans,
) for b in self.blocks
)
for b in self.blocks
],
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = None
if(not self._is_extra_msa_stack):
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
......@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module):
_is_extra_msa_stack=True,
)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
......@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module):
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans
_mask_trans=_mask_trans,
)
return z
......@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
**config["experimentally_resolved"],
)
if(config.tm.enabled):
if config.tm.enabled:
self.tm = TMScoreHead(
**config.tm,
)
......@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"]
)
aux_out["experimentally_resolved_logits"] = (
experimentally_resolved_logits
)
aux_out[
"experimentally_resolved_logits"
] = experimentally_resolved_logits
if(self.config.tm.enabled):
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(compute_predicted_aligned_error(
tm_logits, **self.config.tm,
))
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
**self.config.tm,
)
)
return aux_out
......@@ -118,6 +121,7 @@ class DistogramHead(nn.Module):
For use in computation of distogram loss, subsection 1.9.8
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
......@@ -133,9 +137,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self,
z # [*, N, N, C_z]
):
def forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
......@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
......@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def __init__(self, c_m, c_out, **kwargs):
"""
Args:
......@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module):
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def __init__(self, c_s, c_out, **kwargs):
"""
Args:
......
......@@ -54,6 +54,7 @@ class AlphaFold(nn.Module):
Implements Algorithm 2 (but with training).
"""
def __init__(self, config):
"""
Args:
......@@ -115,7 +116,7 @@ class AlphaFold(nn.Module):
)
single_template_embeds = {}
if(self.config.template.embed_angles):
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
......@@ -130,18 +131,18 @@ class AlphaFold(nn.Module):
single_template_feats,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram
**self.config.template.distogram,
)
t = self.template_pair_embedder(t)
t = self.template_pair_stack(
t,
pair_mask.unsqueeze(-3),
_mask_trans=self.config._mask_trans
t, pair_mask.unsqueeze(-3), _mask_trans=self.config._mask_trans
)
single_template_embeds.update({
single_template_embeds.update(
{
"pair": t,
})
}
)
template_embeds.append(single_template_embeds)
......@@ -152,19 +153,19 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
t = self.template_pointwise_att(
template_embeds["pair"],
z,
template_mask=batch["template_mask"]
template_embeds["pair"], z, template_mask=batch["template_mask"]
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if(self.config.template.embed_angles):
if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({
ret.update(
{
"template_pair_embedding": t,
})
}
)
return ret
......@@ -195,9 +196,9 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if(self.config.num_recycle > 0):
if self.config.num_recycle > 0:
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
......@@ -213,11 +214,7 @@ class AlphaFold(nn.Module):
(*batch_dims, n, residue_constants.atom_type_num, 3),
)
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None
)
x_prev = pseudo_beta_fn(feats["aatype"], x_prev, None)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
......@@ -237,9 +234,9 @@ class AlphaFold(nn.Module):
del m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
if self.config.template.enabled:
template_feats = {
k:v for k,v in feats.items() if k.startswith("template_")
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
......@@ -251,11 +248,10 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
[m, template_embeds["template_angle_embedding"]], dim=-3
)
# [*, S, N]
......@@ -265,7 +261,7 @@ class AlphaFold(nn.Module):
)
# Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled):
if self.config.extra_msa.enabled:
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
......@@ -287,7 +283,7 @@ class AlphaFold(nn.Module):
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :]
......@@ -296,7 +292,10 @@ class AlphaFold(nn.Module):
# Predict 3D structure
outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"],
s,
z,
feats["aatype"],
mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
......@@ -397,16 +396,19 @@ class AlphaFold(nn.Module):
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.num_recycle)
is_final_iter = cycle_no == self.config.num_recycle
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter):
if is_final_iter:
self._enable_activation_checkpointing()
if(torch.is_autocast_enabled()):
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
feats,
m_1_prev,
z_prev,
x_prev,
)
# Run auxiliary heads
......
......@@ -27,7 +27,8 @@ from openfold.utils.tensor_utils import (
class MSAAttention(nn.Module):
def __init__(self,
def __init__(
self,
c_in,
c_hidden,
no_heads,
......@@ -64,17 +65,14 @@ class MSAAttention(nn.Module):
self.layer_norm_m = nn.LayerNorm(self.c_in)
if(self.pair_bias):
if self.pair_bias:
self.layer_norm_z = nn.LayerNorm(self.c_z)
self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
def forward(self, m, z=None, mask=None):
......@@ -92,7 +90,7 @@ class MSAAttention(nn.Module):
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1]
if(mask is None):
if mask is None:
# [*, N_seq, N_res]
mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res),
......@@ -106,7 +104,7 @@ class MSAAttention(nn.Module):
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
biases = [bias]
if(self.pair_bias):
if self.pair_bias:
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
......@@ -118,18 +116,13 @@ class MSAAttention(nn.Module):
biases.append(z)
mha_inputs = {
"q_x": m,
"k_x": m,
"v_x": m,
"biases": biases
}
if(self.chunk_size is not None):
mha_inputs = {"q_x": m, "k_x": m, "v_x": m, "biases": biases}
if self.chunk_size is not None:
m = chunk_layer(
self.mha,
mha_inputs,
chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2])
no_batch_dims=len(m.shape[:-2]),
)
else:
m = self.mha(**mha_inputs)
......@@ -141,6 +134,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
"""
Implements Algorithm 7.
"""
def __init__(self, c_m, c_z, c_hidden, no_heads, chunk_size, inf=1e9):
"""
Args:
......@@ -170,6 +164,7 @@ class MSAColumnAttention(MSAAttention):
"""
Implements Algorithm 8.
"""
def __init__(self, c_m, c_hidden, no_heads, chunk_size=4, inf=1e9):
"""
Args:
......@@ -192,7 +187,6 @@ class MSAColumnAttention(MSAAttention):
inf=inf,
)
def forward(self, m, mask=None):
"""
Args:
......@@ -203,26 +197,21 @@ class MSAColumnAttention(MSAAttention):
"""
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
if(mask is not None):
if mask is not None:
mask = mask.transpose(-1, -2)
m = super().forward(m, mask=mask)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
if(mask is not None):
if mask is not None:
mask = mask.transpose(-1, -2)
return m
class MSAColumnGlobalAttention(nn.Module):
def __init__(self,
c_in,
c_hidden,
no_heads,
chunk_size=4,
inf=1e9,
eps=1e-10
def __init__(
self, c_in, c_hidden, no_heads, chunk_size=4, inf=1e9, eps=1e-10
):
super(MSAColumnGlobalAttention, self).__init__()
......@@ -243,13 +232,12 @@ class MSAColumnGlobalAttention(nn.Module):
eps=eps,
)
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None
def forward(
self, m: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
if(mask is None):
if mask is None:
# [*, N_seq, N_res]
mask = torch.ones(
m.shape[:-1],
......@@ -268,12 +256,12 @@ class MSAColumnGlobalAttention(nn.Module):
"m": m,
"mask": mask,
}
if(self.chunk_size is not None):
if self.chunk_size is not None:
m = chunk_layer(
self.global_attention,
mha_input,
chunk_size=self.chunk_size,
no_batch_dims=len(m.shape[:-2])
no_batch_dims=len(m.shape[:-2]),
)
else:
m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment