Commit d48c052c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add training parsers

parent eeda001c
......@@ -6,42 +6,42 @@ def set_inf(c, inf):
for k, v in c.items():
if(isinstance(v, mlc.ConfigDict)):
set_inf(v, inf)
elif(k == "inf"):
elif(k == 'inf'):
c[k] = inf
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
if(name == "model_1"):
if(name == 'model_1'):
pass
elif(name == "model_2"):
elif(name == 'model_2'):
pass
elif(name == "model_3"):
elif(name == 'model_3'):
c.model.template.enabled = False
elif(name == "model_4"):
elif(name == 'model_4'):
c.model.template.enabled = False
elif(name == "model_5"):
elif(name == 'model_5'):
c.model.template.enabled = False
elif(name == "model_1_ptm"):
elif(name == 'model_1_ptm'):
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == "model_2_ptm"):
elif(name == 'model_2_ptm'):
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == "model_3_ptm"):
elif(name == 'model_3_ptm'):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == "model_4_ptm"):
elif(name == 'model_4_ptm'):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif(name == "model_5_ptm"):
elif(name == 'model_5_ptm'):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
raise ValueError('Invalid model name')
if(train):
c.globals.blocks_per_ckpt = 1
......@@ -65,6 +65,9 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
num_recycle = mlc.FieldReference(3, field_type=int)
templates_enabled = mlc.FieldReference(True, field_type=bool)
embed_template_torsion_angles = mlc.FieldReference(True, field_type=bool)
NUM_RES = 'num residues placeholder'
NUM_MSA_SEQ = 'msa placeholder'
......@@ -74,29 +77,7 @@ NUM_TEMPLATES = 'num templates placeholder'
config = mlc.ConfigDict({
'data': {
'common': {
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
'msa_cluster_features': True,
'num_recycle': 3,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_masks',
# 'template_domain_names'
],
'unsupervised_features': [
'aatype', 'residue_index', 'msa', # 'sequence', #'domain_name',
'num_alignments', 'seq_length', 'between_segment_residues',
'deletion_matrix'
],
'use_templates': True,
},
'eval': {
'batch_modes': [('clamped', 0.9), ('unclamped', 0.1)],
'feat': {
'aatype': [NUM_RES],
'all_atom_mask': [NUM_RES, None],
......@@ -110,7 +91,7 @@ config = mlc.ConfigDict({
'atom14_gt_positions': [NUM_RES, None, None],
'atom37_atom_exists': [NUM_RES, None],
'backbone_affine_mask': [NUM_RES],
'backbone_affine_tensor': [NUM_RES, None],
'backbone_affine_tensor': [NUM_RES, None, None],
'bert_mask': [NUM_MSA_SEQ, NUM_RES],
'chi_angles': [NUM_RES, None],
'chi_mask': [NUM_RES, None],
......@@ -125,266 +106,333 @@ config = mlc.ConfigDict({
'msa_row_mask': [NUM_MSA_SEQ],
'pseudo_beta': [NUM_RES, None],
'pseudo_beta_mask': [NUM_RES],
'random_crop_to_size_seed': [None],
'residue_index': [NUM_RES],
'residx_atom14_to_atom37': [NUM_RES, None],
'residx_atom37_to_atom14': [NUM_RES, None],
'resolution': [],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None],
'rigidgroups_alt_gt_frames': [NUM_RES, None, None, None],
'rigidgroups_group_exists': [NUM_RES, None],
'rigidgroups_group_is_ambiguous': [NUM_RES, None],
'rigidgroups_gt_exists': [NUM_RES, None],
'rigidgroups_gt_frames': [NUM_RES, None, None],
'rigidgroups_gt_frames': [NUM_RES, None, None, None],
'seq_length': [],
'seq_mask': [NUM_RES],
'target_feat': [NUM_RES, None],
'template_aatype': [NUM_TEMPLATES, NUM_RES],
'template_all_atom_masks': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions': [
NUM_TEMPLATES, NUM_RES, None, None],
'template_all_atom_mask': [NUM_TEMPLATES, NUM_RES, None],
'template_all_atom_positions':
[NUM_TEMPLATES, NUM_RES, None, None],
'template_alt_torsion_angles_sin_cos':
[NUM_TEMPLATES, NUM_RES, None, None],
'template_backbone_affine_mask': [NUM_TEMPLATES, NUM_RES],
'template_backbone_affine_tensor': [
NUM_TEMPLATES, NUM_RES, None],
NUM_TEMPLATES, NUM_RES, None, None],
'template_mask': [NUM_TEMPLATES],
'template_pseudo_beta': [NUM_TEMPLATES, NUM_RES, None],
'template_pseudo_beta_mask': [NUM_TEMPLATES, NUM_RES],
'template_sum_probs': [NUM_TEMPLATES, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES]
'template_torsion_angles_mask': [NUM_TEMPLATES, NUM_RES, None],
'template_torsion_angles_sin_cos':
[NUM_TEMPLATES, NUM_RES, None, None],
'true_msa': [NUM_MSA_SEQ, NUM_RES],
'use_clamped_fape': [],
},
'masked_msa': {
'profile_prob': 0.1,
'same_prob': 0.1,
'uniform_prob': 0.1
},
'max_extra_msa': 1024,
'msa_cluster_features': True,
'num_recycle': num_recycle,
'reduce_msa_clusters_by_max_templates': False,
'resample_msa_in_recycling': True,
'template_features': [
'template_all_atom_positions', 'template_sum_probs',
'template_aatype', 'template_all_atom_mask',
],
'unsupervised_features': [
'aatype', 'residue_index', 'msa', 'num_alignments',
'seq_length', 'between_segment_residues', 'deletion_matrix'
],
'use_templates': templates_enabled,
'use_template_torsion_angles': embed_template_torsion_angles,
'supervised_features': [
'all_atom_mask', 'all_atom_positions', 'resolution',
'use_clamped_fape',
],
},
'predict': {
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': False,
'crop_size': None,
'supervised': False,
},
'eval': {
'fixed_size': True,
'subsample_templates': False, # We want top templates.
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': False,
'crop_size': None,
'supervised': True,
},
'train': {
'fixed_size': True,
'subsample_templates': True,
'masked_msa_replace_fraction': 0.15,
'max_msa_clusters': 512,
'max_templates': 4,
'num_ensemble': 1,
'crop': True,
'crop_size': 256,
'supervised': True,
},
'data_module': {
'use_small_bfd': False,
'data_loaders': {
'batch_size': 1,
'num_workers': 1,
},
}
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"c_z": c_z,
"c_m": c_m,
"c_t": c_t,
"c_e": c_e,
"c_s": c_s,
"eps": eps,
'globals': {
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'c_z': c_z,
'c_m': c_m,
'c_t': c_t,
'c_e': c_e,
'c_s': c_s,
'eps': eps,
},
"model": {
"no_cycles": 4,
"_mask_trans": False,
"input_embedder": {
"tf_dim": 22,
"msa_dim": 49,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
'model': {
'num_recycle': num_recycle,
'_mask_trans': False,
'input_embedder': {
'tf_dim': 22,
'msa_dim': 49,
'c_z': c_z,
'c_m': c_m,
'relpos_k': 32,
},
"recycling_embedder": {
"c_z": c_z,
"c_m": c_m,
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": 1e8,
'recycling_embedder': {
'c_z': c_z,
'c_m': c_m,
'min_bin': 3.25,
'max_bin': 20.75,
'no_bins': 15,
'inf': 1e8,
},
"template": {
"distogram": {
"min_bin": 3.25,
"max_bin": 50.75,
"no_bins": 39,
'template': {
'distogram': {
'min_bin': 3.25,
'max_bin': 50.75,
'no_bins': 39,
},
"template_angle_embedder": {
'template_angle_embedder': {
# DISCREPANCY: c_in is supposed to be 51.
"c_in": 57,
"c_out": c_m,
'c_in': 57,
'c_out': c_m,
},
"template_pair_embedder": {
"c_in": 88,
"c_out": c_t,
'template_pair_embedder': {
'c_in': 88,
'c_out': c_t,
},
"template_pair_stack": {
"c_t": c_t,
'template_pair_stack': {
'c_t': c_t,
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
# as 64. In the code, it's 16.
"c_hidden_tri_att": 16,
"c_hidden_tri_mul": 64,
"no_blocks": 2,
"no_heads": 4,
"pair_transition_n": 2,
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5,#1e9,
'c_hidden_tri_att': 16,
'c_hidden_tri_mul': 64,
'no_blocks': 2,
'no_heads': 4,
'pair_transition_n': 2,
'dropout_rate': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
"c_z": c_z,
'template_pointwise_attention': {
'c_t': c_t,
'c_z': c_z,
# DISCREPANCY: c_hidden here is given in the supplement as 64.
# It's actually 16.
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": 1e5,#1e9,
'c_hidden': 16,
'no_heads': 4,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
},
"inf": 1e5,#1e9,
"eps": eps,#1e-6,
"enabled": True,
"embed_angles": True,
'inf': 1e5,#1e9,
'eps': eps,#1e-6,
'enabled': templates_enabled,
'embed_angles': embed_template_torsion_angles,
},
"extra_msa": {
"extra_msa_embedder": {
"c_in": 25,
"c_out": c_e,
'extra_msa': {
'extra_msa_embedder': {
'c_in': 25,
'c_out': c_e,
},
"extra_msa_stack": {
"c_m": c_e,
"c_z": c_z,
"c_hidden_msa_att": 8,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 4,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5,#1e9,
"eps": eps,#1e-10,
'extra_msa_stack': {
'c_m': c_e,
'c_z': c_z,
'c_hidden_msa_att': 8,
'c_hidden_opm': 32,
'c_hidden_mul': 128,
'c_hidden_pair_att': 32,
'no_heads_msa': 8,
'no_heads_pair': 4,
'no_blocks': 4,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
'eps': eps,#1e-10,
},
"enabled": True,
'enabled': True,
},
"evoformer_stack": {
"c_m": c_m,
"c_z": c_z,
"c_hidden_msa_att": 32,
"c_hidden_opm": 32,
"c_hidden_mul": 128,
"c_hidden_pair_att": 32,
"c_s": c_s,
"no_heads_msa": 8,
"no_heads_pair": 4,
"no_blocks": 48,
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": 1e5,#1e9,
"eps": eps,#1e-10,
'evoformer_stack': {
'c_m': c_m,
'c_z': c_z,
'c_hidden_msa_att': 32,
'c_hidden_opm': 32,
'c_hidden_mul': 128,
'c_hidden_pair_att': 32,
'c_s': c_s,
'no_heads_msa': 8,
'no_heads_pair': 4,
'no_blocks': 48,
'transition_n': 4,
'msa_dropout': 0.15,
'pair_dropout': 0.25,
'blocks_per_ckpt': blocks_per_ckpt,
'chunk_size': chunk_size,
'inf': 1e5,#1e9,
'eps': eps,#1e-10,
},
"structure_module": {
"c_s": c_s,
"c_z": c_z,
"c_ipa": 16,
"c_resnet": 128,
"no_heads_ipa": 12,
"no_qk_points": 4,
"no_v_points": 8,
"dropout_rate": 0.1,
"no_blocks": 8,
"no_transition_layers": 1,
"no_resnet_blocks": 2,
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps,#1e-12,
"inf": 1e5,
'structure_module': {
'c_s': c_s,
'c_z': c_z,
'c_ipa': 16,
'c_resnet': 128,
'no_heads_ipa': 12,
'no_qk_points': 4,
'no_v_points': 8,
'dropout_rate': 0.1,
'no_blocks': 8,
'no_transition_layers': 1,
'no_resnet_blocks': 2,
'no_angles': 7,
'trans_scale_factor': 10,
'epsilon': eps,#1e-12,
'inf': 1e5,
},
"heads": {
"lddt": {
"no_bins": 50,
"c_in": c_s,
"c_hidden": 128,
'heads': {
'lddt': {
'no_bins': 50,
'c_in': c_s,
'c_hidden': 128,
},
"distogram": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
'distogram': {
'c_z': c_z,
'no_bins': aux_distogram_bins,
},
"tm": {
"c_z": c_z,
"no_bins": aux_distogram_bins,
"enabled": False,
'tm': {
'c_z': c_z,
'no_bins': aux_distogram_bins,
'enabled': False,
},
"masked_msa": {
"c_m": c_m,
"c_out": 23,
'masked_msa': {
'c_m': c_m,
'c_out': 23,
},
"experimentally_resolved": {
"c_s": c_s,
"c_out": 37,
'experimentally_resolved': {
'c_s': c_s,
'c_out': 37,
},
},
},
"relax": {
"max_iterations": 0, # no max
"tolerance": 2.39,
"stiffness": 10.0,
"max_outer_iterations": 20,
"exclude_residues": [],
'relax': {
'max_iterations': 0, # no max
'tolerance': 2.39,
'stiffness': 10.0,
'max_outer_iterations': 20,
'exclude_residues': [],
},
"loss": {
"distogram": {
"min_bin": 2.3125,
"max_bin": 21.6875,
"no_bins": 64,
"eps": eps,#1e-6,
"weight": 0.3,
'loss': {
'distogram': {
'min_bin': 2.3125,
'max_bin': 21.6875,
'no_bins': 64,
'eps': eps,#1e-6,
'weight': 0.3,
},
"experimentally_resolved": {
"eps": eps,#1e-8,
"min_resolution": 0.1,
"max_resolution": 3.0,
"weight": 0.,
'experimentally_resolved': {
'eps': eps,#1e-8,
'min_resolution': 0.1,
'max_resolution': 3.0,
'weight': 0.,
},
"fape": {
"backbone": {
"clamp_distance": 10.,
"loss_unit_distance": 10.,
"weight": 0.5,
'fape': {
'backbone': {
'clamp_distance': 10.,
'loss_unit_distance': 10.,
'weight': 0.5,
},
"sidechain": {
"clamp_distance": 10.,
"length_scale": 10.,
"weight": 0.5,
'sidechain': {
'clamp_distance': 10.,
'length_scale': 10.,
'weight': 0.5,
},
"eps": 1e-4,
"weight": 1.0,
'eps': 1e-4,
'weight': 1.0,
},
"lddt": {
"min_resolution": 0.1,
"max_resolution": 3.0,
"cutoff": 15.,
"no_bins": 50,
"eps": eps,#1e-10,
"weight": 0.01,
'lddt': {
'min_resolution': 0.1,
'max_resolution': 3.0,
'cutoff': 15.,
'no_bins': 50,
'eps': eps,#1e-10,
'weight': 0.01,
},
"masked_msa": {
"eps": eps,#1e-8,
"weight": 2.0,
'masked_msa': {
'eps': eps,#1e-8,
'weight': 2.0,
},
"supervised_chi": {
"chi_weight": 0.5,
"angle_norm_weight": 0.01,
"eps": eps,#1e-6,
"weight": 1.0,
'supervised_chi': {
'chi_weight': 0.5,
'angle_norm_weight': 0.01,
'eps': eps,#1e-6,
'weight': 1.0,
},
"violation": {
"violation_tolerance_factor": 12.0,
"clash_overlap_tolerance": 1.5,
"eps": eps,#1e-6,
"weight": 0.,
'violation': {
'violation_tolerance_factor': 12.0,
'clash_overlap_tolerance': 1.5,
'eps': eps,#1e-6,
'weight': 0.,
},
"tm": {
"max_bin": 31,
"no_bins": 64,
"min_resolution": 0.1,
"max_resolution": 3.0,
"eps": eps,#1e-8,
"weight": 0.,
'tm': {
'max_bin': 31,
'no_bins': 64,
'min_resolution': 0.1,
'max_resolution': 3.0,
'eps': eps,#1e-8,
'weight': 0.,
},
"eps": eps,
'eps': eps,
},
'ema': {
'decay': 0.999
},
})
import os
import datetime
import numpy as np
from typing import Mapping, Optional, Sequence
from typing import Mapping, Optional, Sequence, Any
from openfold.features import templates, parsers
from openfold.features import templates, parsers, mmcif_parsing
from openfold.features.np import jackhmmer, hhblits, hhsearch
from openfold.features.np.utils import to_date
from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray]
FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
def make_sequence_features(
sequence: str,
description: str,
num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
......@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
map_unknown_to_x=True
)
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array([description.encode('utf-8')], dtype=np.object_)
features['domain_name'] = np.array(
[description.encode('utf-8')], dtype=np.object_
)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
features['sequence'] = np.array(
[sequence.encode('utf-8')], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject,
chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
))
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_
)
return mmcif_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
......@@ -58,9 +101,9 @@ def make_msa_features(
)
return features
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
class AlignmentRunner:
""" Runs alignment tools and saves the results """
def __init__(self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
......@@ -71,106 +114,158 @@ class DataPipeline:
uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
pdb70_database_path: str,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000
no_cpus: int,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
"""Constructs a feature dict for a given FASTA file."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path
database_path=small_bfd_database_path,
n_cpu=no_cpus,
)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path]
databases=[bfd_database_path, uniclust30_database_path],
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path]
)
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.'
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(input_fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(input_fasta_path)[0]
def run(self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits
)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'w') as f:
f.write(jackhmmer_uniref90_result['sto'])
f.write(uniref90_msa_as_a3m)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.so')
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'w') as f:
f.write(jackhmmer_mgnify_result['sto'])
f.write(mgnify_msa_as_a3m)
pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr')
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f:
f.write(hhsearch_result)
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_uniref90_result['sto']
)
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_mgnify_result['sto']
)
hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(input_fasta_path)[0]
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m')
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0]
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto'])
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path)
if(output_dir is not None):
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m'])
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_small_bfd_result['sto']
class DataPipeline:
"""Assembles input features."""
def __init__(self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
):
self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self,
alignment_dir: str,
) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'r') as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'r') as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(
f.read()
)
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(input_fasta_path)
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m'])
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
hhblits_bfd_uniclust_result['a3m']
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'r') as f:
hhsearch_hits = parsers.parse_hhr(
f.read()
)
if(self.use_small_bfd):
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
)
else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
f.read()
)
return {
'uniref90_msa': uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits,
'bfd_msa': bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix,
}
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
alignments = self._parse_alignment_output(alignment_dir)
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=hhsearch_hits
hits=alignments['hhsearch_hits']
)
sequence_features = make_sequence_features(
......@@ -180,9 +275,62 @@ class DataPipeline:
)
msa_features = make_msa_features(
msas=(uniref90_msa, bfd_msa, mgnify_msa),
deletion_matrices = (uniref90_deletion_matrix,
bfd_deletion_matrix,
mgnify_deletion_matrix)
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
),
deletion_matrices=(
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**sequence_features, **msa_features, **templates_result.features}
def process_mmcif(self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if(chain_id is None):
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if(chain is None):
raise ValueError(
'No chains in mmCIF file'
)
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits']
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**mmcif_feats, **templates_result.features, **msa_features}
......@@ -6,8 +6,10 @@ import torch
from operator import add
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
from openfold.np import residue_constants
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
from openfold.np import residue_constants as rc
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import tree_map, tensor_tree_map, batched_gather
MSA_FEATURE_NAMES = [
'msa', 'deletion_matrix', 'msa_mask', 'msa_row_mask', 'bert_mask', 'true_msa'
......@@ -59,7 +61,7 @@ def fix_templates_aatype(protein):
num_templates = protein['template_aatype'].shape[0]
protein['template_aatype'] = torch.argmax(protein['template_aatype'], dim=-1)
# Map hhsearch-aatype to our aatype.
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
new_order_list, dtype=torch.int64
).expand(num_templates, -1)
......@@ -69,8 +71,8 @@ def fix_templates_aatype(protein):
return protein
def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as residue_constants."""
new_order_list = residue_constants.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
"""Correct MSA restype to have the same order as rc."""
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order = torch.tensor(
[new_order_list]*protein['msa'].shape[1], dtype=protein['msa'].dtype
).transpose(0,1)
......@@ -93,7 +95,7 @@ def squeeze_features(protein):
for k in [
'domain_name', 'msa', 'num_alignments', 'seq_length', 'sequence',
'superfamily', 'deletion_matrix', 'resolution',
'between_segment_residues', 'residue_index', 'template_all_atom_masks']:
'between_segment_residues', 'residue_index', 'template_all_atom_mask']:
if k in protein:
final_dim = protein[k].shape[-1]
if isinstance(final_dim, int) and final_dim == 1:
......@@ -104,12 +106,6 @@ def squeeze_features(protein):
protein[k] = protein[k][0]
return protein
def make_protein_crop_to_size_seed(protein):
protein['random_crop_to_size_seed'] = torch.distributions.Uniform(
low=torch.int32, high=torch.int32).sample((2)
)
return protein
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
"""Replace a portion of the MSA with 'X'."""
......@@ -284,19 +280,19 @@ def make_msa_mask(protein):
protein['msa_row_mask'] = torch.ones(protein['msa'].shape[0], dtype=torch.float32)
return protein
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
"""Create pseudo beta features."""
is_gly = torch.eq(aatype, residue_constants.restype_order['G'])
ca_idx = residue_constants.atom_order['CA']
cb_idx = residue_constants.atom_order['CB']
is_gly = torch.eq(aatype, rc.restype_order['G'])
ca_idx = rc.atom_order['CA']
cb_idx = rc.atom_order['CB']
pseudo_beta = torch.where(
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
all_atom_positions[..., ca_idx, :],
all_atom_positions[..., cb_idx, :])
if all_atom_masks is not None:
if all_atom_mask is not None:
pseudo_beta_mask = torch.where(
is_gly, all_atom_masks[..., ca_idx], all_atom_masks[..., cb_idx])
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx])
return pseudo_beta, pseudo_beta_mask
else:
return pseudo_beta
......@@ -307,9 +303,9 @@ def make_pseudo_beta(protein, prefix=''):
assert prefix in ['', 'template_']
protein[prefix + 'pseudo_beta'], protein[prefix + 'pseudo_beta_mask'] = (
pseudo_beta_fn(
protein['template_aatype' if prefix else 'all_atom_aatype'],
protein['template_aatype' if prefix else 'aatype'],
protein[prefix + 'all_atom_positions'],
protein['template_all_atom_masks' if prefix else 'all_atom_mask']))
protein['template_all_atom_mask' if prefix else 'all_atom_mask']))
return protein
@curry1
......@@ -456,10 +452,12 @@ def make_msa_feat(protein):
protein['target_feat'] = torch.cat(target_feat, dim=-1)
return protein
@curry1
def select_feat(protein, feature_list):
return {k: v for k, v in protein.items() if k in feature_list}
@curry1
def crop_templates(protein, max_templates):
for k, v in protein.items():
......@@ -467,72 +465,74 @@ def crop_templates(protein, max_templates):
protein[k] = v[:max_templates]
return protein
def make_atom14_masks(protein):
"""Construct denser atom positions (14 dimensions instead of 37)."""
restype_atom14_to_atom37 = []
restype_atom37_to_atom14 = []
restype_atom14_mask = []
for rt in residue_constants.restypes:
atom_names = residue_constants.restype_name_to_atom14_names[
residue_constants.restype_1to3[rt]
for rt in rc.restypes:
atom_names = rc.restype_name_to_atom14_names[
rc.restype_1to3[rt]
]
restype_atom14_to_atom37.append([
(residue_constants.atom_order[name] if name else 0)
(rc.atom_order[name] if name else 0)
for name in atom_names
])
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
restype_atom37_to_atom14.append([
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
for name in residue_constants.atom_types
for name in rc.atom_types
])
# Since all 14 atoms are not present in every residue, use this mask to
# tell which atom is there in this residue
restype_atom14_mask.append([(1. if name else 0.) for name in atom_names])
# Add dummy mapping for restype 'UNK'
restype_atom14_to_atom37.append([0] * 14)
restype_atom37_to_atom14.append([0] * 37)
restype_atom14_mask.append([0.] * 14)
restype_atom14_to_atom37 = torch.tensor(
restype_atom14_to_atom37, dtype=torch.int32
restype_atom14_to_atom37,
dtype=torch.int32,
device=protein['aatype'].device,
)
restype_atom37_to_atom14 = torch.tensor(
restype_atom37_to_atom14, dtype=torch.int32
restype_atom37_to_atom14,
dtype=torch.int32,
device=protein['aatype'].device,
)
restype_atom14_mask = torch.tensor(
restype_atom14_mask, dtype=torch.float32
restype_atom14_mask,
dtype=torch.float32,
device=protein['aatype'].device,
)
# create the mapping for (residx, atom14) --> atom37, i.e. an array
# with shape (num_res, 14) containing the atom37 indices for this protein
residx_atom14_to_atom37 = torch.index_select(
restype_atom14_to_atom37, 0, protein['aatype']
)
residx_atom14_mask = torch.index_select(
restype_atom14_mask, 0, protein['aatype']
)
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein['aatype']]
residx_atom14_mask = restype_atom14_mask[protein['aatype']]
protein['atom14_atom_exists'] = residx_atom14_mask
protein['residx_atom14_to_atom37'] = residx_atom14_to_atom37.long()
# create the gather indices for mapping back
residx_atom37_to_atom14 = torch.index_select(
restype_atom37_to_atom14, 0, protein['aatype']
)
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein['aatype']]
protein['residx_atom37_to_atom14'] = residx_atom37_to_atom14.long()
# create the corresponding mask
restype_atom37_mask = torch.zeros([21, 37], dtype=torch.float32)
for restype, restype_letter in enumerate(residue_constants.restypes):
restype_name = residue_constants.restype_1to3[restype_letter]
atom_names = residue_constants.residue_atoms[restype_name]
restype_atom37_mask = torch.zeros(
[21, 37], dtype=torch.float32, device=protein['aatype'].device
)
for restype, restype_letter in enumerate(rc.restypes):
restype_name = rc.restype_1to3[restype_letter]
atom_names = rc.residue_atoms[restype_name]
for atom_name in atom_names:
atom_type = residue_constants.atom_order[atom_name]
atom_type = rc.atom_order[atom_name]
restype_atom37_mask[restype, atom_type] = 1
residx_atom37_mask = torch.index_select(
restype_atom37_mask, 0, protein['aatype']
)
residx_atom37_mask = restype_atom37_mask[protein['aatype']]
protein['atom37_atom_exists'] = residx_atom37_mask
return protein
......@@ -543,3 +543,546 @@ def make_atom14_masks_np(batch):
out = make_atom14_masks(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def make_atom14_positions(protein):
"""Constructs denser atom positions (14 dimensions instead of 37)."""
residx_atom14_mask = protein["atom14_atom_exists"]
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
# Create a mask for known ground truth positions.
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
protein["all_atom_mask"],
residx_atom14_to_atom37,
dim=-1,
no_batch_dims=len(protein["all_atom_mask"].shape[:-1])
)
# Gather the ground truth positions.
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
batched_gather(
protein["all_atom_positions"],
residx_atom14_to_atom37,
dim=-2,
no_batch_dims=len(protein["all_atom_positions"].shape[:-2])
)
)
protein["atom14_atom_exists"] = residx_atom14_mask
protein["atom14_gt_exists"] = residx_atom14_gt_mask
protein["atom14_gt_positions"] = residx_atom14_gt_positions
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
# alternative ground truth coordinates where the naming is swapped
restype_3 = [
rc.restype_1to3[res] for res in rc.restypes
]
restype_3 += ["UNK"]
# Matrices for renaming ambiguous atoms.
all_matrices = {
res: torch.eye(
14,
dtype=protein["all_atom_mask"].dtype,
device=protein["all_atom_mask"].device
) for res in restype_3
}
for resname, swap in rc.residue_atom_renaming_swaps.items():
correspondences = torch.arange(14, device=protein["all_atom_mask"].device)
for source_atom_swap, target_atom_swap in swap.items():
source_index = rc.restype_name_to_atom14_names[
resname].index(source_atom_swap)
target_index = rc.restype_name_to_atom14_names[
resname].index(target_atom_swap)
correspondences[source_index] = target_index
correspondences[target_index] = source_index
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
for index, correspondence in enumerate(correspondences):
renaming_matrix[index, correspondence] = 1.
all_matrices[resname] = renaming_matrix
renaming_matrices = torch.stack(
[all_matrices[restype] for restype in restype_3]
)
# Pick the transformation matrices for the given residue sequence
# shape (num_res, 14, 14).
renaming_transform = renaming_matrices[protein["aatype"]]
# Apply it to the ground truth positions. shape (num_res, 14, 3).
alternative_gt_positions = torch.einsum(
"...rac,...rab->...rbc",
residx_atom14_gt_positions,
renaming_transform
)
protein["atom14_alt_gt_positions"] = alternative_gt_positions
# Create the mask for the alternative ground truth (differs from the
# ground truth mask, if only one of the atoms in an ambiguous pair has a
# ground truth position).
alternative_gt_mask = torch.einsum(
"...ra,...rab->...rb",
residx_atom14_gt_mask,
renaming_transform
)
protein["atom14_alt_gt_exists"] = alternative_gt_mask
# Create an ambiguous atoms mask. shape: (21, 14).
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
for resname, swap in rc.residue_atom_renaming_swaps.items():
for atom_name1, atom_name2 in swap.items():
restype = rc.restype_order[
rc.restype_3to1[resname]]
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
atom_name1)
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
atom_name2)
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
# From this create an ambiguous_mask for the given sequence.
protein["atom14_atom_is_ambiguous"] = (
restype_atom14_is_ambiguous[protein["aatype"]]
)
return protein
def atom37_to_frames(protein):
aatype = protein["aatype"]
all_atom_positions = protein["all_atom_positions"]
all_atom_mask = protein["all_atom_mask"]
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if(rc.chi_angles_mask[restype][chi_idx]):
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
lookuptable = rc.atom_order.copy()
lookuptable[''] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
)
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
restype_rigidgroup_base_atom37_idx,
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims),
*restype_rigidgroup_base_atom37_idx.shape
)
)
residx_rigidgroup_base_atom37_idx = batched_gather(
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = T.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=1e-8,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
gt_atoms_exist = batched_gather(
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather(
restype_rigidgroup_is_ambiguous,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
protein['rigidgroups_gt_frames'] = gt_frames_tensor
protein['rigidgroups_gt_exists'] = gt_exists
protein['rigidgroups_group_exists'] = group_exists
protein['rigidgroups_group_is_ambiguous'] = residx_rigidgroup_is_ambiguous
protein['rigidgroups_alt_gt_frames'] = alt_gt_frames_tensor
return protein
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
@curry1
def atom37_to_torsion_angles(
protein,
prefix='',
):
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
Dict containing:
* (prefix)aatype:
[*, N_res] residue indices
* (prefix)all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
* (prefix)all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
The same dictionary updated with the following features:
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"(prefix)torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype = protein[prefix + "aatype"]
all_atom_positions = protein[prefix + "all_atom_positions"]
all_atom_mask = protein[prefix + "all_atom_mask"]
aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 1:3, :],
all_atom_positions[..., :2, :]
], dim=-2
)
phi_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 2:3, :],
all_atom_positions[..., :3, :]
], dim=-2
)
psi_atom_pos = torch.cat(
[
all_atom_positions[..., :3, :],
all_atom_positions[..., 4:5, :]
], dim=-2
)
pre_omega_mask = (
torch.prod(prev_all_atom_mask[..., 1:3], dim=-1) *
torch.prod(all_atom_mask[..., :2], dim=-1)
)
phi_mask = (
prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) *
all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
get_chi_atom_indices(), device=aatype.device
)
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2])
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat(
[
pre_omega_atom_pos[..., None, :, :],
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
], dim=-3
)
torsion_angles_mask = torch.cat(
[
pre_omega_mask[..., None],
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
], dim=-1
)
torsion_frames = T.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
eps=1e-8,
)
fourth_atom_rel_pos = torsion_frames.invert().apply(
torsions_atom_pos[..., 3, :]
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True
) + 1e-8
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous
], dim=-1
)
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
)
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
return protein
def get_backbone_frames(protein):
# TODO: Verify that this is correct
protein["backbone_affine_tensor"] = (
protein["rigidgroups_gt_frames"][..., 0, :, :]
)
protein["backbone_affine_mask"] = (
protein["rigidgroups_gt_exists"][..., 0]
)
return protein
def get_chi_angles(protein):
dtype = protein["all_atom_mask"].dtype
protein["chi_angles_sin_cos"] = (
protein["torsion_angles_sin_cos"][..., 3:, :]
).to(dtype)
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
return protein
@curry1
def random_crop_to_size(
protein,
crop_size,
max_templates,
shape_schema,
subsample_templates=False,
seed=None,
batch_mode='clamped'
):
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
seq_length = protein['seq_length']
if 'template_mask' in protein:
num_templates = protein['template_mask'].shape[-1]
else:
num_templates = protein['aatype'].new_zeros((1,))
num_res_crop_size = min(seq_length, crop_size)
# We want each ensemble to be cropped the same way
g = torch.Generator(device=protein['seq_length'].device)
if(seed is not None):
g.manual_seed(seed)
def _randint(lower, upper):
return int(torch.randint(
lower, upper, (1,),
device=protein['seq_length'].device, generator=g
)[0])
if subsample_templates:
templates_crop_start = _randint(0, num_templates + 1)
templates_select_indices = torch.randperm(
num_templates, device=protein['seq_length'].device, generator=g
)
num_templates_crop_size = min(
num_templates - templates_crop_start, max_templates
)
else:
templates_crop_start = 0
num_templates_crop_size = num_templates
n = seq_length - num_res_crop_size
if(batch_mode == 'clamped'):
right_anchor = n + 1
elif(batch_mode == 'unclamped'):
x = _randint(0, n)
right_anchor = n - x + 1
else:
raise ValueError("Invalid batch mode")
num_res_crop_start = _randint(0, right_anchor)
for k, v in protein.items():
if (k not in shape_schema or
('template' not in k and NUM_RES not in shape_schema[k])
):
continue
# randomly permute the templates before cropping them.
if k.startswith('template') and subsample_templates:
v = v[templates_select_indices]
slices = []
for i, (dim_size, dim) in enumerate(zip(shape_schema[k],
v.shape)):
is_num_res = (dim_size == NUM_RES)
if i == 0 and k.startswith('template'):
crop_size = num_templates_crop_size
crop_start = templates_crop_start
else:
crop_start = num_res_crop_start if is_num_res else 0
crop_size = num_res_crop_size if is_num_res else dim
slices.append(slice(crop_start, crop_start + crop_size))
protein[k] = v[slices]
protein['seq_length'] = (
protein['seq_length'].new_tensor(num_res_crop_size)
)
return protein
......@@ -25,39 +25,67 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict = {k: torch.tensor(v) for k, v in np_example.items() if k in features}
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict
def make_data_config(
config: ml_collections.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config.data)
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if(mode_cfg.crop_size is None):
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
with cfg.unlocked():
cfg.eval.crop_size = num_res
if(cfg[mode].supervised):
feature_names += cfg.common.supervised_features
return cfg, feature_names
def np_example_to_features(np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0):
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
mode: str,
batch_mode: str,
):
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)
cfg, feature_names = make_data_config(
config, mode=mode, num_res=num_res
)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32))
np_example.pop('deletion_matrix_int').astype(np.float32)
)
if batch_mode == 'clamped':
np_example['use_clamped_fape'] = (
np.array(1.).astype(np.float32)
)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
torch.manual_seed(random_seed)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names)
features = input_pipeline.process_tensors_from_config(tensor_dict, cfg)
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode,
)
return {k: v for k, v in features.items()}
......@@ -70,10 +98,13 @@ class FeaturePipeline:
self.params = params
def process_features(self,
raw_features: FeatureDict,
random_seed: int) -> FeatureDict:
raw_features: FeatureDict,
mode: str = 'train',
batch_mode: str = 'clamped',
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
random_seed=random_seed
)
\ No newline at end of file
mode=mode,
batch_mode=batch_mode,
)
from functools import partial
import torch
from openfold.features import data_transforms
def nonensembled_transform_fns(data_config):
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
common_cfg = data_config.common
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes,
......@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_')
])
if(common_cfg.use_template_torsion_angles):
transforms.extend([
data_transforms.atom37_to_torsion_angles('template_'),
])
transforms.extend([
data_transforms.make_atom14_masks,
])
if(mode_cfg.supervised):
transforms.extend([
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(''),
data_transforms.make_pseudo_beta(''),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
])
return transforms
def ensembled_transform_fns(data_config):
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
"""Input pipeline data transformers that can be ensembled and averaged."""
common_cfg = data_config.common
eval_cfg = data_config.eval
transforms = []
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = eval_cfg.max_msa_clusters
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
......@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(common_cfg.masked_msa,
eval_cfg.masked_msa_replace_fraction)
data_transforms.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction
)
)
if common_cfg.msa_cluster_features:
......@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
transforms.append(data_transforms.make_msa_feat())
crop_feats = dict(eval_cfg.feat)
crop_feats = dict(common_cfg.feat)
if eval_cfg.fixed_size:
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=torch.Generator().seed()
))
transforms.append(data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
eval_cfg.crop_size,
eval_cfg.max_templates
mode_cfg.crop_size,
mode_cfg.max_templates
))
else:
transforms.append(data_transforms.crop_templates(eval_cfg.max_templates))
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def process_tensors_from_config(tensors, data_config):
def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped'
):
"""Based on the config, apply filters and transformations to the data."""
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(data_config)
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns)
d['ensemble_index'] = i
return fn(d)
eval_cfg = data_config.eval
tensors = compose(
nonensembled_transform_fns(data_config)
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = eval_cfg.num_ensemble
if data_config.common.resample_msa_in_recycling:
num_ensemble = mode_cfg.num_ensemble
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= data_config.common.num_recycle + 1
num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x),
......@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack([dict_i[feat] for dict_i in ensembles])
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
"""Parses the mmCIF file format."""
import collections
import dataclasses
import io
import json
import logging
import os
from typing import Any, Mapping, Optional, Sequence, Tuple
from absl import logging
from Bio import PDB
from Bio.Data import SCOPData
import numpy as np
import openfold.np.residue_constants as residue_constants
# Type aliases:
ChainId = str
......@@ -369,3 +374,73 @@ def _get_protein_chains(
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.'
)
chain = relevant_chains[0]
# Extract the coordinates
num_res = len(mmcif_object.chain_to_seqres[chain_id])
all_atom_positions = np.zeros(
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
)
all_atom_mask = np.zeros(
[num_res, residue_constants.atom_type_num], dtype=np.float32
)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
return all_atom_positions, all_atom_mask
def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
with open(os.path.join(mmcif_dir, f), 'r') as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None):
logging.warning(f'Could not parse {f}. Skipping...')
continue
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data['release_date'] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data
with open(out_path, 'w') as fp:
fp.write(json.dumps(data))
......@@ -18,6 +18,7 @@ class HHSearch:
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000):
"""Initializes the Python HHsearch wrapper.
......@@ -26,6 +27,7 @@ class HHSearch:
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
......@@ -34,6 +36,7 @@ class HHSearch:
"""
self.binary_path = binary_path
self.databases = databases
self.n_cpu = n_cpu
self.maxseq = maxseq
for database_path in self.databases:
......@@ -56,7 +59,8 @@ class HHSearch:
cmd = [self.binary_path,
'-i', input_path,
'-o', hhr_path,
'-maxseq', str(self.maxseq)
'-maxseq', str(self.maxseq),
'-cpu', str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
......
......@@ -3,14 +3,12 @@
from concurrent import futures
import glob
import logging
import os
import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from absl import logging
from openfold.features.np import utils
......
"""Common utilities for data pipeline tools."""
import contextlib
import datetime
import shutil
import tempfile
import time
......@@ -25,3 +26,9 @@ def timing(msg: str):
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
def to_date(s: str):
return datetime.datetime(
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
)
......@@ -2,16 +2,17 @@
import dataclasses
import datetime
import glob
import json
import logging
import os
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
from absl import logging
import numpy as np
from openfold.features import parsers, mmcif_parsing
from openfold.features.np import kalign
from openfold.features.np.utils import to_date
from openfold.np import residue_constants
......@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES = {
'template_aatype': np.int64,
'template_all_atom_masks': np.float32,
'template_all_atom_mask': np.float32,
'template_all_atom_positions': np.float32,
'template_domain_names': np.object,
'template_sequence': np.object,
......@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
return result
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
path = os.path.join(mmcif_dir, f)
with open(path, 'r') as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}. Skipping...')
continue
mmcif = mmcif.mmcif_object
release_date = mmcif.header['release_date']
dates[file_id] = release_date
with open(out_path, 'r') as fp:
fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if path.endswith('txt'):
release_dates = {}
with open(path, 'r') as f:
for line in f:
pdb_id, date = line.split(':')
date = date.strip()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates[pdb_id.strip()] = datetime.datetime(
year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10]))
return release_dates
else:
raise ValueError('Invalid format of the release date file %s.' % path)
with open(path, 'r') as fp:
data = json.load(fp)
return {
pdb:to_date(v) for pdb,d in data.items() for k,v in d.items()
if k == "release_date"
}
def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
......@@ -419,42 +437,14 @@ def _get_atom_positions(
auth_chain_id: str,
max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])
relevant_chains = [c for c in mmcif_object.structure.get_chains()
if c.id == auth_chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {auth_chain_id}.')
chain = relevant_chains[0]
all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
dtype=np.int64)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coordinates of the selenium atom in the sulphur column.
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
all_positions[res_index] = pos
all_positions_mask[res_index] = mask
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id
)
all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances(
all_positions, all_positions_mask, max_ca_ca_distance)
return all_positions, all_positions_mask
all_atom_positions, all_atom_mask, max_ca_ca_distance
)
return all_atom_positions, all_atom_mask
def _extract_template_features(
......@@ -579,7 +569,7 @@ def _extract_template_features(
return (
{
'template_all_atom_positions': np.array(templates_all_atom_positions),
'template_all_atom_masks': np.array(templates_all_atom_masks),
'template_all_atom_mask': np.array(templates_all_atom_masks),
'template_sequence': output_templates_sequence.encode(),
'template_aatype': np.array(templates_aatype),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
......
......@@ -19,7 +19,6 @@ import torch.nn as nn
from openfold.utils.feats import (
pseudo_beta_fn,
atom37_to_torsion_angles,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
......@@ -115,21 +114,16 @@ class AlphaFold(nn.Module):
batch,
)
# Build template angle feats
angle_feats = atom37_to_torsion_angles(
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],#.float(),
single_template_feats["template_all_atom_masks"],#.float(),
eps=self.config.template.eps,
)
template_angle_feat = build_template_angle_feat(
angle_feats,
single_template_feats["template_aatype"],
)
single_template_embeds = {}
if(self.config.template.embed_angles):
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
......@@ -145,11 +139,11 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans
)
template_embeds.append({
"angle": a,
"pair": t,
"torsion_mask": angle_feats["torsion_angles_mask"]
single_template_embeds.update({
"pair": t,
})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
......@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
)
t = t * (torch.sum(batch["template_mask"]) > 0)
return {
"template_angle_embedding": template_embeds["angle"],
ret = {}
if(self.config.template.embed_angles):
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({
"template_pair_embedding": t,
"torsion_angles_mask": template_embeds["torsion_mask"],
}
})
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Primary output dictionary
......@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
if(self.config.num_recycle > 0):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
......@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
k:v for k,v in feats.items() if "template_" in k
k:v for k,v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
......@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
)
# [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
......@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
......@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
self._disable_activation_checkpointing()
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
for cycle_no in range(self.config.num_recycle + 1):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == (self.config.no_cycles - 1))
is_final_iter = (cycle_no == self.config.num_recycle)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter):
......
......@@ -29,14 +29,15 @@ class ExponentialMovingAverage:
self.decay = decay
def _update_state_dict_(self, update, state_dict):
for k, v in update.items():
stored = state_dict[k]
if(not isinstance(v, torch.Tensor)):
self._update_state_dict_(v, stored)
else:
diff = stored - v
diff *= (1 - self.decay)
stored -= diff
with torch.no_grad():
for k, v in update.items():
stored = state_dict[k]
if(not isinstance(v, torch.Tensor)):
self._update_state_dict_(v, stored)
else:
diff = stored - v
diff *= (1 - self.decay)
stored -= diff
def update(self, model: torch.nn.Module) -> None:
"""
......
......@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return pseudo_beta
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
......@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
return atom37_data
def atom37_to_torsion_angles(
aatype: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
eps: float = 1e-8,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
Dictionary of the following features:
"torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 1:3, :],
all_atom_positions[..., :2, :]
], dim=-2
)
phi_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 2:3, :],
all_atom_positions[..., :3, :]
], dim=-2
)
psi_atom_pos = torch.cat(
[
all_atom_positions[..., :3, :],
all_atom_positions[..., 4:5, :]
], dim=-2
)
pre_omega_mask = (
torch.prod(prev_all_atom_mask[..., 1:3], dim=-1) *
torch.prod(all_atom_mask[..., :2], dim=-1)
)
phi_mask = (
prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) *
all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
get_chi_atom_indices(), device=aatype.device
)
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2])
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat(
[
pre_omega_atom_pos[..., None, :, :],
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
], dim=-3
)
torsion_angles_mask = torch.cat(
[
pre_omega_mask[..., None],
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
], dim=-1
)
torsion_frames = T.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
eps=eps,
)
fourth_atom_rel_pos = torsion_frames.invert().apply(
torsions_atom_pos[..., 3, :]
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True
) + eps
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous
], dim=-1
)
def build_template_angle_feat(template_feats):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
)
return {
"torsion_angles_sin_cos": torsion_angles_sin_cos,
"alt_torsion_angles_sin_cos": alt_torsion_angles_sin_cos,
"torsion_angles_mask": torsion_angles_mask,
}
def atom37_to_frames(
aatype: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
eps: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if(rc.chi_angles_mask[restype][chi_idx]):
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
lookuptable = rc.atom_order.copy()
lookuptable[''] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
)
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
restype_rigidgroup_base_atom37_idx,
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims),
*restype_rigidgroup_base_atom37_idx.shape
)
)
residx_rigidgroup_base_atom37_idx = batched_gather(
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
template_feats["template_alt_torsion_angles_sin_cos"]
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = T.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
gt_atoms_exist = batched_gather(
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather(
restype_rigidgroup_is_ambiguous,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
return {
'rigidgroups_gt_frames': gt_frames_tensor,
'rigidgroups_gt_exists': gt_exists,
'rigidgroups_group_exists': group_exists,
'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous,
'rigidgroups_alt_gt_frames': alt_gt_frames_tensor,
}
def build_template_angle_feat(angle_feats, template_aatype):
torsion_angles_sin_cos = angle_feats["torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = angle_feats["alt_torsion_angles_sin_cos"]
torsion_angles_mask = angle_feats["torsion_angles_mask"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
......@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
eps + torch.sum(affine_vec ** 2, dim=-1)
)
t_aa_masks = batch["template_all_atom_masks"]
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
)
......@@ -534,53 +201,6 @@ def build_msa_feat(batch):
return batch
def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms
)
)
atom14_atom_is_ambiguous = ambiguous_atoms[batch["aatype"], ...]
# Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor(swap_mat)
swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = (
torch.sum(
batch["atom14_gt_positions"][..., None, :] * swap_mat[..., None],
dim=-3
)
)
atom14_alt_gt_exists = (
torch.sum(
batch["atom14_gt_exists"][..., None] * swap_mat, dim=-2
)
)
return {
"atom14_atom_is_ambiguous": atom14_atom_is_ambiguous,
"atom14_alt_gt_positions": atom14_alt_gt_positions,
"atom14_alt_gt_exists": atom14_alt_gt_exists,
}
def torsion_angles_to_frames(
t: T,
alpha: torch.Tensor,
......
......@@ -18,6 +18,7 @@ import ml_collections
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
......@@ -117,7 +118,9 @@ def compute_fape(
return normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
# DISCREPANCY: From the way this function is written, it's possible that
# DeepMind clamped 90% of individual residue losses, not 90% of all batches.
# We defer to the text, which seems to imply the latter.
def backbone_loss(
backbone_affine_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor,
......@@ -130,7 +133,7 @@ def backbone_loss(
) -> torch.Tensor:
pred_aff = T.from_tensor(traj)
gt_aff = T.from_tensor(backbone_affine_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff[..., None, :],
......@@ -142,7 +145,6 @@ def backbone_loss(
length_scale=loss_unit_distance,
eps=eps,
)
if(use_clamped_fape is not None):
unclamped_fape_loss = compute_fape(
pred_aff,
......@@ -157,12 +159,12 @@ def backbone_loss(
)
fape_loss = (
fape_loss * use_clamped_fape +
fape_loss * use_clamped_fape +
unclamped_fape_loss * (1 - use_clamped_fape)
)
# Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=0)
fape_loss = torch.mean(fape_loss, dim=-1)
return fape_loss
......@@ -1453,7 +1455,7 @@ class AlphaFoldLoss(nn.Module):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch):
def forward(self, out, batch):
if("violation" not in out.keys() and self.config.violation.weight):
out["violation"] = find_structural_violations(
batch,
......@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
**self.config.violation,
)
if("atom14_atom_is_ambiguous" not in batch.keys()):
batch.update(feats.build_ambiguity_feats(batch))
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
if("backbone_affine_tensor" not in batch.keys()):
batch.update(feats.atom37_to_frames(eps=self.config.eps, **batch))
# TODO: Verify that this is correct
batch["backbone_affine_tensor"] = (
batch["rigidgroups_gt_frames"][..., 0, :, :]
)
batch["backbone_affine_mask"] = (
batch["rigidgroups_gt_exists"][..., 0]
)
if("chi_angles_sin_cos" not in batch.keys()):
with torch.no_grad():
batch.update(feats.atom37_to_torsion_angles(
aatype=batch["aatype"],
all_atom_positions=batch["all_atom_positions"].double(),
all_atom_mask=batch["all_atom_mask"].double(),
eps=self.config.eps,
))
# TODO: Verify that this is correct
batch["chi_angles_sin_cos"] = (
batch["torsion_angles_sin_cos"][..., 3:, :]
).to(batch["all_atom_mask"].dtype)
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:].to(batch["all_atom_mask"].dtype)
loss_fns = {
"distogram":
lambda: distogram_loss(
......
......@@ -15,17 +15,17 @@
import argparse
from datetime import date
import pickle
import logging
import os
# A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import pickle
import random
import sys
from openfold.features import templates, feature_pipeline
from openfold.features.np import data_pipeline
from openfold.features import templates, feature_pipeline, data_pipeline
import time
......@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
tensor_tree_map,
)
MAX_TEMPLATE_HITS = 20
from scripts.utils import add_data_args
def main(args):
config = model_config(args.model_name)
model = AlphaFold(config.model)
model = model.eval()
import_jax_weights_(model, args.param_path)
model = model.to(args.device)
model = model.to(args.model_device)
# FEATURE COLLECTION AND PROCESSING
use_small_bfd = args.preset == "reduced_dbs"
num_ensemble = 1
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=MAX_TEMPLATE_HITS,
max_hits=args.max_template_hits,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
......@@ -76,6 +77,7 @@ def main(args):
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
data_processor = data_pipeline.DataPipeline(
......@@ -87,7 +89,7 @@ def main(args):
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
config.data.eval.num_ensemble = num_ensemble
config.data.predict.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
......@@ -95,7 +97,7 @@ def main(args):
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
print("Generating features...")
logging.info("Generating features...")
alignment_runner.run(
args.fasta_path, alignment_dir
)
......@@ -105,42 +107,20 @@ def main(args):
)
processed_feature_dict = feature_processor.process_features(
feature_dict, random_seed
feature_dict, mode='predict',
)
for k, v in processed_feature_dict.items():
print(k)
print(v.shape)
print("Executing model...")
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.device)
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
longs = [
"aatype",
"template_aatype",
"extra_msa",
"residx_atom37_to_atom14",
"residx_atom14_to_atom37",
"true_msa",
"residue_index",
]
for l in longs:
batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
make_contig = lambda t: t.contiguous()
batch = tensor_tree_map(make_contig, batch)
t = time.time()
out = model(batch)
print(f"Inference time: {time.time() - t}")
logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
......@@ -158,9 +138,7 @@ def main(args):
result=out,
b_factors=plddt_b_factors
)
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
......@@ -168,7 +146,7 @@ def main(args):
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.time() - t}")
logging.info(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
......@@ -183,53 +161,14 @@ if __name__ == "__main__":
parser.add_argument(
"fasta_path", type=str,
)
parser.add_argument(
'uniref90_database_path', type=str,
)
parser.add_argument(
'mgnify_database_path', type=str,
)
parser.add_argument(
'pdb70_database_path', type=str,
)
parser.add_argument(
'template_mmcif_dir', type=str,
)
parser.add_argument(
'--uniclust30_database_path', type=str,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
add_data_args(parser)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
required=True
)
parser.add_argument(
"--device", type=str, default="cpu",
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
......@@ -244,6 +183,10 @@ if __name__ == "__main__":
automatically according to the model name from
openfold/resources/params"""
)
parser.add_argument(
"--cpus", type=int, default=4,
help="""Number of CPUs to use to run alignment tools"""
)
parser.add_argument(
'--preset', type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs')
......
......@@ -15,6 +15,7 @@
import argparse
import json
parser = argparse.ArgumentParser(description='''Outputs a DeepSpeed
configuration file to
stdout''')
......
import argparse
import logging
import os
import tempfile
import openfold.features.mmcif_parsing as mmcif_parsing
from openfold.features.data_pipeline import AlignmentRunner
from scripts.utils import add_data_args
def main(args):
# Build the alignment tool runner
alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus,
)
for f in os.listdir(args.input_dir):
path = os.path.join(args.input_dir, f)
is_mmcif = f.endswith('.cif')
is_fasta = f.endswith('.fasta')
file_id = os.path.splitext(f)[0]
seqs = {}
if(is_mmcif):
with open(path, 'r') as fp:
mmcif_str = fp.read()
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_str
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}...')
if(args.raise_errors):
raise list(mmcif.errors.values())[0]
else:
continue
mmcif = mmcif.mmcif_object
for k,v in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k])
seqs[chain_id] = v
elif(is_fasta):
with open(path, 'r') as fp:
fasta_str = fp.read()
input_seqs, _ = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
msg = f'More than one input_sequence found in {f}'
if(args.raise_errors):
raise ValueError(msg)
else:
logging.warning(msg)
input_sequence = input_seqs[0]
seqs[file_id] = input_sequence
else:
continue
for name, seq in seqs.items():
alignment_dir = os.path.join(args.output_dir, name)
if(os.path.isdir(alignment_dir)):
logging.info(f'{f} has already been processed. Skipping...')
continue
os.makedirs(alignment_dir)
if(not is_fasta):
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
alignment_runner.run(
f if is_fasta else fasta_path, alignment_dir
)
if(not is_fasta):
os.remove(fasta_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir", type=str,
help="Path to directory containing mmCIF and/or FASTA files"
)
parser.add_argument(
"output_dir", type=str,
help="Directory in which to output alignments"
)
add_data_args(parser)
parser.add_argument(
"--raise_errors", type=bool, default=False,
help="Whether to crash on parsing errors"
)
parser.add_argument(
"--cpus", type=int, default=4,
help="Number of CPUs to use"
)
args = parser.parse_args()
main(args)
import argparse
from datetime import date
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'uniref90_database_path', type=str,
)
parser.add_argument(
'mgnify_database_path', type=str,
)
parser.add_argument(
'pdb70_database_path', type=str,
)
parser.add_argument(
'template_mmcif_dir', type=str,
)
parser.add_argument(
'uniclust30_database_path', type=str,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--max_template_hits', type=int, default=20,
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import importlib
import pkgutil
import sys
......
......@@ -16,6 +16,7 @@ import torch
import numpy as np
import unittest
import openfold.features.data_transforms as data_transforms
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = feats.atom37_to_frames(eps=1e-8, **batch)
out_repro = data_transforms.atom37_to_frames(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k,v in out_gt.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment