Commit d48c052c authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add training parsers

parent eeda001c
This diff is collapsed.
import os import os
import datetime
import numpy as np import numpy as np
from typing import Mapping, Optional, Sequence from typing import Mapping, Optional, Sequence, Any
from openfold.features import templates, parsers from openfold.features import templates, parsers, mmcif_parsing
from openfold.features.np import jackhmmer, hhblits, hhsearch from openfold.features.np import jackhmmer, hhblits, hhsearch
from openfold.features.np.utils import to_date
from openfold.np import residue_constants from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray]
FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features(sequence: str, description: str, num_res: int) -> FeatureDict: def make_sequence_features(
sequence: str,
description: str,
num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features.""" """Construct a feature dict of sequence features."""
features = {} features = {}
features['aatype'] = residue_constants.sequence_to_onehot( features['aatype'] = residue_constants.sequence_to_onehot(
...@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea ...@@ -19,13 +25,50 @@ def make_sequence_features(sequence: str, description: str, num_res: int) -> Fea
map_unknown_to_x=True map_unknown_to_x=True
) )
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array([description.encode('utf-8')], dtype=np.object_) features['domain_name'] = np.array(
[description.encode('utf-8')], dtype=np.object_
)
features['residue_index'] = np.array(range(num_res), dtype=np.int32) features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array([sequence.encode('utf-8')], dtype=np.object_) features['sequence'] = np.array(
[sequence.encode('utf-8')], dtype=np.object_
)
return features return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject,
chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
))
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
)
mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_
)
return mmcif_feats
def make_msa_features( def make_msa_features(
msas: Sequence[Sequence[str]], msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
...@@ -58,9 +101,9 @@ def make_msa_features( ...@@ -58,9 +101,9 @@ def make_msa_features(
) )
return features return features
class DataPipeline:
"""Runs the alignment tools and assembles the input features."""
class AlignmentRunner:
""" Runs alignment tools and saves the results """
def __init__(self, def __init__(self,
jackhmmer_binary_path: str, jackhmmer_binary_path: str,
hhblits_binary_path: str, hhblits_binary_path: str,
...@@ -71,106 +114,158 @@ class DataPipeline: ...@@ -71,106 +114,158 @@ class DataPipeline:
uniclust30_database_path: Optional[str], uniclust30_database_path: Optional[str],
small_bfd_database_path: Optional[str], small_bfd_database_path: Optional[str],
pdb70_database_path: str, pdb70_database_path: str,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool, use_small_bfd: bool,
mgnify_max_hits: int = 501, no_cpus: int,
uniref_max_hits: int = 10000 uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
): ):
"""Constructs a feature dict for a given FASTA file."""
self._use_small_bfd = use_small_bfd self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=uniref90_database_path database_path=uniref90_database_path,
n_cpu=no_cpus,
) )
if use_small_bfd: if use_small_bfd:
self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=small_bfd_database_path database_path=small_bfd_database_path,
n_cpu=no_cpus,
) )
else: else:
self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( self.hhblits_bfd_uniclust_runner = hhblits.HHBlits(
binary_path=hhblits_binary_path, binary_path=hhblits_binary_path,
databases=[bfd_database_path, uniclust30_database_path] databases=[bfd_database_path, uniclust30_database_path],
n_cpu=no_cpus,
) )
self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer(
binary_path=jackhmmer_binary_path, binary_path=jackhmmer_binary_path,
database_path=mgnify_database_path database_path=mgnify_database_path,
n_cpu=no_cpus,
) )
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path, binary_path=hhsearch_binary_path,
databases=[pdb70_database_path] databases=[pdb70_database_path]
) )
self.template_featurizer = template_featurizer
self.mgnify_max_hits = mgnify_max_hits
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
def process(self, input_fasta_path: str, msa_output_dir: str) -> FeatureDict: def run(self,
"""Runs alignment tools on the input sequence and creates features.""" fasta_path: str,
with open(input_fasta_path) as f: output_dir: str,
input_fasta_str = f.read() ):
input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) """Runs alignment tools on a sequence"""
if len(input_seqs) != 1: jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0]
raise ValueError(
f'More than one input sequence found in {input_fasta_path}.'
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(input_fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(input_fasta_path)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits
) )
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m')
uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto')
with open(uniref90_out_path, 'w') as f: with open(uniref90_out_path, 'w') as f:
f.write(jackhmmer_uniref90_result['sto']) f.write(uniref90_msa_as_a3m)
mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.so') jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'w') as f: with open(mgnify_out_path, 'w') as f:
f.write(jackhmmer_mgnify_result['sto']) f.write(mgnify_msa_as_a3m)
pdb70_out_path = os.path.join(msa_output_dir, 'pdb70_hits.hhr') hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f: with open(pdb70_out_path, 'w') as f:
f.write(hhsearch_result) f.write(hhsearch_result)
uniref90_msa, uniref90_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_uniref90_result['sto']
)
mgnify_msa, mgnify_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_mgnify_result['sto']
)
hhsearch_hits = parsers.parse_hhr(hhsearch_result)
mgnify_msa = mgnify_msa[:self.mgnify_max_hits]
mgnify_deletion_matrix = mgnify_deletion_matrix[:self.mgnify_max_hits]
if self._use_small_bfd: if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(input_fasta_path)[0] jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0]
bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.a3m') bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'w') as f: with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto']) f.write(jackhmmer_small_bfd_result['sto'])
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
jackhmmer_small_bfd_result['sto']
)
else: else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(input_fasta_path) hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path)
bfd_out_path = os.path.join(msa_output_dir, 'bfd_uniclust_hits.a3m') if(output_dir is not None):
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f: with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m']) f.write(hhblits_bfd_uniclust_result['a3m'])
class DataPipeline:
"""Assembles input features."""
def __init__(self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
):
self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self,
alignment_dir: str,
) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'r') as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'r') as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'r') as f:
hhsearch_hits = parsers.parse_hhr(
f.read()
)
if(self.use_small_bfd):
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
)
else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
hhblits_bfd_uniclust_result['a3m'] f.read()
) )
return {
'uniref90_msa': uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits,
'bfd_msa': bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix,
}
def process_fasta(self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {fasta_path}.')
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
alignments = self._parse_alignment_output(alignment_dir)
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hits=hhsearch_hits hits=alignments['hhsearch_hits']
) )
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
...@@ -180,9 +275,62 @@ class DataPipeline: ...@@ -180,9 +275,62 @@ class DataPipeline:
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=(uniref90_msa, bfd_msa, mgnify_msa), msas=(
deletion_matrices = (uniref90_deletion_matrix, alignments['uniref90_msa'],
bfd_deletion_matrix, alignments['bfd_msa'],
mgnify_deletion_matrix) alignments['mgnify_msa']
),
deletion_matrices=(
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
) )
return {**sequence_features, **msa_features, **templates_result.features} return {**sequence_features, **msa_features, **templates_result.features}
def process_mmcif(self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if(chain_id is None):
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if(chain is None):
raise ValueError(
'No chains in mmCIF file'
)
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits']
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**mmcif_feats, **templates_result.features, **msa_features}
This diff is collapsed.
...@@ -25,39 +25,67 @@ def np_to_tensor_dict( ...@@ -25,39 +25,67 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
tensor_dict = {k: torch.tensor(v) for k, v in np_example.items() if k in features} tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
}
return tensor_dict return tensor_dict
def make_data_config( def make_data_config(
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str,
num_res: int, num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]: ) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config.data) cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if(mode_cfg.crop_size is None):
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
if cfg.common.use_templates: if cfg.common.use_templates:
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
with cfg.unlocked(): if(cfg[mode].supervised):
cfg.eval.crop_size = num_res feature_names += cfg.common.supervised_features
return cfg, feature_names return cfg, feature_names
def np_example_to_features(np_example: FeatureDict,
def np_example_to_features(
np_example: FeatureDict,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
random_seed: int = 0): mode: str,
batch_mode: str,
):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example['seq_length'][0]) num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(config, num_res=num_res) cfg, feature_names = make_data_config(
config, mode=mode, num_res=num_res
)
if 'deletion_matrix_int' in np_example: if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = ( np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32)) np_example.pop('deletion_matrix_int').astype(np.float32)
)
if batch_mode == 'clamped':
np_example['use_clamped_fape'] = (
np.array(1.).astype(np.float32)
)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
torch.manual_seed(random_seed)
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names) np_example=np_example, features=feature_names
features = input_pipeline.process_tensors_from_config(tensor_dict, cfg) )
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode,
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
...@@ -71,9 +99,12 @@ class FeaturePipeline: ...@@ -71,9 +99,12 @@ class FeaturePipeline:
def process_features(self, def process_features(self,
raw_features: FeatureDict, raw_features: FeatureDict,
random_seed: int) -> FeatureDict: mode: str = 'train',
batch_mode: str = 'clamped',
) -> FeatureDict:
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
config=self.config, config=self.config,
random_seed=random_seed mode=mode,
batch_mode=batch_mode,
) )
from functools import partial
import torch import torch
from openfold.features import data_transforms from openfold.features import data_transforms
def nonensembled_transform_fns(data_config): def nonensembled_transform_fns(common_cfg, mode_cfg):
"""Input pipeline data transformers that are not ensembled.""" """Input pipeline data transformers that are not ensembled."""
common_cfg = data_config.common
transforms = [ transforms = [
data_transforms.cast_to_64bit_ints, data_transforms.cast_to_64bit_ints,
data_transforms.correct_msa_restypes, data_transforms.correct_msa_restypes,
...@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config): ...@@ -23,23 +22,36 @@ def nonensembled_transform_fns(data_config):
data_transforms.make_template_mask, data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_') data_transforms.make_pseudo_beta('template_')
]) ])
if(common_cfg.use_template_torsion_angles):
transforms.extend([
data_transforms.atom37_to_torsion_angles('template_'),
])
transforms.extend([ transforms.extend([
data_transforms.make_atom14_masks, data_transforms.make_atom14_masks,
]) ])
if(mode_cfg.supervised):
transforms.extend([
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(''),
data_transforms.make_pseudo_beta(''),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
])
return transforms return transforms
def ensembled_transform_fns(data_config): def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
"""Input pipeline data transformers that can be ensembled and averaged.""" """Input pipeline data transformers that can be ensembled and averaged."""
common_cfg = data_config.common
eval_cfg = data_config.eval
transforms = [] transforms = []
if common_cfg.reduce_msa_clusters_by_max_templates: if common_cfg.reduce_msa_clusters_by_max_templates:
pad_msa_clusters = eval_cfg.max_msa_clusters - eval_cfg.max_templates pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates
else: else:
pad_msa_clusters = eval_cfg.max_msa_clusters pad_msa_clusters = mode_cfg.max_msa_clusters
max_msa_clusters = pad_msa_clusters max_msa_clusters = pad_msa_clusters
max_extra_msa = common_cfg.max_extra_msa max_extra_msa = common_cfg.max_extra_msa
...@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config): ...@@ -53,8 +65,10 @@ def ensembled_transform_fns(data_config):
# the clustering and full MSA profile do not leak information about # the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa(common_cfg.masked_msa, data_transforms.make_masked_msa(
eval_cfg.masked_msa_replace_fraction) common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction
)
) )
if common_cfg.msa_cluster_features: if common_cfg.msa_cluster_features:
...@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config): ...@@ -69,44 +83,55 @@ def ensembled_transform_fns(data_config):
transforms.append(data_transforms.make_msa_feat()) transforms.append(data_transforms.make_msa_feat())
crop_feats = dict(eval_cfg.feat) crop_feats = dict(common_cfg.feat)
if eval_cfg.fixed_size: if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats))) transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=torch.Generator().seed()
))
transforms.append(data_transforms.make_fixed_size( transforms.append(data_transforms.make_fixed_size(
crop_feats, crop_feats,
pad_msa_clusters, pad_msa_clusters,
common_cfg.max_extra_msa, common_cfg.max_extra_msa,
eval_cfg.crop_size, mode_cfg.crop_size,
eval_cfg.max_templates mode_cfg.max_templates
)) ))
else: else:
transforms.append(data_transforms.crop_templates(eval_cfg.max_templates)) transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
)
return transforms return transforms
def process_tensors_from_config(tensors, data_config):
def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped'
):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
def wrap_ensemble_fn(data, i): def wrap_ensemble_fn(data, i):
"""Function to be mapped over the ensemble dimension.""" """Function to be mapped over the ensemble dimension."""
d = data.copy() d = data.copy()
fns = ensembled_transform_fns(data_config) fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns) fn = compose(fns)
d['ensemble_index'] = i d['ensemble_index'] = i
return fn(d) return fn(d)
eval_cfg = data_config.eval
tensors = compose( tensors = compose(
nonensembled_transform_fns(data_config) nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors) )(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0) tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = eval_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
if data_config.common.resample_msa_in_recycling: if common_cfg.resample_msa_in_recycling:
# Separate batch per ensembling & recycling step. # Separate batch per ensembling & recycling step.
num_ensemble *= data_config.common.num_recycle + 1 num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1: if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x), tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x),
...@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config): ...@@ -116,16 +141,20 @@ def process_tensors_from_config(tensors, data_config):
return tensors return tensors
@data_transforms.curry1 @data_transforms.curry1
def compose(x, fs): def compose(x, fs):
for f in fs: for f in fs:
x = f(x) x = f(x)
return x return x
def map_fn(fun, x): def map_fn(fun, x):
ensembles = [fun(elem) for elem in x] ensembles = [fun(elem) for elem in x]
features = ensembles[0].keys() features = ensembles[0].keys()
ensembled_dict = {} ensembled_dict = {}
for feat in features: for feat in features:
ensembled_dict[feat] = torch.stack([dict_i[feat] for dict_i in ensembles]) ensembled_dict[feat] = torch.stack(
[dict_i[feat] for dict_i in ensembles], dim=-1
)
return ensembled_dict return ensembled_dict
"""Parses the mmCIF file format.""" """Parses the mmCIF file format."""
import collections import collections
import dataclasses import dataclasses
import io import io
import json
import logging
import os
from typing import Any, Mapping, Optional, Sequence, Tuple from typing import Any, Mapping, Optional, Sequence, Tuple
from absl import logging
from Bio import PDB from Bio import PDB
from Bio.Data import SCOPData from Bio.Data import SCOPData
import numpy as np
import openfold.np.residue_constants as residue_constants
# Type aliases: # Type aliases:
ChainId = str ChainId = str
...@@ -369,3 +374,73 @@ def _get_protein_chains( ...@@ -369,3 +374,73 @@ def _get_protein_chains(
def _is_set(data: str) -> bool: def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'.""" """Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?') return data not in ('.', '?')
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.'
)
chain = relevant_chains[0]
# Extract the coordinates
num_res = len(mmcif_object.chain_to_seqres[chain_id])
all_atom_positions = np.zeros(
[num_res, residue_constants.atom_type_num, 3], dtype=np.float32
)
all_atom_mask = np.zeros(
[num_res, residue_constants.atom_type_num], dtype=np.float32
)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
return all_atom_positions, all_atom_mask
def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
with open(os.path.join(mmcif_dir, f), 'r') as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None):
logging.warning(f'Could not parse {f}. Skipping...')
continue
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data['release_date'] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data
with open(out_path, 'w') as fp:
fp.write(json.dumps(data))
...@@ -18,6 +18,7 @@ class HHSearch: ...@@ -18,6 +18,7 @@ class HHSearch:
*, *,
binary_path: str, binary_path: str,
databases: Sequence[str], databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000): maxseq: int = 1_000_000):
"""Initializes the Python HHsearch wrapper. """Initializes the Python HHsearch wrapper.
...@@ -26,6 +27,7 @@ class HHSearch: ...@@ -26,6 +27,7 @@ class HHSearch:
databases: A sequence of HHsearch database paths. This should be the databases: A sequence of HHsearch database paths. This should be the
common prefix for the database files (i.e. up to but not including common prefix for the database files (i.e. up to but not including
_hhm.ffindex etc.) _hhm.ffindex etc.)
n_cpu: The number of CPUs to use
maxseq: The maximum number of rows in an input alignment. Note that this maxseq: The maximum number of rows in an input alignment. Note that this
parameter is only supported in HHBlits version 3.1 and higher. parameter is only supported in HHBlits version 3.1 and higher.
...@@ -34,6 +36,7 @@ class HHSearch: ...@@ -34,6 +36,7 @@ class HHSearch:
""" """
self.binary_path = binary_path self.binary_path = binary_path
self.databases = databases self.databases = databases
self.n_cpu = n_cpu
self.maxseq = maxseq self.maxseq = maxseq
for database_path in self.databases: for database_path in self.databases:
...@@ -56,7 +59,8 @@ class HHSearch: ...@@ -56,7 +59,8 @@ class HHSearch:
cmd = [self.binary_path, cmd = [self.binary_path,
'-i', input_path, '-i', input_path,
'-o', hhr_path, '-o', hhr_path,
'-maxseq', str(self.maxseq) '-maxseq', str(self.maxseq),
'-cpu', str(self.n_cpu),
] + db_cmd ] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', ' '.join(cmd))
......
...@@ -3,14 +3,12 @@ ...@@ -3,14 +3,12 @@
from concurrent import futures from concurrent import futures
import glob import glob
import logging
import os import os
import subprocess import subprocess
from typing import Any, Callable, Mapping, Optional, Sequence from typing import Any, Callable, Mapping, Optional, Sequence
from urllib import request from urllib import request
from absl import logging
from openfold.features.np import utils from openfold.features.np import utils
......
"""Common utilities for data pipeline tools.""" """Common utilities for data pipeline tools."""
import contextlib import contextlib
import datetime
import shutil import shutil
import tempfile import tempfile
import time import time
...@@ -25,3 +26,9 @@ def timing(msg: str): ...@@ -25,3 +26,9 @@ def timing(msg: str):
yield yield
toc = time.time() toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic) logging.info('Finished %s in %.3f seconds', msg, toc - tic)
def to_date(s: str):
return datetime.datetime(
year=int(s[:4]), month=int(s[5:7]), day=int(s[8:10])
)
...@@ -2,16 +2,17 @@ ...@@ -2,16 +2,17 @@
import dataclasses import dataclasses
import datetime import datetime
import glob import glob
import json
import logging
import os import os
import re import re
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple from typing import Any, Dict, Mapping, Optional, Sequence, Tuple
from absl import logging
import numpy as np import numpy as np
from openfold.features import parsers, mmcif_parsing from openfold.features import parsers, mmcif_parsing
from openfold.features.np import kalign from openfold.features.np import kalign
from openfold.features.np.utils import to_date
from openfold.np import residue_constants from openfold.np import residue_constants
...@@ -74,7 +75,7 @@ class LengthError(PrefilterError): ...@@ -74,7 +75,7 @@ class LengthError(PrefilterError):
TEMPLATE_FEATURES = { TEMPLATE_FEATURES = {
'template_aatype': np.int64, 'template_aatype': np.int64,
'template_all_atom_masks': np.float32, 'template_all_atom_mask': np.float32,
'template_all_atom_positions': np.float32, 'template_all_atom_positions': np.float32,
'template_domain_names': np.object, 'template_domain_names': np.object,
'template_sequence': np.object, 'template_sequence': np.object,
...@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -133,23 +134,40 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
return result return result
def generate_release_dates_cache(mmcif_dir: str, out_path: str):
dates = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
path = os.path.join(mmcif_dir, f)
with open(path, 'r') as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_string
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}. Skipping...')
continue
mmcif = mmcif.mmcif_object
release_date = mmcif.header['release_date']
dates[file_id] = release_date
with open(out_path, 'r') as fp:
fp.write(json.dumps(dates))
def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]: def _parse_release_dates(path: str) -> Mapping[str, datetime.datetime]:
"""Parses release dates file, returns a mapping from PDBs to release dates.""" """Parses release dates file, returns a mapping from PDBs to release dates."""
if path.endswith('txt'): with open(path, 'r') as fp:
release_dates = {} data = json.load(fp)
with open(path, 'r') as f:
for line in f:
pdb_id, date = line.split(':')
date = date.strip()
# Python 3.6 doesn't have datetime.date.fromisoformat() which is about
# 90x faster than strptime. However, splitting the string manually is
# about 10x faster than strptime.
release_dates[pdb_id.strip()] = datetime.datetime(
year=int(date[:4]), month=int(date[5:7]), day=int(date[8:10]))
return release_dates
else:
raise ValueError('Invalid format of the release date file %s.' % path)
return {
pdb:to_date(v) for pdb,d in data.items() for k,v in d.items()
if k == "release_date"
}
def _assess_hhsearch_hit( def _assess_hhsearch_hit(
hit: parsers.TemplateHit, hit: parsers.TemplateHit,
...@@ -419,42 +437,14 @@ def _get_atom_positions( ...@@ -419,42 +437,14 @@ def _get_atom_positions(
auth_chain_id: str, auth_chain_id: str,
max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]: max_ca_ca_distance: float) -> Tuple[np.ndarray, np.ndarray]:
"""Gets atom positions and mask from a list of Biopython Residues.""" """Gets atom positions and mask from a list of Biopython Residues."""
num_res = len(mmcif_object.chain_to_seqres[auth_chain_id]) coords_with_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=auth_chain_id
relevant_chains = [c for c in mmcif_object.structure.get_chains() )
if c.id == auth_chain_id] all_atom_positions, all_atom_mask = coords_with_mask
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {auth_chain_id}.')
chain = relevant_chains[0]
all_positions = np.zeros([num_res, residue_constants.atom_type_num, 3])
all_positions_mask = np.zeros([num_res, residue_constants.atom_type_num],
dtype=np.int64)
for res_index in range(num_res):
pos = np.zeros([residue_constants.atom_type_num, 3], dtype=np.float32)
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[auth_chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
# Put the coordinates of the selenium atom in the sulphur column.
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
all_positions[res_index] = pos
all_positions_mask[res_index] = mask
_check_residue_distances( _check_residue_distances(
all_positions, all_positions_mask, max_ca_ca_distance) all_atom_positions, all_atom_mask, max_ca_ca_distance
return all_positions, all_positions_mask )
return all_atom_positions, all_atom_mask
def _extract_template_features( def _extract_template_features(
...@@ -579,7 +569,7 @@ def _extract_template_features( ...@@ -579,7 +569,7 @@ def _extract_template_features(
return ( return (
{ {
'template_all_atom_positions': np.array(templates_all_atom_positions), 'template_all_atom_positions': np.array(templates_all_atom_positions),
'template_all_atom_masks': np.array(templates_all_atom_masks), 'template_all_atom_mask': np.array(templates_all_atom_masks),
'template_sequence': output_templates_sequence.encode(), 'template_sequence': output_templates_sequence.encode(),
'template_aatype': np.array(templates_aatype), 'template_aatype': np.array(templates_aatype),
'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(), 'template_domain_names': f'{pdb_id.lower()}_{chain_id}'.encode(),
......
...@@ -19,7 +19,6 @@ import torch.nn as nn ...@@ -19,7 +19,6 @@ import torch.nn as nn
from openfold.utils.feats import ( from openfold.utils.feats import (
pseudo_beta_fn, pseudo_beta_fn,
atom37_to_torsion_angles,
build_extra_msa_feat, build_extra_msa_feat,
build_template_angle_feat, build_template_angle_feat,
build_template_pair_feat, build_template_pair_feat,
...@@ -115,22 +114,17 @@ class AlphaFold(nn.Module): ...@@ -115,22 +114,17 @@ class AlphaFold(nn.Module):
batch, batch,
) )
# Build template angle feats single_template_embeds = {}
angle_feats = atom37_to_torsion_angles( if(self.config.template.embed_angles):
single_template_feats["template_aatype"],
single_template_feats["template_all_atom_positions"],#.float(),
single_template_feats["template_all_atom_masks"],#.float(),
eps=self.config.template.eps,
)
template_angle_feat = build_template_angle_feat( template_angle_feat = build_template_angle_feat(
angle_feats, single_template_feats,
single_template_feats["template_aatype"],
) )
# [*, S_t, N, C_m] # [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat) a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t] # [*, S_t, N, N, C_t]
t = build_template_pair_feat( t = build_template_pair_feat(
single_template_feats, single_template_feats,
...@@ -145,12 +139,12 @@ class AlphaFold(nn.Module): ...@@ -145,12 +139,12 @@ class AlphaFold(nn.Module):
_mask_trans=self.config._mask_trans _mask_trans=self.config._mask_trans
) )
template_embeds.append({ single_template_embeds.update({
"angle": a,
"pair": t, "pair": t,
"torsion_mask": angle_feats["torsion_angles_mask"]
}) })
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap( template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim), partial(torch.cat, dim=templ_dim),
template_embeds, template_embeds,
...@@ -164,11 +158,15 @@ class AlphaFold(nn.Module): ...@@ -164,11 +158,15 @@ class AlphaFold(nn.Module):
) )
t = t * (torch.sum(batch["template_mask"]) > 0) t = t * (torch.sum(batch["template_mask"]) > 0)
return { ret = {}
"template_angle_embedding": template_embeds["angle"], if(self.config.template.embed_angles):
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({
"template_pair_embedding": t, "template_pair_embedding": t,
"torsion_angles_mask": template_embeds["torsion_mask"], })
}
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev): def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Primary output dictionary # Primary output dictionary
...@@ -197,7 +195,7 @@ class AlphaFold(nn.Module): ...@@ -197,7 +195,7 @@ class AlphaFold(nn.Module):
) )
# Inject information from previous recycling iterations # Inject information from previous recycling iterations
if(self.config.no_cycles > 1): if(self.config.num_recycle > 0):
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]): if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m] # [*, N, C_m]
...@@ -241,7 +239,7 @@ class AlphaFold(nn.Module): ...@@ -241,7 +239,7 @@ class AlphaFold(nn.Module):
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled): if(self.config.template.enabled):
template_feats = { template_feats = {
k:v for k,v in feats.items() if "template_" in k k:v for k,v in feats.items() if k.startswith("template_")
} }
template_embeds = self.embed_templates( template_embeds = self.embed_templates(
template_feats, template_feats,
...@@ -261,7 +259,7 @@ class AlphaFold(nn.Module): ...@@ -261,7 +259,7 @@ class AlphaFold(nn.Module):
) )
# [*, S, N] # [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2 [feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
) )
...@@ -374,7 +372,8 @@ class AlphaFold(nn.Module): ...@@ -374,7 +372,8 @@ class AlphaFold(nn.Module):
"template_aatype" ([*, N_templ, N_res]) "template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown)) than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3]) "template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37]) "template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask Template atom coordinate mask
...@@ -392,13 +391,13 @@ class AlphaFold(nn.Module): ...@@ -392,13 +391,13 @@ class AlphaFold(nn.Module):
self._disable_activation_checkpointing() self._disable_activation_checkpointing()
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.no_cycles): for cycle_no in range(self.config.num_recycle + 1):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == (self.config.no_cycles - 1)) is_final_iter = (cycle_no == self.config.num_recycle)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
# Sidestep AMP bug discussed in pytorch issue #65766 # Sidestep AMP bug discussed in pytorch issue #65766
if(is_final_iter): if(is_final_iter):
......
...@@ -29,6 +29,7 @@ class ExponentialMovingAverage: ...@@ -29,6 +29,7 @@ class ExponentialMovingAverage:
self.decay = decay self.decay = decay
def _update_state_dict_(self, update, state_dict): def _update_state_dict_(self, update, state_dict):
with torch.no_grad():
for k, v in update.items(): for k, v in update.items():
stored = state_dict[k] stored = state_dict[k]
if(not isinstance(v, torch.Tensor)): if(not isinstance(v, torch.Tensor)):
......
...@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): ...@@ -49,32 +49,6 @@ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks):
return pseudo_beta return pseudo_beta
def get_chi_atom_indices():
"""Returns atom indices needed to compute chi angles for all residue types.
Returns:
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
in the order specified in rc.restypes + unknown residue type
at the end. For chi angles which are not defined on the residue, the
positions indices are by default set to 0.
"""
chi_atom_indices = []
for residue_name in rc.restypes:
residue_name = rc.restype_1to3[residue_name]
residue_chi_angles = rc.chi_angles_atoms[residue_name]
atom_indices = []
for chi_angle in residue_chi_angles:
atom_indices.append(
[rc.atom_order[atom] for atom in chi_angle])
for _ in range(4 - len(atom_indices)):
atom_indices.append([0, 0, 0, 0]) # For chi angles not defined on the AA.
chi_atom_indices.append(atom_indices)
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
return chi_atom_indices
def atom14_to_atom37(atom14, batch): def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather( atom37_data = batched_gather(
atom14, atom14,
...@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch): ...@@ -88,320 +62,13 @@ def atom14_to_atom37(atom14, batch):
return atom37_data return atom37_data
def atom37_to_torsion_angles( def build_template_angle_feat(template_feats):
aatype: torch.Tensor, template_aatype = template_feats["template_aatype"]
all_atom_positions: torch.Tensor, torsion_angles_sin_cos = template_feats["template_torsion_angles_sin_cos"]
all_atom_mask: torch.Tensor,
eps: float = 1e-8,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""
Convert coordinates to torsion angles.
This function is extremely sensitive to floating point imprecisions
and should be run with double precision whenever possible.
Args:
aatype:
[*, N_res] residue indices
all_atom_positions:
[*, N_res, 37, 3] atom positions (in atom37
format)
all_atom_mask:
[*, N_res, 37] atom position mask
Returns:
Dictionary of the following features:
"torsion_angles_sin_cos" ([*, N_res, 7, 2])
Torsion angles
"alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
Alternate torsion angles (accounting for 180-degree symmetry)
"torsion_angles_mask" ([*, N_res, 7])
Torsion angles mask
"""
aatype = torch.clamp(aatype, max=20)
pad = all_atom_positions.new_zeros(
[*all_atom_positions.shape[:-3], 1, 37, 3]
)
prev_all_atom_positions = torch.cat(
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
)
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
pre_omega_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 1:3, :],
all_atom_positions[..., :2, :]
], dim=-2
)
phi_atom_pos = torch.cat(
[
prev_all_atom_positions[..., 2:3, :],
all_atom_positions[..., :3, :]
], dim=-2
)
psi_atom_pos = torch.cat(
[
all_atom_positions[..., :3, :],
all_atom_positions[..., 4:5, :]
], dim=-2
)
pre_omega_mask = (
torch.prod(prev_all_atom_mask[..., 1:3], dim=-1) *
torch.prod(all_atom_mask[..., :2], dim=-1)
)
phi_mask = (
prev_all_atom_mask[..., 2] *
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
)
psi_mask = (
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype) *
all_atom_mask[..., 4]
)
chi_atom_indices = torch.as_tensor(
get_chi_atom_indices(), device=aatype.device
)
atom_indices = chi_atom_indices[..., aatype, :, :]
chis_atom_pos = batched_gather(
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
)
chi_angles_mask = list(rc.chi_angles_mask)
chi_angles_mask.append([0., 0., 0., 0.])
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
chis_mask = chi_angles_mask[aatype, :]
chi_angle_atoms_mask = batched_gather(
all_atom_mask,
atom_indices,
dim=-1,
no_batch_dims=len(atom_indices.shape[:-2])
)
chi_angle_atoms_mask = torch.prod(
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
)
chis_mask = chis_mask * chi_angle_atoms_mask
torsions_atom_pos = torch.cat(
[
pre_omega_atom_pos[..., None, :, :],
phi_atom_pos[..., None, :, :],
psi_atom_pos[..., None, :, :],
chis_atom_pos,
], dim=-3
)
torsion_angles_mask = torch.cat(
[
pre_omega_mask[..., None],
phi_mask[..., None],
psi_mask[..., None],
chis_mask,
], dim=-1
)
torsion_frames = T.from_3_points(
torsions_atom_pos[..., 1, :],
torsions_atom_pos[..., 2, :],
torsions_atom_pos[..., 0, :],
eps=eps,
)
fourth_atom_rel_pos = torsion_frames.invert().apply(
torsions_atom_pos[..., 3, :]
)
torsion_angles_sin_cos = torch.stack(
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
)
denom = torch.sqrt(
torch.sum(
torch.square(torsion_angles_sin_cos),
dim=-1,
dtype=torsion_angles_sin_cos.dtype,
keepdims=True
) + eps
)
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
[1., 1., -1., 1., 1., 1., 1.],
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
rc.chi_pi_periodic,
)[aatype, ...]
mirror_torsion_angles = torch.cat(
[
all_atom_mask.new_ones(*aatype.shape, 3),
1. - 2. * chi_is_ambiguous
], dim=-1
)
alt_torsion_angles_sin_cos = ( alt_torsion_angles_sin_cos = (
torsion_angles_sin_cos * mirror_torsion_angles[..., None] template_feats["template_alt_torsion_angles_sin_cos"]
)
return {
"torsion_angles_sin_cos": torsion_angles_sin_cos,
"alt_torsion_angles_sin_cos": alt_torsion_angles_sin_cos,
"torsion_angles_mask": torsion_angles_mask,
}
def atom37_to_frames(
aatype: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
eps: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
batch_dims = len(aatype.shape[:-1])
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], '', dtype=object)
restype_rigidgroup_base_atom_names[:, 0, :] = ['C', 'CA', 'N']
restype_rigidgroup_base_atom_names[:, 3, :] = ['CA', 'C', 'O']
for restype, restype_letter in enumerate(rc.restypes):
resname = rc.restype_1to3[restype_letter]
for chi_idx in range(4):
if(rc.chi_angles_mask[restype][chi_idx]):
names = rc.chi_angles_atoms[resname][chi_idx]
restype_rigidgroup_base_atom_names[
restype, chi_idx + 4, :
] = names[1:]
restype_rigidgroup_mask = all_atom_mask.new_zeros(
(*aatype.shape[:-1], 21, 8),
)
restype_rigidgroup_mask[..., 0] = 1
restype_rigidgroup_mask[..., 3] = 1
restype_rigidgroup_mask[..., :20, 4:] = (
all_atom_mask.new_tensor(rc.chi_angles_mask)
)
lookuptable = rc.atom_order.copy()
lookuptable[''] = 0
lookup = np.vectorize(lambda x: lookuptable[x])
restype_rigidgroup_base_atom37_idx = lookup(
restype_rigidgroup_base_atom_names,
)
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
restype_rigidgroup_base_atom37_idx,
)
restype_rigidgroup_base_atom37_idx = (
restype_rigidgroup_base_atom37_idx.view(
*((1,) * batch_dims),
*restype_rigidgroup_base_atom37_idx.shape
)
)
residx_rigidgroup_base_atom37_idx = batched_gather(
restype_rigidgroup_base_atom37_idx,
aatype,
dim=-3,
no_batch_dims=batch_dims,
)
base_atom_pos = batched_gather(
all_atom_positions,
residx_rigidgroup_base_atom37_idx,
dim=-2,
no_batch_dims=len(all_atom_positions.shape[:-2]),
)
gt_frames = T.from_3_points(
p_neg_x_axis=base_atom_pos[..., 0, :],
origin=base_atom_pos[..., 1, :],
p_xy_plane=base_atom_pos[..., 2, :],
eps=eps,
)
group_exists = batched_gather(
restype_rigidgroup_mask,
aatype,
dim=-2,
no_batch_dims=batch_dims,
) )
torsion_angles_mask = template_feats["template_torsion_angles_mask"]
gt_atoms_exist = batched_gather(
all_atom_mask,
residx_rigidgroup_base_atom37_idx,
dim=-1,
no_batch_dims=len(all_atom_mask.shape[:-1])
)
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
rots[..., 0, 0, 0] = -1
rots[..., 0, 2, 2] = -1
gt_frames = gt_frames.compose(T(rots, None))
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
*((1,) * batch_dims), 21, 8
)
restype_rigidgroup_rots = torch.eye(
3, dtype=all_atom_mask.dtype, device=aatype.device
)
restype_rigidgroup_rots = torch.tile(
restype_rigidgroup_rots,
(*((1,) * batch_dims), 21, 8, 1, 1),
)
for resname, _ in rc.residue_atom_renaming_swaps.items():
restype = rc.restype_order[
rc.restype_3to1[resname]
]
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
residx_rigidgroup_is_ambiguous = batched_gather(
restype_rigidgroup_is_ambiguous,
aatype,
dim=-2,
no_batch_dims=batch_dims,
)
residx_rigidgroup_ambiguity_rot = batched_gather(
restype_rigidgroup_rots,
aatype,
dim=-4,
no_batch_dims=batch_dims,
)
alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))
gt_frames_tensor = gt_frames.to_4x4()
alt_gt_frames_tensor = alt_gt_frames.to_4x4()
return {
'rigidgroups_gt_frames': gt_frames_tensor,
'rigidgroups_gt_exists': gt_exists,
'rigidgroups_group_exists': group_exists,
'rigidgroups_group_is_ambiguous': residx_rigidgroup_is_ambiguous,
'rigidgroups_alt_gt_frames': alt_gt_frames_tensor,
}
def build_template_angle_feat(angle_feats, template_aatype):
torsion_angles_sin_cos = angle_feats["torsion_angles_sin_cos"]
alt_torsion_angles_sin_cos = angle_feats["alt_torsion_angles_sin_cos"]
torsion_angles_mask = angle_feats["torsion_angles_mask"]
template_angle_feat = torch.cat( template_angle_feat = torch.cat(
[ [
nn.functional.one_hot(template_aatype, 22), nn.functional.one_hot(template_aatype, 22),
...@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e ...@@ -465,7 +132,7 @@ def build_template_pair_feat(batch, min_bin, max_bin, no_bins, eps=1e-20, inf=1e
eps + torch.sum(affine_vec ** 2, dim=-1) eps + torch.sum(affine_vec ** 2, dim=-1)
) )
t_aa_masks = batch["template_all_atom_masks"] t_aa_masks = batch["template_all_atom_mask"]
template_mask = ( template_mask = (
t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c] t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., c]
) )
...@@ -534,53 +201,6 @@ def build_msa_feat(batch): ...@@ -534,53 +201,6 @@ def build_msa_feat(batch):
return batch return batch
def build_ambiguity_feats(batch: Dict[str, torch.Tensor]) -> None:
"""
Compute features required by compute_renamed_ground_truth (Alg. 26)
Args:
batch:
str/tensor dictionary containing:
* atom14_gt_positions: [*, N, 14, 3] ground truth pos.
* atom14_gt_exists: [*, N, 14] atom mask
* aatype: [*, N] residue indices
Returns:
str/tensor dictionary containing:
* atom14_atom_is_ambiguous: [*, N, 14] mask of ambiguous atoms
* atom14_alt_gt_positions: [*, N, 14, 3] renamed positions
"""
ambiguous_atoms = (
batch["atom14_gt_positions"].new_tensor(
rc.restype_atom14_ambiguous_atoms
)
)
atom14_atom_is_ambiguous = ambiguous_atoms[batch["aatype"], ...]
# Swap pairs of ambiguous positions
swap_idx = rc.restype_atom14_ambiguous_atoms_swap_idx
swap_mat = np.eye(swap_idx.shape[-1])[swap_idx] # one-hot swap_idx
swap_mat = batch["atom14_gt_positions"].new_tensor(swap_mat)
swap_mat = swap_mat[batch["aatype"], ...]
atom14_alt_gt_positions = (
torch.sum(
batch["atom14_gt_positions"][..., None, :] * swap_mat[..., None],
dim=-3
)
)
atom14_alt_gt_exists = (
torch.sum(
batch["atom14_gt_exists"][..., None] * swap_mat, dim=-2
)
)
return {
"atom14_atom_is_ambiguous": atom14_atom_is_ambiguous,
"atom14_alt_gt_positions": atom14_alt_gt_positions,
"atom14_alt_gt_exists": atom14_alt_gt_exists,
}
def torsion_angles_to_frames( def torsion_angles_to_frames(
t: T, t: T,
alpha: torch.Tensor, alpha: torch.Tensor,
......
...@@ -18,6 +18,7 @@ import ml_collections ...@@ -18,6 +18,7 @@ import ml_collections
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from openfold.np import residue_constants from openfold.np import residue_constants
...@@ -117,7 +118,9 @@ def compute_fape( ...@@ -117,7 +118,9 @@ def compute_fape(
return normed_error return normed_error
# DISCREPANCY: figure out if loss clamping happens in 90% of each bach or in 90% of batches # DISCREPANCY: From the way this function is written, it's possible that
# DeepMind clamped 90% of individual residue losses, not 90% of all batches.
# We defer to the text, which seems to imply the latter.
def backbone_loss( def backbone_loss(
backbone_affine_tensor: torch.Tensor, backbone_affine_tensor: torch.Tensor,
backbone_affine_mask: torch.Tensor, backbone_affine_mask: torch.Tensor,
...@@ -142,7 +145,6 @@ def backbone_loss( ...@@ -142,7 +145,6 @@ def backbone_loss(
length_scale=loss_unit_distance, length_scale=loss_unit_distance,
eps=eps, eps=eps,
) )
if(use_clamped_fape is not None): if(use_clamped_fape is not None):
unclamped_fape_loss = compute_fape( unclamped_fape_loss = compute_fape(
pred_aff, pred_aff,
...@@ -162,7 +164,7 @@ def backbone_loss( ...@@ -162,7 +164,7 @@ def backbone_loss(
) )
# Take the mean over the layer dimension # Take the mean over the layer dimension
fape_loss = torch.mean(fape_loss, dim=0) fape_loss = torch.mean(fape_loss, dim=-1)
return fape_loss return fape_loss
...@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module): ...@@ -1461,41 +1463,12 @@ class AlphaFoldLoss(nn.Module):
**self.config.violation, **self.config.violation,
) )
if("atom14_atom_is_ambiguous" not in batch.keys()):
batch.update(feats.build_ambiguity_feats(batch))
if("renamed_atom14_gt_positions" not in out.keys()): if("renamed_atom14_gt_positions" not in out.keys()):
batch.update(compute_renamed_ground_truth( batch.update(compute_renamed_ground_truth(
batch, batch,
out["sm"]["positions"][-1], out["sm"]["positions"][-1],
)) ))
if("backbone_affine_tensor" not in batch.keys()):
batch.update(feats.atom37_to_frames(eps=self.config.eps, **batch))
# TODO: Verify that this is correct
batch["backbone_affine_tensor"] = (
batch["rigidgroups_gt_frames"][..., 0, :, :]
)
batch["backbone_affine_mask"] = (
batch["rigidgroups_gt_exists"][..., 0]
)
if("chi_angles_sin_cos" not in batch.keys()):
with torch.no_grad():
batch.update(feats.atom37_to_torsion_angles(
aatype=batch["aatype"],
all_atom_positions=batch["all_atom_positions"].double(),
all_atom_mask=batch["all_atom_mask"].double(),
eps=self.config.eps,
))
# TODO: Verify that this is correct
batch["chi_angles_sin_cos"] = (
batch["torsion_angles_sin_cos"][..., 3:, :]
).to(batch["all_atom_mask"].dtype)
batch["chi_mask"] = batch["torsion_angles_mask"][..., 3:].to(batch["all_atom_mask"].dtype)
loss_fns = { loss_fns = {
"distogram": "distogram":
lambda: distogram_loss( lambda: distogram_loss(
......
...@@ -15,17 +15,17 @@ ...@@ -15,17 +15,17 @@
import argparse import argparse
from datetime import date from datetime import date
import pickle import logging
import os import os
# A hack to get OpenMM and PyTorch to peacefully coexist # A hack to get OpenMM and PyTorch to peacefully coexist
os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL" os.environ["OPENMM_DEFAULT_PLATFORM"] = "OpenCL"
import pickle
import random import random
import sys import sys
from openfold.features import templates, feature_pipeline from openfold.features import templates, feature_pipeline, data_pipeline
from openfold.features.np import data_pipeline
import time import time
...@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import ( ...@@ -43,28 +43,29 @@ from openfold.utils.tensor_utils import (
tensor_tree_map, tensor_tree_map,
) )
MAX_TEMPLATE_HITS = 20 from scripts.utils import add_data_args
def main(args): def main(args):
config = model_config(args.model_name) config = model_config(args.model_name)
model = AlphaFold(config.model) model = AlphaFold(config.model)
model = model.eval() model = model.eval()
import_jax_weights_(model, args.param_path) import_jax_weights_(model, args.param_path)
model = model.to(args.device) model = model.to(args.model_device)
# FEATURE COLLECTION AND PROCESSING # FEATURE COLLECTION AND PROCESSING
use_small_bfd = args.preset == "reduced_dbs"
num_ensemble = 1 num_ensemble = 1
template_featurizer = templates.TemplateHitFeaturizer( template_featurizer = templates.TemplateHitFeaturizer(
mmcif_dir=args.template_mmcif_dir, mmcif_dir=args.template_mmcif_dir,
max_template_date=args.max_template_date, max_template_date=args.max_template_date,
max_hits=MAX_TEMPLATE_HITS, max_hits=args.max_template_hits,
kalign_binary_path=args.kalign_binary_path, kalign_binary_path=args.kalign_binary_path,
release_dates_path=None, release_dates_path=None,
obsolete_pdbs_path=args.obsolete_pdbs_path obsolete_pdbs_path=args.obsolete_pdbs_path
) )
use_small_bfd=(args.bfd_database_path is None)
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path, hhblits_binary_path=args.hhblits_binary_path,
...@@ -76,6 +77,7 @@ def main(args): ...@@ -76,6 +77,7 @@ def main(args):
small_bfd_database_path=args.small_bfd_database_path, small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path, pdb70_database_path=args.pdb70_database_path,
use_small_bfd=use_small_bfd, use_small_bfd=use_small_bfd,
no_cpus=args.cpus,
) )
data_processor = data_pipeline.DataPipeline( data_processor = data_pipeline.DataPipeline(
...@@ -87,7 +89,7 @@ def main(args): ...@@ -87,7 +89,7 @@ def main(args):
random_seed = args.data_random_seed random_seed = args.data_random_seed
if random_seed is None: if random_seed is None:
random_seed = random.randrange(sys.maxsize) random_seed = random.randrange(sys.maxsize)
config.data.eval.num_ensemble = num_ensemble config.data.predict.num_ensemble = num_ensemble
feature_processor = feature_pipeline.FeaturePipeline(config) feature_processor = feature_pipeline.FeaturePipeline(config)
if not os.path.exists(output_dir_base): if not os.path.exists(output_dir_base):
os.makedirs(output_dir_base) os.makedirs(output_dir_base)
...@@ -95,7 +97,7 @@ def main(args): ...@@ -95,7 +97,7 @@ def main(args):
if not os.path.exists(alignment_dir): if not os.path.exists(alignment_dir):
os.makedirs(alignment_dir) os.makedirs(alignment_dir)
print("Generating features...") logging.info("Generating features...")
alignment_runner.run( alignment_runner.run(
args.fasta_path, alignment_dir args.fasta_path, alignment_dir
) )
...@@ -105,42 +107,20 @@ def main(args): ...@@ -105,42 +107,20 @@ def main(args):
) )
processed_feature_dict = feature_processor.process_features( processed_feature_dict = feature_processor.process_features(
feature_dict, random_seed feature_dict, mode='predict',
) )
for k, v in processed_feature_dict.items(): logging.info("Executing model...")
print(k)
print(v.shape)
print("Executing model...")
batch = processed_feature_dict batch = processed_feature_dict
with torch.no_grad(): with torch.no_grad():
batch = { batch = {
k:torch.as_tensor(v, device=args.device) k:torch.as_tensor(v, device=args.model_device)
for k,v in batch.items() for k,v in batch.items()
} }
longs = [
"aatype",
"template_aatype",
"extra_msa",
"residx_atom37_to_atom14",
"residx_atom14_to_atom37",
"true_msa",
"residue_index",
]
for l in longs:
batch[l] = batch[l].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
make_contig = lambda t: t.contiguous()
batch = tensor_tree_map(make_contig, batch)
t = time.time() t = time.time()
out = model(batch) out = model(batch)
print(f"Inference time: {time.time() - t}") logging.info(f"Inference time: {time.time() - t}")
# Toss out the recycling dimensions --- we don't need them anymore # Toss out the recycling dimensions --- we don't need them anymore
batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch) batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), batch)
...@@ -159,8 +139,6 @@ def main(args): ...@@ -159,8 +139,6 @@ def main(args):
b_factors=plddt_b_factors b_factors=plddt_b_factors
) )
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
**config.relax **config.relax
) )
...@@ -168,7 +146,7 @@ def main(args): ...@@ -168,7 +146,7 @@ def main(args):
# Relax the prediction. # Relax the prediction.
t = time.time() t = time.time()
relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein) relaxed_pdb_str, _, _ = amber_relaxer.process(prot=unrelaxed_protein)
print(f"Relaxation time: {time.time() - t}") logging.info(f"Relaxation time: {time.time() - t}")
# Save the relaxed PDB. # Save the relaxed PDB.
relaxed_output_path = os.path.join( relaxed_output_path = os.path.join(
...@@ -183,53 +161,14 @@ if __name__ == "__main__": ...@@ -183,53 +161,14 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"fasta_path", type=str, "fasta_path", type=str,
) )
parser.add_argument( add_data_args(parser)
'uniref90_database_path', type=str,
)
parser.add_argument(
'mgnify_database_path', type=str,
)
parser.add_argument(
'pdb70_database_path', type=str,
)
parser.add_argument(
'template_mmcif_dir', type=str,
)
parser.add_argument(
'--uniclust30_database_path', type=str,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default=os.getcwd(), "--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""", help="""Name of the directory in which to output the prediction""",
required=True required=True
) )
parser.add_argument( parser.add_argument(
"--device", type=str, default="cpu", "--model_device", type=str, default="cpu",
help="""Name of the device on which to run the model. Any valid torch help="""Name of the device on which to run the model. Any valid torch
device name is accepted (e.g. "cpu", "cuda:0")""" device name is accepted (e.g. "cpu", "cuda:0")"""
) )
...@@ -244,6 +183,10 @@ if __name__ == "__main__": ...@@ -244,6 +183,10 @@ if __name__ == "__main__":
automatically according to the model name from automatically according to the model name from
openfold/resources/params""" openfold/resources/params"""
) )
parser.add_argument(
"--cpus", type=int, default=4,
help="""Number of CPUs to use to run alignment tools"""
)
parser.add_argument( parser.add_argument(
'--preset', type=str, default='full_dbs', '--preset', type=str, default='full_dbs',
choices=('reduced_dbs', 'full_dbs') choices=('reduced_dbs', 'full_dbs')
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import argparse import argparse
import json import json
parser = argparse.ArgumentParser(description='''Outputs a DeepSpeed parser = argparse.ArgumentParser(description='''Outputs a DeepSpeed
configuration file to configuration file to
stdout''') stdout''')
......
import argparse
import logging
import os
import tempfile
import openfold.features.mmcif_parsing as mmcif_parsing
from openfold.features.data_pipeline import AlignmentRunner
from scripts.utils import add_data_args
def main(args):
# Build the alignment tool runner
alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
small_bfd_database_path=args.small_bfd_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus,
)
for f in os.listdir(args.input_dir):
path = os.path.join(args.input_dir, f)
is_mmcif = f.endswith('.cif')
is_fasta = f.endswith('.fasta')
file_id = os.path.splitext(f)[0]
seqs = {}
if(is_mmcif):
with open(path, 'r') as fp:
mmcif_str = fp.read()
mmcif = mmcif_parsing.parse(
file_id=file_id, mmcif_string=mmcif_str
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {f}...')
if(args.raise_errors):
raise list(mmcif.errors.values())[0]
else:
continue
mmcif = mmcif.mmcif_object
for k,v in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k])
seqs[chain_id] = v
elif(is_fasta):
with open(path, 'r') as fp:
fasta_str = fp.read()
input_seqs, _ = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
msg = f'More than one input_sequence found in {f}'
if(args.raise_errors):
raise ValueError(msg)
else:
logging.warning(msg)
input_sequence = input_seqs[0]
seqs[file_id] = input_sequence
else:
continue
for name, seq in seqs.items():
alignment_dir = os.path.join(args.output_dir, name)
if(os.path.isdir(alignment_dir)):
logging.info(f'{f} has already been processed. Skipping...')
continue
os.makedirs(alignment_dir)
if(not is_fasta):
fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
alignment_runner.run(
f if is_fasta else fasta_path, alignment_dir
)
if(not is_fasta):
os.remove(fasta_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"input_dir", type=str,
help="Path to directory containing mmCIF and/or FASTA files"
)
parser.add_argument(
"output_dir", type=str,
help="Directory in which to output alignments"
)
add_data_args(parser)
parser.add_argument(
"--raise_errors", type=bool, default=False,
help="Whether to crash on parsing errors"
)
parser.add_argument(
"--cpus", type=int, default=4,
help="Number of CPUs to use"
)
args = parser.parse_args()
main(args)
import argparse
from datetime import date
def add_data_args(parser: argparse.ArgumentParser):
parser.add_argument(
'uniref90_database_path', type=str,
)
parser.add_argument(
'mgnify_database_path', type=str,
)
parser.add_argument(
'pdb70_database_path', type=str,
)
parser.add_argument(
'template_mmcif_dir', type=str,
)
parser.add_argument(
'uniclust30_database_path', type=str,
)
parser.add_argument(
'--bfd_database_path', type=str, default=None,
)
parser.add_argument(
'--small_bfd_database_path', type=str, default=None
)
parser.add_argument(
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
parser.add_argument(
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
parser.add_argument(
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
parser.add_argument(
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
parser.add_argument(
'--max_template_date', type=str,
default=date.today().strftime("%Y-%m-%d"),
)
parser.add_argument(
'--max_template_hits', type=int, default=20,
)
parser.add_argument(
'--obsolete_pdbs_path', type=str, default=None
)
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,"
import importlib import importlib
import pkgutil import pkgutil
import sys import sys
......
...@@ -16,6 +16,7 @@ import torch ...@@ -16,6 +16,7 @@ import torch
import numpy as np import numpy as np
import unittest import unittest
import openfold.features.data_transforms as data_transforms
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
...@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase): ...@@ -168,7 +169,7 @@ class TestFeats(unittest.TestCase):
to_tensor = lambda t: torch.tensor(np.array(t)).cuda() to_tensor = lambda t: torch.tensor(np.array(t)).cuda()
batch = tree_map(to_tensor, batch, np.ndarray) batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = feats.atom37_to_frames(eps=1e-8, **batch) out_repro = data_transforms.atom37_to_frames(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro) out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
for k,v in out_gt.items(): for k,v in out_gt.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment