"vscode:/vscode.git/clone" did not exist on "e9cb035ac72b1b3f88635242db5e773da66d93b8"
Commit d48c052c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add training parsers

parent eeda001c
This diff is collapsed.
import os
import datetime
import numpy as np
from typing import Mapping, Optional, Sequence
from typing import Mapping, Optional, Sequence, Any
from openfold.features import templates, parsers
from openfold.features import templates, parsers, mmcif_parsing
from openfold.features.np import jackhmmer, hhblits, hhsearch
from openfold.features.np.utils import to_date
from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray]
FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict:
def make_sequence_features(
sequence: str,
description: str,
num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
......@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
map_unknown_to_x=True
)
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array([description.encode('utf-8')], dtype=np.object_)
features['domain_name'] = np.array(
[description.encode('utf-8')], dtype=np.object_
)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_)
features['sequence'] = np.array(
[sequence.encode('utf-8')], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject,
chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
))
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_
)
return mmcif_feats
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
......@@ -58,9 +101,9 @@ def make_msa_features(
)
return features
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
class AlignmentRunner:
""" Runs alignment tools and saves the results """
def __init__(self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
......@@ -71,106 +114,158 @@ class DataPipeline:
uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str],
pdb70_database_path: str,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
mgnify_max_hits: int = 501,
uniref_max_hits: int = 10000
no_cpus: int,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
):
"""Constructs a feature dict for a given FASTA file."""
self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path
database_path=uniref90_database_path,
n_cpu=no_cpus,
)
if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path
database_path=small_bfd_database_path,
n_cpu=no_cpus,
)
else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path]
databases=[bfd_database_path, uniclust30_database_path],
n_cpu=no_cpus,
)
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path
database_path=mgnify_database_path,
n_cpu=no_cpus,
)
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path]
)
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict:
"""Runs alignment tools on the input sequence and creates features."""
with open(input_fasta_path) as f:
input_fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.'
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(input_fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(input_fasta_path)[0]
def run(self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits
)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'w') as f:
f.write(jackhmmer_uniref90_result['sto'])
f.write(uniref90_msa_as_a3m)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.so')
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'w') as f:
f.write(jackhmmer_mgnify_result['sto'])
f.write(mgnify_msa_as_a3m)
pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr')
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f:
f.write(hhsearch_result)
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_uniref90_result['sto']
)
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_mgnify_result['sto']
)
hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(input_fasta_path)[0]
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m')
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0]
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto'])
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_small_bfd_result['sto']
)
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(input_fasta_path)
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m')
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path)
if(output_dir is not None):
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m'])
class DataPipeline:
"""Assembles input features."""
def __init__(self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
):
self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self,
alignment_dir: str,
) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'r') as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'r') as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'r') as f:
hhsearch_hits = parsers.parse_hhr(
f.read()
)
if(self.use_small_bfd):
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
)
else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
hhblits_bfd_uniclust_result['a3m']
f.read()
)
return {
'uniref90_msa': uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits,
'bfd_msa': bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix,
}
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
alignments = self._parse_alignment_output(alignment_dir)
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=hhsearch_hits
hits=alignments['hhsearch_hits']
)
sequence_features = make_sequence_features(
......@@ -180,9 +275,62 @@ class DataPipeline:
)
msa_features = make_msa_features(
msas=(uniref90_msa, bfd_msa, mgnify_msa),
deletion_matrices = (uniref90_deletion_matrix,
bfd_deletion_matrix,
mgnify_deletion_matrix)
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
),
deletion_matrices=(
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**sequence_features, **msa_features, **templates_result.features}
def process_mmcif(self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if(chain_id is None):
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if(chain is None):
raise ValueError(
'No chains in mmCIF file'
)
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits']
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**mmcif_feats, **templates_result.features, **msa_features}
This diff is collapsed.
......@@ -25,39 +25,67 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
tensor_dict = {k: torch.tensor(v) for k, v in np_example.items() if k in features}
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict
def make_data_config(
config: ml_collections.ConfigDict,
mode: str,
num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config.data)
) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if(mode_cfg.crop_size is None):
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates:
feature_names += cfg.common.template_features
with cfg.unlocked():
cfg.eval.crop_size = num_res
if(cfg[mode].supervised):
feature_names += cfg.common.supervised_features
return cfg, feature_names
def np_example_to_features(np_example: FeatureDict,
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict,
random_seed: int = 0):
mode: str,
batch_mode: str,
):
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res)
cfg, feature_names = make_data_config(
config, mode=mode, num_res=num_res
)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32))
np_example.pop('deletion_matrix_int').astype(np.float32)
)
if batch_mode == 'clamped':
np_example['use_clamped_fape'] = (
np.array(1.).astype(np.float32)
)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
torch.manual_seed(random_seed)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names)
features = input_pipeline.process_tensors_from_config(tensor_dict, cfg)
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode,
)
return {k: v for k, v in features.items()}
......@@ -71,9 +99,12 @@ class FeaturePipeline:
def process_features(self,
raw_features: FeatureDict,
random_seed: int) -> FeatureDict:
mode: str = 'train',
batch_mode: str = 'clamped',
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
config=self.config,
random_seed=random_seed
mode=mode,
batch_mode=batch_mode,
)
from functools import partial
import torch
from openfold.features import data_transforms
def nonensembled_transform_fns(data_config):
def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled."""
common_cfg = data_config.common
transforms = [
data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes,
......@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_')
])
if(common_cfg.use_template_torsion_angles):
transforms.extend([
data_transforms.atom37_to_torsion_angles('template_'),
])
transforms.extend([
data_transforms.make_atom14_masks,
])
if(mode_cfg.supervised):
transforms.extend([
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(''),
data_transforms.make_pseudo_beta(''),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
])
return transforms
def ensembled_transform_fns(data_config):
def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
"""Input pipeline data transformers that can be ensembled and averaged."""
common_cfg = data_config.common
eval_cfg = data_config.eval
transforms = []
if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates
pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else:
pad_msa_clusters = eval_cfg.max_msa_clusters
pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa
......@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(common_cfg.masked_msa,
eval_cfg.masked_msa_replace_fraction)
data_transforms.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction
)
)
if common_cfg.msa_cluster_features:
......@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
transforms.append(data_transforms.make_msa_feat())
crop_feats = dict(eval_cfg.feat)
crop_feats = dict(common_cfg.feat)
if eval_cfg.fixed_size:
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=torch.Generator().seed()
))
transforms.append(data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
eval_cfg.crop_size,
eval_cfg.max_templates
mode_cfg.crop_size,
mode_cfg.max_templates
))
else:
transforms.append(data_transforms.crop_templates(eval_cfg.max_templates))
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms
def process_tensors_from_config(tensors, data_config):
def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped'
):
"""Based on the config, apply filters and transformations to the data."""
def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension."""
d = data.copy()
fns = ensembled_transform_fns(data_config)
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns)
d['ensemble_index'] = i
return fn(d)
eval_cfg = data_config.eval
tensors = compose(
nonensembled_transform_fns(data_config)
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = eval_cfg.num_ensemble
if data_config.common.resample_msa_in_recycling:
num_ensemble = mode_cfg.num_ensemble
if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step.
num_ensemble *= data_config.common.num_recycle + 1
num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x),
......@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
return tensors
@data_transforms.curry1
def compose(x, fs):
for f in fs:
x = f(x)
return x
def map_fn(fun, x):
ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys()
ensembled_dict = {}
for feat in features:
ensembled_dict[feat] = torch.stack([dict_i[feat] for dict_i in ensembles])
ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict
"""Parses the mmCIF file format."""
import collections
import dataclasses
import io
import json
import logging
import os
from typing import Any, Mapping, Optional, Sequence, Tuple
from absl import logging
from Bio import PDB
from Bio.Data import SCOPData
import numpy as np
import openfold.np.residue_constants as residue_constants
# Type aliases:
ChainId = str
......@@ -369,3 +374,73 @@ def _get_protein_chains(
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.'
)
chain = relevant_chains[0]
# Extract the coordinates
num_res = len(mmcif_object.chain_to_seqres[chain_id])
all_atom_positions = np.zeros(
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
)
all_atom_mask = np.zeros(
[num_res, residue_constants.atom_type_num], dtype=np.float32
)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
return all_atom_positions, all_atom_mask
def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
with open(os.path.join(mmcif_dir, f), 'r') as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None):
logging.warning(f'Could not parse {f}. Skipping...')
continue
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data['release_date'] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data
with open(out_path, 'w') as fp:
fp.write(json.dumps(data))
......@@ -18,6 +18,7 @@ class HHSearch:
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000):
"""Initializes the Python HHsearch wrapper.
......@@ -26,6 +27,7 @@ class HHSearch:
databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher.
......@@ -34,6 +36,7 @@ class HHSearch:
"""
self.binary_path = binary_path
self.databases = databases
self.n_cpu = n_cpu
self.maxseq = maxseq
for database_path in self.databases:
......@@ -56,7 +59,8 @@ class HHSearch:
cmd = [self.binary_path,
'-i', input_path,
'-o', hhr_path,
'-maxseq', str(self.maxseq)
'-maxseq', str(self.maxseq),
'-cpu', str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
......
......@@ -3,14 +3,12 @@
from concurrent import futures
import glob
import logging
import os
import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request
from absl import logging
from openfold.features.np import utils
......
"""Common utilities for data pipeline tools."""
import contextlib
import datetime
import shutil
import tempfile
import time
......@@ -25,3 +26,9 @@ def timing(msg: str):
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
def to_date(s: str):
return datetime.datetime(
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
)
......@@ -2,16 +2,17 @@
import dataclasses
import datetime
import glob
import json
import logging
import os
import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
from absl import logging
import numpy as np
from openfold.features import parsers, mmcif_parsing
from openfold.features.np import kalign
from openfold.features.np.utils import to_date
from openfold.np import residue_constants
......@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES = {
'template_aatype': np.int64,
'template_all_atom_masks': np.float32,
'template_all_atom_mask': np.float32,
'template_all_atom_positions': np.float32,
'template_domain_names': np.object,
'template_sequence': np.object,
......@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
return result
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
path = os.path.join(mmcif_dir, f)
with open(path, 'r') as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}. Skipping...')
continue
mmcif = mmcif.mmcif_object
release_date = mmcif.header['release_date']
dates[file_id] = release_date
with open(out_path, 'r') as fp:
fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates."""
if path.endswith('txt'):
release_dates = {}
with open(path, 'r') as f:
for line in f:
pdb_id, date = line.split(':')
date = date.strip()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates[pdb_id.strip()] = datetime.datetime(
year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10]))
return release_dates
else:
raise ValueError('Invalid format of the release date file %s.' % path)
with open(path, 'r') as fp:
data = json.load(fp)
return {
pdb:to_date(v) for pdb,d in data.items() for k,v in d.items()
if k == "release_date"
}
def _assess_hhsearch_hit(
hit: parsers.TemplateHit,
......@@ -419,42 +437,14 @@ def _get_atom_positions(
auth_chain_id: str,
max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues."""
num_res = len(mmcif_object.chain_to_seqres[auth_chain_id])
relevant_chains = [c for c in mmcif_object.structure.get_chains()
if c.id == auth_chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {auth_chain_id}.')
chain = relevant_chains[0]
all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
dtype=np.int64)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coordinates of the selenium atom in the sulphur column.
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
all_positions[res_index] = pos
all_positions_mask[res_index] = mask
coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id
)
all_atom_positions, all_atom_mask = coords_with_mask
_check_residue_distances(
all_positions, all_positions_mask, max_ca_ca_distance)
return all_positions, all_positions_mask
all_atom_positions, all_atom_mask, max_ca_ca_distance
)
return all_atom_positions, all_atom_mask
def _extract_template_features(
......@@ -579,7 +569,7 @@ def _extract_template_features(
return (
{
'template_all_atom_positions': np.array(templates_all_atom_positions),
'template_all_atom_masks': np.array(templates_all_atom_masks),
'template_all_atom_mask': np.array(templates_all_atom_masks),
'template_sequence': output_templates_sequence.encode(),
'template_aatype': np.array(templates_aatype),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
......
......@@ -19,7 +19,6 @@ import torch.nn as nn
from openfold.utils.feats import (
pseudo_beta_fn,
atom37_to_torsion_angles,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
......@@ -115,22 +114,17 @@ class AlphaFold(nn.Module):
batch,
)
# Build template angle feats
angle_feats = atom37_to_torsion_angles(
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],#.float(),
single_template_feats["template_all_atom_masks"],#.float(),
eps=self.config.template.eps,
)
single_template_embeds = {}
if(self.config.template.embed_angles):
template_angle_feat = build_template_angle_feat(
angle_feats,
single_template_feats["template_aatype"],
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
......@@ -145,12 +139,12 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans
)
template_embeds.append({
"angle": a,
single_template_embeds.update({
"pair": t,
"torsion_mask": angle_feats["torsion_angles_mask"]
})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
......@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
)
t = t * (torch.sum(batch["template_mask"]) > 0)
return {
"template_angle_embedding": template_embeds["angle"],
ret = {}
if(self.config.template.embed_angles):
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({
"template_pair_embedding": t,
"torsion_angles_mask": template_embeds["torsion_mask"],
}
})
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Primary output dictionary
......@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
if(self.config.num_recycle > 0):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
......@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
k:v for k,v in feats.items() if "template_" in k
k:v for k,v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
......@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
)
# [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
......@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
......@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
self._disable_activation_checkpointing()
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
for cycle_no in range(self.config.num_recycle + 1):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == (self.config.no_cycles - 1))
is_final_iter = (cycle_no == self.config.num_recycle)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter):
......
......@@ -29,6 +29,7 @@ class ExponentialMovingAverage:
self.decay = decay
def _update_state_dict_(self, update, state_dict):
with torch.no_grad():
for k, v in update.items():
stored = state_dict[k]
if(not isinstance(v, torch.Tensor)):
......
......@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return pseudo_beta
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
......@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
return atom37_data
def atom37_to_torsion_angles(
aatype: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
eps: float = 1e-8,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
Dictionary of the following features:
"torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 1:3, :],
all_atom_positions[..., :2, :]
], dim=-2
)
phi_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 2:3, :],
all_atom_positions[..., :3, :]
], dim=-2
)
psi_atom_pos = torch.cat(
[
all_atom_positions[..., :3, :],
all_atom_positions[..., 4:5, :]
], dim=-2
)
pre_omega_mask = (
torch.prod(prev_all_atom_mask[..., 1:3], dim=-1) *
torch.prod(all_atom_mask[..., :2], dim=-1)
)
phi_mask = (
prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) *
all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
get_chi_atom_indices(), device=aatype.device
)
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2])
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat(
[
pre_omega_atom_pos[..., None, :, :],
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
], dim=-3
)
torsion_angles_mask = torch.cat(
[
pre_omega_mask[..., None],
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
], dim=-1
)
torsion_frames = T.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
eps=eps,
)
fourth_atom_rel_pos = torsion_frames.invert().apply(
torsions_atom_pos[..., 3, :]
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True
) + eps
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous
], dim=-1
)
def build_template_angle_feat(template_feats):
template_aatype = template_feats["template_aatype"]
torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
)
return {
"torsion_angles_sin_cos": torsion_angles_sin_cos,
"alt_torsion_angles_sin_cos": alt_torsion_angles_sin_cos,
"torsion_angles_mask": torsion_angles_mask,
}
def atom37_to_frames(
aatype: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
eps: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if(rc.chi_angles_mask[restype][chi_idx]):
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
lookuptable = rc.atom_order.copy()
lookuptable[''] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
)
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
restype_rigidgroup_base_atom37_idx,
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims),
*restype_rigidgroup_base_atom37_idx.shape
)
)
residx_rigidgroup_base_atom37_idx = batched_gather(
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = T.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
aatype,
dim=-2,
no_batch_dims=batch_dims,
template_feats["template_alt_torsion_angles_sin_cos"]
)
gt_atoms_exist = batched_gather(
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather(
restype_rigidgroup_is_ambiguous,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
return {
'rigidgroups_gt_frames': gt_frames_tensor,
'rigidgroups_gt_exists': gt_exists,
'rigidgroups_group_exists': group_exists,
'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous,
'rigidgroups_alt_gt_frames': alt_gt_frames_tensor,
}
def build_template_angle_feat(angle_feats, template_aatype):
torsion_angles_sin_cos = angle_feats["torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = angle_feats["alt_torsion_angles_sin_cos"]
torsion_angles_mask = angle_feats["torsion_angles_mask"]
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
template_angle_feat = torch.cat(
[
nn.functional.one_hot(template_aatype, 22),
......@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
eps + torch.sum(affine_vec ** 2, dim=-1)
)
t_aa_masks = batch["template_all_atom_masks"]
t_aa_masks = batch["template_all_atom_mask"]
template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
)
......@@ -534,53 +201,6 @@ def build_msa_feat(batch):
return batch
def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms
)
)
atom14_atom_is_ambiguous = ambiguous_atoms[batch["aatype"], ...]
# Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor(swap_mat)
swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = (
torch.sum(
batch["atom14_gt_positions"][..., None, :] * swap_mat[..., None],
dim=-3
)
)
atom14_alt_gt_exists = (
torch.sum(
batch["atom14_gt_exists"][..., None] * swap_mat, dim=-2
)
)
return {
"atom14_atom_is_ambiguous": atom14_atom_is_ambiguous,
"atom14_alt_gt_positions": atom14_alt_gt_positions,
"atom14_alt_gt_exists": atom14_alt_gt_exists,
}
def torsion_angles_to_frames(
t: T,
alpha: torch.Tensor,
......
......@@ -18,6 +18,7 @@ import ml_collections
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
from openfold.np import residue_constants
......@@ -117,7 +118,9 @@ def compute_fape(
return normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches
# DISCREPANCY: From the way this function is written, it's possible that
# DeepMind clamped 90% of individual residue losses, not 90% of all batches.
# We defer to the text, which seems to imply the latter.
def backbone_loss(
backbone_affine_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor,
......@@ -142,7 +145,6 @@ def backbone_loss(
length_scale=loss_unit_distance,
eps=eps,
)
if(use_clamped_fape is not None):
unclamped_fape_loss = compute_fape(
pred_aff,
......@@ -162,7 +164,7 @@ def backbone_loss(
)
# Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=0)
fape_loss = torch.mean(fape_loss, dim=-1)
return fape_loss
......@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
**self.config.violation,
)
if("atom14_atom_is_ambiguous" not in batch.keys()):
batch.update(feats.build_ambiguity_feats(batch))
if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
))
if("backbone_affine_tensor" not in batch.keys()):
batch.update(feats.atom37_to_frames(eps=self.config.eps, **batch))
# TODO: Verify that this is correct
batch["backbone_affine_tensor"] = (
batch["rigidgroups_gt_frames"][..., 0, :, :]
)
batch["backbone_affine_mask"] = (
batch["rigidgroups_gt_exists"][..., 0]
)
if("chi_angles_sin_cos" not in batch.keys()):
with torch.no_grad():
batch.update(feats.atom37_to_torsion_angles(
aatype=batch["aatype"],
all_atom_positions=batch["all_atom_positions"].double(),
all_atom_mask=batch["all_atom_mask"].double(),
eps=self.config.eps,
))
# TODO: Verify that this is correct
batch["chi_angles_sin_cos"] = (
batch["torsion_angles_sin_cos"][..., 3:, :]
).to(batch["all_atom_mask"].dtype)
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:].to(batch["all_atom_mask"].dtype)
loss_fns = {
"distogram":
lambda: distogram_loss(
......
......@@ -15,17 +15,17 @@
import argparse
from datetime import date
import pickle
import logging
import os
# A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import pickle
import random
import sys
from openfold.features import templates, feature_pipeline
from openfold.features.np import data_pipeline
from openfold.features import templates, feature_pipeline, data_pipeline
import time
......@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
tensor_tree_map,
)
MAX_TEMPLATE_HITS = 20
from scripts.utils import add_data_args
def main(args):
config = model_config(args.model_name)
model = AlphaFold(config.model)
model = model.eval()
import_jax_weights_(model, args.param_path)
model = model.to(args.device)
model = model.to(args.model_device)
# FEATURE COLLECTION AND PROCESSING
use_small_bfd = args.preset == "reduced_dbs"
num_ensemble = 1
template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date,
max_hits=MAX_TEMPLATE_HITS,
max_hits=args.max_template_hits,
kalign_binary_path=args.kalign_binary_path,
release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path
)
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
......@@ -76,6 +77,7 @@ def main(args):
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
)
data_processor = data_pipeline.DataPipeline(
......@@ -87,7 +89,7 @@ def main(args):
random_seed = args.data_random_seed
if random_seed is None:
random_seed = random.randrange(sys.maxsize)
config.data.eval.num_ensemble = num_ensemble
config.data.predict.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config)
if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base)
......@@ -95,7 +97,7 @@ def main(args):
if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir)
print("Generating features...")
logging.info("Generating features...")
alignment_runner.run(
args.fasta_path, alignment_dir
)
......@@ -105,42 +107,20 @@ def main(args):
)
processed_feature_dict = feature_processor.process_features(
feature_dict, random_seed
feature_dict, mode='predict',
)
for k, v in processed_feature_dict.items():
print(k)
print(v.shape)
print("Executing model...")
logging.info("Executing model...")
batch = processed_feature_dict
with torch.no_grad():
batch = {
k:torch.as_tensor(v, device=args.device)
k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items()
}
longs = [
"aatype",
"template_aatype",
"extra_msa",
"residx_atom37_to_atom14",
"residx_atom14_to_atom37",
"true_msa",
"residue_index",
]
for l in longs:
batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
make_contig = lambda t: t.contiguous()
batch = tensor_tree_map(make_contig, batch)
t = time.time()
out = model(batch)
print(f"Inference time: {time.time() - t}")
logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
......@@ -159,8 +139,6 @@ def main(args):
b_factors=plddt_b_factors
)
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation(
**config.relax
)
......@@ -168,7 +146,7 @@ def main(args):
# Relax the prediction.
t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.time() - t}")
logging.info(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB.
relaxed_output_path = os.path.join(
......@@ -183,53 +161,14 @@ if __name__ == "__main__":
parser.add_argument(
"fasta_path", type=str,
)
parser.add_argument(
'uniref90_database_path', type=str,
)
parser.add_argument(
'mgnify_database_path', type=str,
)
parser.add_argument(
'pdb70_database_path', type=str,
)
parser.add_argument(
'template_mmcif_dir', type=str,
)
parser.add_argument(
'--uniclust30_database_path', type=str,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
add_data_args(parser)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""",
required=True
)
parser.add_argument(
"--device", type=str, default="cpu",
"--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")"""
)
......@@ -244,6 +183,10 @@ if __name__ == "__main__":
automatically according to the model name from
openfold/resources/params"""
)
parser.add_argument(
"--cpus", type=int, default=4,
help="""Number of CPUs to use to run alignment tools"""
)
parser.add_argument(
'--preset', type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs')
......
......@@ -15,6 +15,7 @@
import argparse
import json
parser = argparse.ArgumentParser(description='''Outputs a DeepSpeed
configuration file to
stdout''')
......
import argparse
import logging
import os
import tempfile
import openfold.features.mmcif_parsing as mmcif_parsing
from openfold.features.data_pipeline import AlignmentRunner
from scripts.utils import add_data_args
def main(args):
# Build the alignment tool runner
alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus,
)
for f in os.listdir(args.input_dir):
path = os.path.join(args.input_dir, f)
is_mmcif = f.endswith('.cif')
is_fasta = f.endswith('.fasta')
file_id = os.path.splitext(f)[0]
seqs = {}
if(is_mmcif):
with open(path, 'r') as fp:
mmcif_str = fp.read()
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_str
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}...')
if(args.raise_errors):
raise list(mmcif.errors.values())[0]
else:
continue
mmcif = mmcif.mmcif_object
for k,v in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k])
seqs[chain_id] = v
elif(is_fasta):
with open(path, 'r') as fp:
fasta_str = fp.read()
input_seqs, _ = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
msg = f'More than one input_sequence found in {f}'
if(args.raise_errors):
raise ValueError(msg)
else:
logging.warning(msg)
input_sequence = input_seqs[0]
seqs[file_id] = input_sequence
else:
continue
for name, seq in seqs.items():
alignment_dir = os.path.join(args.output_dir, name)
if(os.path.isdir(alignment_dir)):
logging.info(f'{f} has already been processed. Skipping...')
continue
os.makedirs(alignment_dir)
if(not is_fasta):
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
alignment_runner.run(
f if is_fasta else fasta_path, alignment_dir
)
if(not is_fasta):
os.remove(fasta_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir", type=str,
help="Path to directory containing mmCIF and/or FASTA files"
)
parser.add_argument(
"output_dir", type=str,
help="Directory in which to output alignments"
)
add_data_args(parser)
parser.add_argument(
"--raise_errors", type=bool, default=False,
help="Whether to crash on parsing errors"
)
parser.add_argument(
"--cpus", type=int, default=4,
help="Number of CPUs to use"
)
args = parser.parse_args()
main(args)
import argparse
from datetime import date
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'uniref90_database_path', type=str,
)
parser.add_argument(
'mgnify_database_path', type=str,
)
parser.add_argument(
'pdb70_database_path', type=str,
)
parser.add_argument(
'template_mmcif_dir', type=str,
)
parser.add_argument(
'uniclust30_database_path', type=str,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--max_template_hits', type=int, default=20,
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import importlib
import pkgutil
import sys
......
......@@ -16,6 +16,7 @@ import torch
import numpy as np
import unittest
import openfold.features.data_transforms as data_transforms
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
......@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = feats.atom37_to_frames(eps=1e-8, **batch)
out_repro = data_transforms.atom37_to_frames(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k,v in out_gt.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment