"vscode:/vscode.git/clone" did not exist on "af9ee90e98e4089855f9aab7ae56da40e0af16e5"
Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
This diff is collapsed.
...@@ -27,45 +27,45 @@ from openfold.np import residue_constants ...@@ -27,45 +27,45 @@ from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features( def make_sequence_features(
sequence: str, sequence: str, description: str, num_res: int
description: str,
num_res: int
) -> FeatureDict: ) -> FeatureDict:
"""Construct a feature dict of sequence features.""" """Construct a feature dict of sequence features."""
features = {} features = {}
features['aatype'] = residue_constants.sequence_to_onehot( features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence, sequence=sequence,
mapping=residue_constants.restype_order_with_x, mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True map_unknown_to_x=True,
) )
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array( features["domain_name"] = np.array(
[description.encode('utf-8')], dtype=np.object_ [description.encode("utf-8")], dtype=np.object_
) )
features['residue_index'] = np.array(range(num_res), dtype=np.int32) features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array( features["sequence"] = np.array(
[sequence.encode('utf-8')], dtype=np.object_ [sequence.encode("utf-8")], dtype=np.object_
) )
return features return features
def make_mmcif_features( def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
chain_id: str
) -> FeatureDict: ) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id] input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id]) description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence) num_res = len(input_sequence)
mmcif_feats = {} mmcif_feats = {}
mmcif_feats.update(make_sequence_features( mmcif_feats.update(
make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
description=description, description=description,
num_res=num_res, num_res=num_res,
)) )
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords( all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id mmcif_object=mmcif_object, chain_id=chain_id
...@@ -78,7 +78,7 @@ def make_mmcif_features( ...@@ -78,7 +78,7 @@ def make_mmcif_features(
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
) )
return mmcif_feats return mmcif_feats
...@@ -86,17 +86,20 @@ def make_mmcif_features( ...@@ -86,17 +86,20 @@ def make_mmcif_features(
def make_msa_features( def make_msa_features(
msas: Sequence[Sequence[str]], msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features.""" """Constructs a feature dict of MSA features."""
if not msas: if not msas:
raise ValueError('At least one MSA must be provided.') raise ValueError("At least one MSA must be provided.")
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
if not msa: if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.') raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa): for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences: if sequence in seen_sequences:
continue continue
...@@ -109,17 +112,19 @@ def make_msa_features( ...@@ -109,17 +112,19 @@ def make_msa_features(
num_res = len(msas[0][0]) num_res = len(msas[0][0])
num_alignments = len(int_msa) num_alignments = len(int_msa)
features = {} features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32) features["msa"] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
return features return features
class AlignmentRunner: class AlignmentRunner:
""" Runs alignment tools and saves the results """ """Runs alignment tools and saves the results"""
def __init__(self,
def __init__(
self,
jackhmmer_binary_path: str, jackhmmer_binary_path: str,
hhblits_binary_path: str, hhblits_binary_path: str,
hhsearch_binary_path: str, hhsearch_binary_path: str,
...@@ -161,105 +166,109 @@ class AlignmentRunner: ...@@ -161,105 +166,109 @@ class AlignmentRunner:
) )
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path, binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
databases=[pdb70_database_path]
) )
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
def run(self, def run(
self,
fasta_path: str, fasta_path: str,
output_dir: str, output_dir: str,
): ):
"""Runs alignment tools on a sequence""" """Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0] jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
) )
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m') uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, 'w') as f: with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m) f.write(uniref90_msa_as_a3m)
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0] jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
) )
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m') mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, 'w') as f: with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m) f.write(mgnify_msa_as_a3m)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr') pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, 'w') as f: with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result) f.write(hhsearch_result)
if self._use_small_bfd: if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0] jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto') fasta_path
with open(bfd_out_path, 'w') as f: )[0]
f.write(jackhmmer_small_bfd_result['sto']) bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
else: else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path) hhblits_bfd_uniclust_result = (
if(output_dir is not None): self.hhblits_bfd_uniclust_runner.query(fasta_path)
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m') )
with open(bfd_out_path, 'w') as f: if output_dir is not None:
f.write(hhblits_bfd_uniclust_result['a3m']) bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline: class DataPipeline:
"""Assembles input features.""" """Assembles input features."""
def __init__(self,
def __init__(
self,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool, use_small_bfd: bool,
): ):
self.template_featurizer = template_featurizer self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self, def _parse_alignment_output(
self,
alignment_dir: str, alignment_dir: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m') uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, 'r') as f: with open(uniref90_out_path, "r") as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m( uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m') mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, 'r') as f: with open(mgnify_out_path, "r") as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m( mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr') pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, 'r') as f: with open(pdb70_out_path, "r") as f:
hhsearch_hits = parsers.parse_hhr( hhsearch_hits = parsers.parse_hhr(f.read())
f.read()
)
if(self.use_small_bfd): if self.use_small_bfd:
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto') bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, 'r') as f: with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read() f.read()
) )
else: else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, 'r') as f: with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
return { return {
'uniref90_msa': uniref90_msa, "uniref90_msa": uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix, "uniref90_deletion_matrix": uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa, "mgnify_msa": mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix, "mgnify_deletion_matrix": mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits, "hhsearch_hits": hhsearch_hits,
'bfd_msa': bfd_msa, "bfd_msa": bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix, "bfd_deletion_matrix": bfd_deletion_matrix,
} }
def process_fasta(self, def process_fasta(
self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
) -> FeatureDict: ) -> FeatureDict:
...@@ -269,7 +278,8 @@ class DataPipeline: ...@@ -269,7 +278,8 @@ class DataPipeline:
input_seqs, input_descs = parsers.parse_fasta(fasta_str) input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1: if len(input_seqs) != 1:
raise ValueError( raise ValueError(
f'More than one input sequence found in {fasta_path}.') f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0] input_sequence = input_seqs[0]
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
...@@ -280,30 +290,31 @@ class DataPipeline: ...@@ -280,30 +290,31 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hits=alignments['hhsearch_hits'] hits=alignments["hhsearch_hits"],
) )
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
description=input_description, description=input_description,
num_res=num_res num_res=num_res,
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=( msas=(
alignments['uniref90_msa'], alignments["uniref90_msa"],
alignments['bfd_msa'], alignments["bfd_msa"],
alignments['mgnify_msa'] alignments["mgnify_msa"],
), ),
deletion_matrices=( deletion_matrices=(
alignments['uniref90_deletion_matrix'], alignments["uniref90_deletion_matrix"],
alignments['bfd_deletion_matrix'], alignments["bfd_deletion_matrix"],
alignments['mgnify_deletion_matrix'] alignments["mgnify_deletion_matrix"],
) ),
) )
return {**sequence_features, **msa_features, **templates_result.data} return {**sequence_features, **msa_features, **templates_result.data}
def process_mmcif(self, def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
...@@ -314,13 +325,11 @@ class DataPipeline: ...@@ -314,13 +325,11 @@ class DataPipeline:
If chain_id is None, it is assumed that there is only one chain If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown. in the object. Otherwise, a ValueError is thrown.
""" """
if(chain_id is None): if chain_id is None:
chains = mmcif.structure.get_chains() chains = mmcif.structure.get_chains()
chain = next(chains, None) chain = next(chains, None)
if(chain is None): if chain is None:
raise ValueError( raise ValueError("No chains in mmCIF file")
'No chains in mmCIF file'
)
chain_id = chain.id chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
...@@ -332,20 +341,20 @@ class DataPipeline: ...@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]), query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits'] hits=alignments["hhsearch_hits"],
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=( msas=(
alignments['uniref90_msa'], alignments["uniref90_msa"],
alignments['bfd_msa'], alignments["bfd_msa"],
alignments['mgnify_msa'] alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
), ),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
) )
return {**mmcif_feats, **templates_result.data, **msa_features} return {**mmcif_feats, **templates_result.data, **msa_features}
This diff is collapsed.
...@@ -26,10 +26,11 @@ from openfold.data import input_pipeline ...@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor] TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict( def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray], np_example: Mapping[str, np.ndarray],
features: Sequence[str], features: Sequence[str],
) -> TensorDict: ) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays. """Creates dict of tensors from a dict of NumPy arrays.
Args: Args:
...@@ -54,7 +55,7 @@ def make_data_config( ...@@ -54,7 +55,7 @@ def make_data_config(
cfg = copy.deepcopy(config) cfg = copy.deepcopy(config)
mode_cfg = cfg[mode] mode_cfg = cfg[mode]
with cfg.unlocked(): with cfg.unlocked():
if(mode_cfg.crop_size is None): if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
...@@ -62,7 +63,7 @@ def make_data_config( ...@@ -62,7 +63,7 @@ def make_data_config(
if cfg.common.use_templates: if cfg.common.use_templates:
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
if(cfg[mode].supervised): if cfg[mode].supervised:
feature_names += cfg.common.supervised_features feature_names += cfg.common.supervised_features
return cfg, feature_names return cfg, feature_names
...@@ -75,47 +76,47 @@ def np_example_to_features( ...@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode: str, batch_mode: str,
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example['seq_length'][0]) num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config( cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
config, mode=mode, num_res=num_res
)
if 'deletion_matrix_int' in np_example: if "deletion_matrix_int" in np_example:
np_example['deletion_matrix'] = ( np_example["deletion_matrix"] = np_example.pop(
np_example.pop('deletion_matrix_int').astype(np.float32) "deletion_matrix_int"
) ).astype(np.float32)
if batch_mode == 'clamped': if batch_mode == "clamped":
np_example['use_clamped_fape'] = ( np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
np.array(1.).astype(np.float32) elif batch_mode == "unclamped":
) np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode, tensor_dict,
cfg.common,
cfg[mode],
batch_mode=batch_mode,
) )
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
class FeaturePipeline: class FeaturePipeline:
def __init__(self, def __init__(
self,
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
):
self.config = config self.config = config
self.params = params self.params = params
def process_features(self, def process_features(
self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = 'train', mode: str = "train",
batch_mode: str = 'clamped', batch_mode: str = "clamped",
) -> FeatureDict: ) -> FeatureDict:
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
......
...@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_hhblits_profile, data_transforms.make_hhblits_profile,
] ]
if common_cfg.use_templates: if common_cfg.use_templates:
transforms.extend([ transforms.extend(
[
data_transforms.fix_templates_aatype, data_transforms.fix_templates_aatype,
data_transforms.make_template_mask, data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_') data_transforms.make_pseudo_beta("template_"),
]) ]
if(common_cfg.use_template_torsion_angles): )
transforms.extend([ if common_cfg.use_template_torsion_angles:
data_transforms.atom37_to_torsion_angles('template_'), transforms.extend(
]) [
data_transforms.atom37_to_torsion_angles("template_"),
transforms.extend([ ]
)
transforms.extend(
[
data_transforms.make_atom14_masks, data_transforms.make_atom14_masks,
]) ]
)
if(mode_cfg.supervised): if mode_cfg.supervised:
transforms.extend([ transforms.extend(
[
data_transforms.make_atom14_positions, data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames, data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(''), data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(''), data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames, data_transforms.get_backbone_frames,
data_transforms.get_chi_angles, data_transforms.get_chi_angles,
]) ]
)
return transforms return transforms
...@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms.sample_msa(max_msa_clusters, keep_extra=True) data_transforms.sample_msa(max_msa_clusters, keep_extra=True)
) )
if 'masked_msa' in common_cfg: if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that # Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about # the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa( data_transforms.make_masked_msa(
common_cfg.masked_msa, common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
mode_cfg.masked_msa_replace_fraction
) )
) )
...@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if mode_cfg.fixed_size: if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats))) transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size( transforms.append(
data_transforms.random_crop_to_size(
mode_cfg.crop_size, mode_cfg.crop_size,
mode_cfg.max_templates, mode_cfg.max_templates,
crop_feats, crop_feats,
mode_cfg.subsample_templates, mode_cfg.subsample_templates,
batch_mode=batch_mode, batch_mode=batch_mode,
seed=torch.Generator().seed() seed=torch.Generator().seed(),
)) )
transforms.append(data_transforms.make_fixed_size( )
transforms.append(
data_transforms.make_fixed_size(
crop_feats, crop_feats,
pad_msa_clusters, pad_msa_clusters,
common_cfg.max_extra_msa, common_cfg.max_extra_msa,
mode_cfg.crop_size, mode_cfg.crop_size,
mode_cfg.max_templates mode_cfg.max_templates,
)) )
)
else: else:
transforms.append( transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates) data_transforms.crop_templates(mode_cfg.max_templates)
...@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def process_tensors_from_config( def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped' tensors, common_cfg, mode_cfg, batch_mode="clamped"
): ):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
...@@ -136,12 +147,10 @@ def process_tensors_from_config( ...@@ -136,12 +147,10 @@ def process_tensors_from_config(
d = data.copy() d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode) fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns) fn = compose(fns)
d['ensemble_index'] = i d["ensemble_index"] = i
return fn(d) return fn(d)
tensors = compose( tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0) tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
...@@ -150,8 +159,9 @@ def process_tensors_from_config( ...@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble *= common_cfg.num_recycle + 1 num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1: if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x), tensors = map_fn(
torch.arange(num_ensemble)) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else: else:
tensors = tree.map_structure(lambda x: x[None], tensors_0) tensors = tree.map_structure(lambda x: x[None], tensors_0)
......
...@@ -90,6 +90,7 @@ class MmcifObject: ...@@ -90,6 +90,7 @@ class MmcifObject:
...}} ...}}
raw_string: The raw string used to construct the MmcifObject. raw_string: The raw string used to construct the MmcifObject.
""" """
file_id: str file_id: str
header: PdbHeader header: PdbHeader
structure: PdbStructure structure: PdbStructure
...@@ -107,6 +108,7 @@ class ParsingResult: ...@@ -107,6 +108,7 @@ class ParsingResult:
parsed. parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated. errors: A dict mapping (file_id, chain_id) to any exception generated.
""" """
mmcif_object: Optional[MmcifObject] mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any] errors: Mapping[Tuple[str, str], Any]
...@@ -115,8 +117,9 @@ class ParseError(Exception): ...@@ -115,8 +117,9 @@ class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed.""" """An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(prefix: str, def mmcif_loop_to_list(
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: prefix: str, parsed_info: MmCIFDict
) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list. """Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF: Reference for loop_ in mmCIF:
...@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str, ...@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str,
data.append(value) data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), ( assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols) "mmCIF error: Not all loops are the same length: %s" % cols
)
return [dict(zip(cols, xs)) for xs in zip(*data)] return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(prefix: str, def mmcif_loop_to_dict(
prefix: str,
index: str, index: str,
parsed_info: MmCIFDict, parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]: ) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary. """Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args: Args:
...@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str, ...@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str,
return {entry[index]: entry for entry in entries} return {entry[index]: entry for entry in entries}
def parse(*, def parse(
file_id: str, *, file_id: str, mmcif_string: str, catch_all_errors: bool = True
mmcif_string: str, ) -> ParsingResult:
catch_all_errors: bool = True) -> ParsingResult:
"""Entry point, parses an mmcif_string. """Entry point, parses an mmcif_string.
Args: Args:
...@@ -188,7 +192,7 @@ def parse(*, ...@@ -188,7 +192,7 @@ def parse(*,
try: try:
parser = PDB.MMCIFParser(QUIET=True) parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string) handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle) full_structure = parser.get_structure("", handle)
first_model_structure = _get_first_model(full_structure) first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not # Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure. # reflected in the Biopython structure.
...@@ -206,9 +210,12 @@ def parse(*, ...@@ -206,9 +210,12 @@ def parse(*,
valid_chains = _get_protein_chains(parsed_info=parsed_info) valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains: if not valid_chains:
return ParsingResult( return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'}) None, {(file_id, ""): "No protein chains found in this file."}
seq_start_num = {chain_id: min([monomer.num for monomer in seq]) )
for chain_id, seq in valid_chains.items()} seq_start_num = {
chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()
}
# Loop over the atoms for which we have coordinates. Populate two mappings: # Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
...@@ -217,34 +224,42 @@ def parse(*, ...@@ -217,34 +224,42 @@ def parse(*,
mmcif_to_author_chain_id = {} mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {} seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info): for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1': if atom.model_num != "1":
# We only process the first model at the moment. # We only process the first model at the moment.
continue continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains: if atom.mmcif_chain_id in valid_chains:
hetflag = ' ' hetflag = " "
if atom.hetatm_atom == 'HETATM': if atom.hetatm_atom == "HETATM":
# Water atoms are assigned a special hetflag of W in Biopython. We # Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch # need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id. # a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'): if atom.residue_name in ("HOH", "WAT"):
hetflag = 'W' hetflag = "W"
else: else:
hetflag = 'H_' + atom.residue_name hetflag = "H_" + atom.residue_name
insertion_code = atom.insertion_code insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code): if not _is_set(atom.insertion_code):
insertion_code = ' ' insertion_code = " "
position = ResiduePosition(chain_id=atom.author_chain_id, position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num), residue_number=int(atom.author_seq_num),
insertion_code=insertion_code) insertion_code=insertion_code,
seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] )
current = seq_to_structure_mappings.get(atom.author_chain_id, {}) seq_idx = (
current[seq_idx] = ResidueAtPosition(position=position, int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
)
current = seq_to_structure_mappings.get(
atom.author_chain_id, {}
)
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name, name=atom.residue_name,
is_missing=False, is_missing=False,
hetflag=hetflag) hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings. # Add missing residue information to seq_to_structure_mappings.
...@@ -253,19 +268,21 @@ def parse(*, ...@@ -253,19 +268,21 @@ def parse(*,
current_mapping = seq_to_structure_mappings[author_chain] current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info): for idx, monomer in enumerate(seq_info):
if idx not in current_mapping: if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(position=None, current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id, name=monomer.id,
is_missing=True, is_missing=True,
hetflag=' ') hetflag=" ",
)
author_chain_to_sequence = {} author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items(): for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id] author_chain = mmcif_to_author_chain_id[chain_id]
seq = [] seq = []
for monomer in seq_info: for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else 'X') seq.append(code if len(code) == 1 else "X")
seq = ''.join(seq) seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject( mmcif_object = MmcifObject(
...@@ -274,11 +291,12 @@ def parse(*, ...@@ -274,11 +291,12 @@ def parse(*,
structure=first_model_structure, structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence, chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings, seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info) raw_string=parsed_info,
)
return ParsingResult(mmcif_object=mmcif_object, errors=errors) return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e errors[(file_id, "")] = e
if not catch_all_errors: if not catch_all_errors:
raise raise
return ParsingResult(mmcif_object=None, errors=errors) return ParsingResult(mmcif_object=None, errors=errors)
...@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure: ...@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure.""" """Returns the first model in a Biopython structure."""
return next(structure.get_models()) return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 _MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str: def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date.""" """Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
return min(revision_dates) return min(revision_dates)
...@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader: ...@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution.""" """Returns a basic header containing method, release date and resolution."""
header = {} header = {}
experiments = mmcif_loop_to_list('_exptl.', parsed_info) experiments = mmcif_loop_to_list("_exptl.", parsed_info)
header['structure_method'] = ','.join([ header["structure_method"] = ",".join(
experiment['_exptl.method'].lower() for experiment in experiments]) [experiment["_exptl.method"].lower() for experiment in experiments]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to # Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date. # use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info: if "_pdbx_audit_revision_history.revision_date" in parsed_info:
header['release_date'] = get_release_date(parsed_info) header["release_date"] = get_release_date(parsed_info)
else: else:
logging.warning('Could not determine release_date: %s', logging.warning(
parsed_info['_entry.id']) "Could not determine release_date: %s", parsed_info["_entry.id"]
)
header['resolution'] = 0.00 header["resolution"] = 0.00
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution', for res_key in (
'_reflns.d_resolution_high'): "_refine.ls_d_res_high",
"_em_3d_reconstruction.resolution",
"_reflns.d_resolution_high",
):
if res_key in parsed_info: if res_key in parsed_info:
try: try:
raw_resolution = parsed_info[res_key][0] raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution) header["resolution"] = float(raw_resolution)
except ValueError: except ValueError:
logging.warning('Invalid resolution format: %s', parsed_info[res_key]) logging.warning(
"Invalid resolution format: %s", parsed_info[res_key]
)
return header return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure.""" """Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension return [
parsed_info['_atom_site.label_comp_id'], AtomSite(*site)
parsed_info['_atom_site.auth_asym_id'], for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_asym_id'], parsed_info["_atom_site.label_comp_id"],
parsed_info['_atom_site.auth_seq_id'], parsed_info["_atom_site.auth_asym_id"],
parsed_info['_atom_site.label_seq_id'], parsed_info["_atom_site.label_asym_id"],
parsed_info['_atom_site.pdbx_PDB_ins_code'], parsed_info["_atom_site.auth_seq_id"],
parsed_info['_atom_site.group_PDB'], parsed_info["_atom_site.label_seq_id"],
parsed_info['_atom_site.pdbx_PDB_model_num'], parsed_info["_atom_site.pdbx_PDB_ins_code"],
)] parsed_info["_atom_site.group_PDB"],
parsed_info["_atom_site.pdbx_PDB_model_num"],
)
]
def _get_protein_chains( def _get_protein_chains(
*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]: *, parsed_info: Mapping[str, Any]
) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only. """Extracts polymer information for protein chains only.
Args: Args:
...@@ -351,26 +381,29 @@ def _get_protein_chains( ...@@ -351,26 +381,29 @@ def _get_protein_chains(
A dict mapping mmcif chain id to a list of Monomers. A dict mapping mmcif chain id to a list of Monomers.
""" """
# Get polymer information for each entity in the structure. # Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
polymers = collections.defaultdict(list) polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs: for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'], Monomer(
num=int(entity_poly_seq['_entity_poly_seq.num']))) id=entity_poly_seq["_entity_poly_seq.mon_id"],
num=int(entity_poly_seq["_entity_poly_seq.num"]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers # Get chemical compositions. Will allow us to identify which of these polymers
# are proteins. # are proteins.
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info) chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
# Get chains information for each entity. Necessary so that we can return a # Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity. # dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list) entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms: for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id'] chain_id = struct_asym["_struct_asym.id"]
entity_id = struct_asym['_struct_asym.entity_id'] entity_id = struct_asym["_struct_asym.entity_id"]
entity_to_mmcif_chains[entity_id].append(chain_id) entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains. # Identify and return the valid protein chains.
...@@ -379,8 +412,12 @@ def _get_protein_chains( ...@@ -379,8 +412,12 @@ def _get_protein_chains(
chain_ids = entity_to_mmcif_chains[entity_id] chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA. # Reject polymers without any peptide-like components, such as DNA/RNA.
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type'] if any(
for monomer in seq_info]): [
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
for monomer in seq_info
]
):
for chain_id in chain_ids: for chain_id in chain_ids:
valid_chains[chain_id] = seq_info valid_chains[chain_id] = seq_info
return valid_chains return valid_chains
...@@ -388,19 +425,18 @@ def _get_protein_chains( ...@@ -388,19 +425,18 @@ def _get_protein_chains(
def _is_set(data: str) -> bool: def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'.""" """Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?') return data not in (".", "?")
def get_atom_coords( def get_atom_coords(
mmcif_object: MmcifObject, mmcif_object: MmcifObject, chain_id: str
chain_id: str
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain # Locate the right chain
chains = list(mmcif_object.structure.get_chains()) chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id] relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1: if len(relevant_chains) != 1:
raise MultipleChainsError( raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.' f"Expected exactly one chain in structure with id {chain_id}."
) )
chain = relevant_chains[0] chain = relevant_chains[0]
...@@ -417,19 +453,23 @@ def get_atom_coords( ...@@ -417,19 +453,23 @@ def get_atom_coords(
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32) mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index] res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing: if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag, res = chain[
(
res_at_position.hetflag,
res_at_position.position.residue_number, res_at_position.position.residue_number,
res_at_position.position.insertion_code)] res_at_position.position.insertion_code,
)
]
for atom in res.get_atoms(): for atom in res.get_atoms():
atom_name = atom.get_name() atom_name = atom.get_name()
x, y, z = atom.get_coord() x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys(): if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z] pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0 mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE': elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
# Put the coords of the selenium atom in the sulphur column # Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z] pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0 mask[residue_constants.atom_order["SD"]] = 1.0
all_atom_positions[res_index] = pos all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask all_atom_mask[res_index] = mask
...@@ -440,22 +480,22 @@ def get_atom_coords( ...@@ -440,22 +480,22 @@ def get_atom_coords(
def generate_mmcif_cache(mmcif_dir: str, out_path: str): def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {} data = {}
for f in os.listdir(mmcif_dir): for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')): if f.endswith(".cif"):
with open(os.path.join(mmcif_dir, f), 'r') as fp: with open(os.path.join(mmcif_dir, f), "r") as fp:
mmcif_string = fp.read() mmcif_string = fp.read()
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string) mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None): if mmcif.mmcif_object is None:
logging.warning(f'Could not parse {f}. Skipping...') logging.warning(f"Could not parse {f}. Skipping...")
continue continue
else: else:
mmcif = mmcif.mmcif_object mmcif = mmcif.mmcif_object
local_data = {} local_data = {}
local_data['release_date'] = mmcif.header["release_date"] local_data["release_date"] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains())) local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data data[file_id] = local_data
with open(out_path, 'w') as fp: with open(out_path, "w") as fp:
fp.write(json.dumps(data)) fp.write(json.dumps(data))
...@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple ...@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateHit: class TemplateHit:
"""Class representing a template hit.""" """Class representing a template hit."""
index: int index: int
name: str name: str
aligned_cols: int aligned_cols: int
...@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index = -1 index = -1
for line in fasta_string.splitlines(): for line in fasta_string.splitlines():
line = line.strip() line = line.strip()
if line.startswith('>'): if line.startswith(">"):
index += 1 index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning. descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('') sequences.append("")
continue continue
elif not line: elif not line:
continue # Skip blank lines. continue # Skip blank lines.
...@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm(stockholm_string: str def parse_stockholm(
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str ...@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence = collections.OrderedDict() name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines(): for line in stockholm_string.splitlines():
line = line.strip() line = line.strip()
if not line or line.startswith(('#', '//')): if not line or line.startswith(("#", "//")):
continue continue
name, sequence = line.split() name, sequence = line.split()
if name not in name_to_sequence: if name not in name_to_sequence:
name_to_sequence[name] = '' name_to_sequence[name] = ""
name_to_sequence[name] += sequence name_to_sequence[name] += sequence
msa = [] msa = []
deletion_matrix = [] deletion_matrix = []
query = '' query = ""
keep_columns = [] keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()): for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0: if seq_index == 0:
# Gather the columns with gaps from the query # Gather the columns with gaps from the query
query = sequence query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-'] keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences. # Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns]) aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence) msa.append(aligned_sequence)
...@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str ...@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec = [] deletion_vec = []
deletion_count = 0 deletion_count = 0
for seq_res, query_res in zip(sequence, query): for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-': if seq_res != "-" or query_res != "-":
if query_res == '-': if query_res == "-":
deletion_count += 1 deletion_count += 1
else: else:
deletion_vec.append(deletion_count) deletion_vec.append(deletion_count)
...@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase) deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap: if is_query_res_non_gap:
yield sequence_res yield sequence_res
elif sequence_res != '-': elif sequence_res != "-":
yield sequence_res.lower() yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str, def convert_stockholm_to_a3m(
max_sequences: Optional[int] = None) -> str: stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format.""" """Converts MSA in Stockholm format to the A3M format."""
descriptions = {} descriptions = {}
sequences = {} sequences = {}
reached_max_sequences = False reached_max_sequences = False
for line in stockholm_format.splitlines(): for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences reached_max_sequences = (
if line.strip() and not line.startswith(('#', '//')): max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment # Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts. # sequence parts.
seqname, aligned_seq = line.split(maxsplit=1) seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences: if seqname not in sequences:
if reached_max_sequences: if reached_max_sequences:
continue continue
sequences[seqname] = '' sequences[seqname] = ""
sequences[seqname] += aligned_seq sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines(): for line in stockholm_format.splitlines():
if line[:4] == '#=GS': if line[:4] == "#=GS":
# Description row - example format is: # Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3) columns = line.split(maxsplit=3)
seqname, feature = columns[1:3] seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else '' value = columns[3] if len(columns) == 4 else ""
if feature != 'DE': if feature != "DE":
continue continue
if reached_max_sequences and seqname not in sequences: if reached_max_sequences and seqname not in sequences:
continue continue
...@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str, ...@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences = {} a3m_sequences = {}
# query_sequence is assumed to be the first sequence # query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values())) query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence] query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items(): for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join( a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" fasta_chunks = (
for k in a3m_sequences) f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups( def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]: regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line) match = re.match(regex_pattern, line)
if match is None: if match is None:
raise RuntimeError(f'Could not parse query line {line}') raise RuntimeError(f"Could not parse query line {line}")
return match.groups() return match.groups()
def _update_hhr_residue_indices_list( def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]): sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence.""" """Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index counter = start_index
for symbol in sequence: for symbol in sequence:
if symbol == '-': if symbol == "-":
indices_list.append(-1) indices_list.append(-1)
else: else:
indices_list.append(counter) indices_list.append(counter)
...@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line. # Parse the summary line.
pattern = ( pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
']*Template_Neff=(.*)') "]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2]) match = re.match(pattern, detailed_lines[2])
if match is None: if match is None:
raise RuntimeError( raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' % "Could not parse section: %s. Expected this: \n%s to contain summary."
(detailed_lines, detailed_lines[2])) % (detailed_lines, detailed_lines[2])
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, )
neff) = [float(x) for x in match.groups()] (prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human # The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to # readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse # assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block. # that with a regexp in order to deduce the fixed length used for that block.
query = '' query = ""
hit_sequence = '' hit_sequence = ""
indices_query = [] indices_query = []
indices_hit = [] indices_hit = []
length_block = None length_block = None
for line in detailed_lines[3:]: for line in detailed_lines[3:]:
# Parse the query sequence line # Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and if (
not line.startswith('Q ss_pred') and line.startswith("Q ")
not line.startswith('Q Consensus')): and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse # Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that. # everything after that.
# start sequence end total_sequence_length # start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:]) groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices, # Get the length of the parsed block using the start and finish indices,
...@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start = int(groups[0]) - 1 # Make index zero based. start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1] delta_query = groups[1]
end = int(groups[2]) end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-']) num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions length_block = end - start + num_insertions
assert length_block == len(delta_query) assert length_block == len(delta_query)
...@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query += delta_query query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query) _update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '): elif line.startswith("T "):
# Parse the hit sequence. # Parse the hit sequence.
if (not line.startswith('T ss_dssp') and if (
not line.startswith('T ss_pred') and not line.startswith("T ss_dssp")
not line.startswith('T Consensus')): and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can # Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that. # parse everything after that.
# start sequence end total_sequence_length # start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:]) groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based. start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1] delta_hit_sequence = groups[1]
...@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list. # Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) _update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit( return TemplateHit(
index=number_of_hit, index=number_of_hit,
...@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: ...@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We # "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit. # iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = [] hits = []
if block_starts: if block_starts:
block_starts.append(len(lines)) # Add the end of the final block. block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1): for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string.""" """Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0} e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != '#'] lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and # space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1). # (5) E-value (full sequence) (numbering from 1).
......
This diff is collapsed.
...@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500 ...@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500
class HHBlits: class HHBlits:
"""Python wrapper of the HHblits binary.""" """Python wrapper of the HHblits binary."""
def __init__(self, def __init__(
self,
*, *,
binary_path: str, binary_path: str,
databases: Sequence[str], databases: Sequence[str],
...@@ -44,7 +45,8 @@ class HHBlits: ...@@ -44,7 +45,8 @@ class HHBlits:
all_seqs: bool = False, all_seqs: bool = False,
alt: Optional[int] = None, alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P, p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z): z: int = _HHBLITS_DEFAULT_Z,
):
"""Initializes the Python HHblits wrapper. """Initializes the Python HHblits wrapper.
Args: Args:
...@@ -77,9 +79,13 @@ class HHBlits: ...@@ -77,9 +79,13 @@ class HHBlits:
self.databases = databases self.databases = databases
for database_path in self.databases: for database_path in self.databases:
if not glob.glob(database_path + '_*'): if not glob.glob(database_path + "_*"):
logging.error('Could not find HHBlits database %s', database_path) logging.error(
raise ValueError(f'Could not find HHBlits database {database_path}') "Could not find HHBlits database %s", database_path
)
raise ValueError(
f"Could not find HHBlits database {database_path}"
)
self.n_cpu = n_cpu self.n_cpu = n_cpu
self.n_iter = n_iter self.n_iter = n_iter
...@@ -95,52 +101,66 @@ class HHBlits: ...@@ -95,52 +101,66 @@ class HHBlits:
def query(self, input_fasta_path: str) -> Mapping[str, Any]: def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using HHblits.""" """Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m') a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = [] db_cmd = []
for db_path in self.databases: for db_path in self.databases:
db_cmd.append('-d') db_cmd.append("-d")
db_cmd.append(db_path) db_cmd.append(db_path)
cmd = [ cmd = [
self.binary_path, self.binary_path,
'-i', input_fasta_path, "-i",
'-cpu', str(self.n_cpu), input_fasta_path,
'-oa3m', a3m_path, "-cpu",
'-o', '/dev/null', str(self.n_cpu),
'-n', str(self.n_iter), "-oa3m",
'-e', str(self.e_value), a3m_path,
'-maxseq', str(self.maxseq), "-o",
'-realign_max', str(self.realign_max), "/dev/null",
'-maxfilt', str(self.maxfilt), "-n",
'-min_prefilter_hits', str(self.min_prefilter_hits)] str(self.n_iter),
"-e",
str(self.e_value),
"-maxseq",
str(self.maxseq),
"-realign_max",
str(self.realign_max),
"-maxfilt",
str(self.maxfilt),
"-min_prefilter_hits",
str(self.min_prefilter_hits),
]
if self.all_seqs: if self.all_seqs:
cmd += ['-all'] cmd += ["-all"]
if self.alt: if self.alt:
cmd += ['-alt', str(self.alt)] cmd += ["-alt", str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P: if self.p != _HHBLITS_DEFAULT_P:
cmd += ['-p', str(self.p)] cmd += ["-p", str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z: if self.z != _HHBLITS_DEFAULT_Z:
cmd += ['-Z', str(self.z)] cmd += ["-Z", str(self.z)]
cmd += db_cmd cmd += db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen( process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing('HHblits query'): with utils.timing("HHblits query"):
stdout, stderr = process.communicate() stdout, stderr = process.communicate()
retcode = process.wait() retcode = process.wait()
if retcode: if retcode:
# Logs have a 15k character limit, so log HHblits error line by line. # Logs have a 15k character limit, so log HHblits error line by line.
logging.error('HHblits failed. HHblits stderr begin:') logging.error("HHblits failed. HHblits stderr begin:")
for error_line in stderr.decode('utf-8').splitlines(): for error_line in stderr.decode("utf-8").splitlines():
if error_line.strip(): if error_line.strip():
logging.error(error_line.strip()) logging.error(error_line.strip())
logging.error('HHblits stderr end') logging.error("HHblits stderr end")
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % ( raise RuntimeError(
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) "HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
)
with open(a3m_path) as f: with open(a3m_path) as f:
a3m = f.read() a3m = f.read()
...@@ -150,5 +170,6 @@ class HHBlits: ...@@ -150,5 +170,6 @@ class HHBlits:
output=stdout, output=stdout,
stderr=stderr, stderr=stderr,
n_iter=self.n_iter, n_iter=self.n_iter,
e_value=self.e_value) e_value=self.e_value,
)
return raw_output return raw_output
...@@ -26,12 +26,14 @@ from openfold.data.np import utils ...@@ -26,12 +26,14 @@ from openfold.data.np import utils
class HHSearch: class HHSearch:
"""Python wrapper of the HHsearch binary.""" """Python wrapper of the HHsearch binary."""
def __init__(self, def __init__(
self,
*, *,
binary_path: str, binary_path: str,
databases: Sequence[str], databases: Sequence[str],
n_cpu: int = 2, n_cpu: int = 2,
maxseq: int = 1_000_000): maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper. """Initializes the Python HHsearch wrapper.
Args: Args:
...@@ -52,41 +54,52 @@ class HHSearch: ...@@ -52,41 +54,52 @@ class HHSearch:
self.maxseq = maxseq self.maxseq = maxseq
for database_path in self.databases: for database_path in self.databases:
if not glob.glob(database_path + '_*'): if not glob.glob(database_path + "_*"):
logging.error('Could not find HHsearch database %s', database_path) logging.error(
raise ValueError(f'Could not find HHsearch database {database_path}') "Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str: def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m.""" """Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m') input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, 'output.hhr') hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, 'w') as f: with open(input_path, "w") as f:
f.write(a3m) f.write(a3m)
db_cmd = [] db_cmd = []
for db_path in self.databases: for db_path in self.databases:
db_cmd.append('-d') db_cmd.append("-d")
db_cmd.append(db_path) db_cmd.append(db_path)
cmd = [self.binary_path, cmd = [
'-i', input_path, self.binary_path,
'-o', hhr_path, "-i",
'-maxseq', str(self.maxseq), input_path,
'-cpu', str(self.n_cpu), "-o",
hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd ] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen( process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
with utils.timing('HHsearch query'): )
with utils.timing("HHsearch query"):
stdout, stderr = process.communicate() stdout, stderr = process.communicate()
retcode = process.wait() retcode = process.wait()
if retcode: if retcode:
# Stderr is truncated to prevent proto size errors in Beam. # Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError( raise RuntimeError(
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % ( "HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) % (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f: with open(hhr_path) as f:
hhr = f.read() hhr = f.read()
......
...@@ -29,7 +29,8 @@ from openfold.data.tools import utils ...@@ -29,7 +29,8 @@ from openfold.data.tools import utils
class Jackhmmer: class Jackhmmer:
"""Python wrapper of the Jackhmmer binary.""" """Python wrapper of the Jackhmmer binary."""
def __init__(self, def __init__(
self,
*, *,
binary_path: str, binary_path: str,
database_path: str, database_path: str,
...@@ -44,7 +45,8 @@ class Jackhmmer: ...@@ -44,7 +45,8 @@ class Jackhmmer:
incdom_e: Optional[float] = None, incdom_e: Optional[float] = None,
dom_e: Optional[float] = None, dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None, num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None): streaming_callback: Optional[Callable[[int], None]] = None,
):
"""Initializes the Python Jackhmmer wrapper. """Initializes the Python Jackhmmer wrapper.
Args: Args:
...@@ -69,9 +71,14 @@ class Jackhmmer: ...@@ -69,9 +71,14 @@ class Jackhmmer:
self.database_path = database_path self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks self.num_streamed_chunks = num_streamed_chunks
if not os.path.exists(self.database_path) and num_streamed_chunks is None: if (
logging.error('Could not find Jackhmmer database %s', database_path) not os.path.exists(self.database_path)
raise ValueError(f'Could not find Jackhmmer database {database_path}') and num_streamed_chunks is None
):
logging.error("Could not find Jackhmmer database %s", database_path)
raise ValueError(
f"Could not find Jackhmmer database {database_path}"
)
self.n_cpu = n_cpu self.n_cpu = n_cpu
self.n_iter = n_iter self.n_iter = n_iter
...@@ -85,11 +92,12 @@ class Jackhmmer: ...@@ -85,11 +92,12 @@ class Jackhmmer:
self.get_tblout = get_tblout self.get_tblout = get_tblout
self.streaming_callback = streaming_callback self.streaming_callback = streaming_callback
def _query_chunk(self, input_fasta_path: str, database_path: str def _query_chunk(
self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer.""" """Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto') sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering # The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these # stages (which get progressively more expensive), reducing these
...@@ -98,48 +106,63 @@ class Jackhmmer: ...@@ -98,48 +106,63 @@ class Jackhmmer:
# amount of time. # amount of time.
cmd_flags = [ cmd_flags = [
# Don't pollute stdout with Jackhmmer output. # Don't pollute stdout with Jackhmmer output.
'-o', '/dev/null', "-o",
'-A', sto_path, "/dev/null",
'--noali', "-A",
'--F1', str(self.filter_f1), sto_path,
'--F2', str(self.filter_f2), "--noali",
'--F3', str(self.filter_f3), "--F1",
'--incE', str(self.e_value), str(self.filter_f1),
"--F2",
str(self.filter_f2),
"--F3",
str(self.filter_f3),
"--incE",
str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output. # Report only sequences with E-values <= x in per-sequence output.
'-E', str(self.e_value), "-E",
'--cpu', str(self.n_cpu), str(self.e_value),
'-N', str(self.n_iter) "--cpu",
str(self.n_cpu),
"-N",
str(self.n_iter),
] ]
if self.get_tblout: if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
cmd_flags.extend(['--tblout', tblout_path]) cmd_flags.extend(["--tblout", tblout_path])
if self.z_value: if self.z_value:
cmd_flags.extend(['-Z', str(self.z_value)]) cmd_flags.extend(["-Z", str(self.z_value)])
if self.dom_e is not None: if self.dom_e is not None:
cmd_flags.extend(['--domE', str(self.dom_e)]) cmd_flags.extend(["--domE", str(self.dom_e)])
if self.incdom_e is not None: if self.incdom_e is not None:
cmd_flags.extend(['--incdomE', str(self.incdom_e)]) cmd_flags.extend(["--incdomE", str(self.incdom_e)])
cmd = [self.binary_path] + cmd_flags + [input_fasta_path, cmd = (
database_path] [self.binary_path]
+ cmd_flags
+ [input_fasta_path, database_path]
)
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen( process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing( with utils.timing(
f'Jackhmmer ({os.path.basename(database_path)}) query'): f"Jackhmmer ({os.path.basename(database_path)}) query"
):
_, stderr = process.communicate() _, stderr = process.communicate()
retcode = process.wait() retcode = process.wait()
if retcode: if retcode:
raise RuntimeError( raise RuntimeError(
'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8')) "Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
)
# Get e-values for each target name # Get e-values for each target name
tbl = '' tbl = ""
if self.get_tblout: if self.get_tblout:
with open(tblout_path) as f: with open(tblout_path) as f:
tbl = f.read() tbl = f.read()
...@@ -152,7 +175,8 @@ class Jackhmmer: ...@@ -152,7 +175,8 @@ class Jackhmmer:
tbl=tbl, tbl=tbl,
stderr=stderr, stderr=stderr,
n_iter=self.n_iter, n_iter=self.n_iter,
e_value=self.e_value) e_value=self.e_value,
)
return raw_output return raw_output
...@@ -162,15 +186,15 @@ class Jackhmmer: ...@@ -162,15 +186,15 @@ class Jackhmmer:
return [self._query_chunk(input_fasta_path, self.database_path)] return [self._query_chunk(input_fasta_path, self.database_path)]
db_basename = os.path.basename(self.database_path) db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
# Remove existing files to prevent OOM # Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk('[0-9]*')): for f in glob.glob(db_local_chunk("[0-9]*")):
try: try:
os.remove(f) os.remove(f)
except OSError: except OSError:
print(f'OSError while deleting {f}') print(f"OSError while deleting {f}")
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor: with futures.ThreadPoolExecutor(max_workers=2) as executor:
...@@ -179,15 +203,22 @@ class Jackhmmer: ...@@ -179,15 +203,22 @@ class Jackhmmer:
# Copy the chunk locally # Copy the chunk locally
if i == 1: if i == 1:
future = executor.submit( future = executor.submit(
request.urlretrieve, db_remote_chunk(i), db_local_chunk(i)) request.urlretrieve,
db_remote_chunk(i),
db_local_chunk(i),
)
if i < self.num_streamed_chunks: if i < self.num_streamed_chunks:
next_future = executor.submit( next_future = executor.submit(
request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1)) request.urlretrieve,
db_remote_chunk(i + 1),
db_local_chunk(i + 1),
)
# Run Jackhmmer with the chunk # Run Jackhmmer with the chunk
future.result() future.result()
chunked_output.append( chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i))) self._query_chunk(input_fasta_path, db_local_chunk(i))
)
# Remove the local copy of the chunk # Remove the local copy of the chunk
os.remove(db_local_chunk(i)) os.remove(db_local_chunk(i))
......
...@@ -25,12 +25,12 @@ from openfold.data.tools import utils ...@@ -25,12 +25,12 @@ from openfold.data.tools import utils
def _to_a3m(sequences: Sequence[str]) -> str: def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file.""" """Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
a3m = [] a3m = []
for sequence, name in zip(sequences, names): for sequence, name in zip(sequences, names):
a3m.append(u'>' + name + u'\n') a3m.append(u">" + name + u"\n")
a3m.append(sequence + u'\n') a3m.append(sequence + u"\n")
return ''.join(a3m) return "".join(a3m)
class Kalign: class Kalign:
...@@ -63,40 +63,51 @@ class Kalign: ...@@ -63,40 +63,51 @@ class Kalign:
RuntimeError: If Kalign fails. RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long. ValueError: If any of the sequences is less than 6 residues long.
""" """
logging.info('Aligning %d sequences', len(sequences)) logging.info("Aligning %d sequences", len(sequences))
for s in sequences: for s in sequences:
if len(s) < 6: if len(s) < 6:
raise ValueError('Kalign requires all sequences to be at least 6 ' raise ValueError(
'residues long. Got %s (%d residues).' % (s, len(s))) "Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir: with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
with open(input_fasta_path, 'w') as f: with open(input_fasta_path, "w") as f:
f.write(_to_a3m(sequences)) f.write(_to_a3m(sequences))
cmd = [ cmd = [
self.binary_path, self.binary_path,
'-i', input_fasta_path, "-i",
'-o', output_a3m_path, input_fasta_path,
'-format', 'fasta', "-o",
output_a3m_path,
"-format",
"fasta",
] ]
logging.info('Launching subprocess "%s"', ' '.join(cmd)) logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, process = subprocess.Popen(
stderr=subprocess.PIPE) cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing('Kalign query'): with utils.timing("Kalign query"):
stdout, stderr = process.communicate() stdout, stderr = process.communicate()
retcode = process.wait() retcode = process.wait()
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n', logging.info(
stdout.decode('utf-8'), stderr.decode('utf-8')) "Kalign stdout:\n%s\n\nstderr:\n%s\n",
stdout.decode("utf-8"),
stderr.decode("utf-8"),
)
if retcode: if retcode:
raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' raise RuntimeError(
% (stdout.decode('utf-8'), stderr.decode('utf-8'))) "Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(output_a3m_path) as f: with open(output_a3m_path) as f:
a3m = f.read() a3m = f.read()
......
...@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None): ...@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None):
@contextlib.contextmanager @contextlib.contextmanager
def timing(msg: str): def timing(msg: str):
logging.info('Started %s', msg) logging.info("Started %s", msg)
tic = time.time() tic = time.time()
yield yield
toc = time.time() toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic) logging.info("Finished %s in %.3f seconds", msg, toc - tic)
def to_date(s: str): def to_date(s: str):
......
...@@ -3,13 +3,14 @@ import glob ...@@ -3,13 +3,14 @@ import glob
import importlib as importlib import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")] __all__ = [
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__] os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules: for _m in _modules:
globals()[_m[0]] = _m[1] globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace # Avoid needlessly cluttering the global namespace
del _files, _m, _modules del _files, _m, _modules
...@@ -26,6 +26,7 @@ class Dropout(nn.Module): ...@@ -26,6 +26,7 @@ class Dropout(nn.Module):
If not in training mode, this module computes the identity function. If not in training mode, this module computes the identity function.
""" """
def __init__(self, r: float, batch_dim: Union[int, List[int]]): def __init__(self, r: float, batch_dim: Union[int, List[int]]):
""" """
Args: Args:
...@@ -37,7 +38,7 @@ class Dropout(nn.Module): ...@@ -37,7 +38,7 @@ class Dropout(nn.Module):
super(Dropout, self).__init__() super(Dropout, self).__init__()
self.r = r self.r = r
if(type(batch_dim) == int): if type(batch_dim) == int:
batch_dim = [batch_dim] batch_dim = [batch_dim]
self.batch_dim = batch_dim self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r) self.dropout = nn.Dropout(self.r)
...@@ -50,7 +51,7 @@ class Dropout(nn.Module): ...@@ -50,7 +51,7 @@ class Dropout(nn.Module):
compatible with self.batch_dim compatible with self.batch_dim
""" """
shape = list(x.shape) shape = list(x.shape)
if(self.batch_dim is not None): if self.batch_dim is not None:
for bd in self.batch_dim: for bd in self.batch_dim:
shape[bd] = 1 shape[bd] = 1
mask = x.new_ones(shape) mask = x.new_ones(shape)
...@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout): ...@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout):
Convenience class for rowwise dropout as described in subsection Convenience class for rowwise dropout as described in subsection
1.11.6. 1.11.6.
""" """
__init__ = partialmethod(Dropout.__init__, batch_dim=-3) __init__ = partialmethod(Dropout.__init__, batch_dim=-3)
...@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout): ...@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout):
Convenience class for columnwise dropout as described in subsection Convenience class for columnwise dropout as described in subsection
1.11.6. 1.11.6.
""" """
__init__ = partialmethod(Dropout.__init__, batch_dim=-2) __init__ = partialmethod(Dropout.__init__, batch_dim=-2)
...@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module): ...@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module):
Implements Algorithms 3 (InputEmbedder) and 4 (relpos). Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
""" """
def __init__( def __init__(
self, self,
tf_dim: int, tf_dim: int,
...@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module): ...@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module):
self.no_bins = 2 * relpos_k + 1 self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z) self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, def relpos(self, ri: torch.Tensor):
ri: torch.Tensor
):
""" """
Computes relative positional encodings Computes relative positional encodings
...@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module): ...@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module):
oh = one_hot(d, boundaries).type(ri.dtype) oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh) return self.linear_relpos(oh)
def forward(self, def forward(
self,
tf: torch.Tensor, tf: torch.Tensor,
ri: torch.Tensor, ri: torch.Tensor,
msa: torch.Tensor, msa: torch.Tensor,
...@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module): ...@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32. Implements Algorithm 32.
""" """
def __init__(self,
def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
min_bin: float, min_bin: float,
max_bin: float, max_bin: float,
no_bins: int, no_bins: int,
inf: float = 1e8, inf: float = 1e8,
**kwargs **kwargs,
): ):
""" """
Args: Args:
...@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module): ...@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module):
self.layer_norm_m = nn.LayerNorm(self.c_m) self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z) self.layer_norm_z = nn.LayerNorm(self.c_z)
def forward(self, def forward(
self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
...@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module): ...@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module):
z: z:
[*, N_res, N_res, C_z] pair embedding update [*, N_res, N_res, C_z] pair embedding update
""" """
if(self.bins is None): if self.bins is None:
self.bins = torch.linspace( self.bins = torch.linspace(
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
self.no_bins, self.no_bins,
dtype=x.dtype, dtype=x.dtype,
device=x.device device=x.device,
) )
# [*, N, C_m] # [*, N, C_m]
...@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module): ...@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module):
# couldn't find in time. # couldn't find in time.
squared_bins = self.bins ** 2 squared_bins = self.bins ** 2
upper = torch.cat( upper = torch.cat(
[ [squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
squared_bins[1:],
squared_bins.new_tensor([self.inf])
], dim=-1
) )
d = torch.sum( d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, (x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
dim=-1,
keepdims=True
) )
# [*, N, N, no_bins] # [*, N, N, no_bins]
...@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module): ...@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module):
Implements Algorithm 2, line 7. Implements Algorithm 2, line 7.
""" """
def __init__(self,
def __init__(
self,
c_in: int, c_in: int,
c_out: int, c_out: int,
**kwargs, **kwargs,
...@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module): ...@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu") self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self, def forward(self, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor
) -> torch.Tensor:
""" """
Args: Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features x: [*, N_templ, N_res, c_in] "template_angle_feat" features
...@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module): ...@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module):
Implements Algorithm 2, line 9. Implements Algorithm 2, line 9.
""" """
def __init__(self,
def __init__(
self,
c_in: int, c_in: int,
c_out: int, c_out: int,
**kwargs, **kwargs,
...@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module): ...@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module):
# Despite there being no relu nearby, the source uses that initializer # Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu") self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(self, def forward(
self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15 Implements Algorithm 2, line 15
""" """
def __init__(self,
def __init__(
self,
c_in: int, c_in: int,
c_out: int, c_out: int,
**kwargs, **kwargs,
...@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module):
self.linear = Linear(self.c_in, self.c_out) self.linear = Linear(self.c_in, self.c_out)
def forward(self, def forward(self, x: torch.Tensor) -> torch.Tensor:
x: torch.Tensor
) -> torch.Tensor:
""" """
Args: Args:
x: x:
......
...@@ -45,6 +45,7 @@ class MSATransition(nn.Module): ...@@ -45,6 +45,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9 Implements Algorithm 9
""" """
def __init__(self, c_m, n, chunk_size): def __init__(self, c_m, n, chunk_size):
""" """
Args: Args:
...@@ -71,7 +72,8 @@ class MSATransition(nn.Module): ...@@ -71,7 +72,8 @@ class MSATransition(nn.Module):
m = self.linear_2(m) * mask m = self.linear_2(m) * mask
return m return m
def forward(self, def forward(
self,
m: torch.Tensor, m: torch.Tensor,
mask: torch.Tensor = None, mask: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -86,7 +88,7 @@ class MSATransition(nn.Module): ...@@ -86,7 +88,7 @@ class MSATransition(nn.Module):
[*, N_seq, N_res, C_m] MSA activation update [*, N_seq, N_res, C_m] MSA activation update
""" """
# DISCREPANCY: DeepMind forgets to apply the MSA mask here. # DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if(mask is None): if mask is None:
mask = m.new_ones(m.shape[:-1]) mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
...@@ -94,7 +96,7 @@ class MSATransition(nn.Module): ...@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
m = self.layer_norm(m) m = self.layer_norm(m)
inp = {"m": m, "mask": mask} inp = {"m": m, "mask": mask}
if(self.chunk_size is not None): if self.chunk_size is not None:
m = chunk_layer( m = chunk_layer(
self._transition, self._transition,
inp, inp,
...@@ -108,7 +110,8 @@ class MSATransition(nn.Module): ...@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
class EvoformerBlock(nn.Module): class EvoformerBlock(nn.Module):
def __init__(self, def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module): ...@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
inf=inf, inf=inf,
) )
if(_is_extra_msa_stack): if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention( self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m, c_in=c_m,
c_hidden=c_hidden_msa_att, c_hidden=c_hidden_msa_att,
...@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module): ...@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module):
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout) self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout) self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(self, def forward(
self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
...@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module): ...@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module):
Implements Algorithm 6. Implements Algorithm 6.
""" """
def __init__(self,
def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module): ...@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module):
) )
self.blocks.append(block) self.blocks.append(block)
if(not self._is_extra_msa_stack): if not self._is_extra_msa_stack:
self.linear = Linear(c_m, c_s) self.linear = Linear(c_m, c_s)
def forward(self, def forward(
self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: torch.Tensor, msa_mask: torch.Tensor,
...@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module): ...@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=_mask_trans, _mask_trans=_mask_trans,
) for b in self.blocks )
for b in self.blocks
], ],
args=(m, z), args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
) )
s = None s = None
if(not self._is_extra_msa_stack): if not self._is_extra_msa_stack:
seq_dim = -3 seq_dim = -3
index = torch.tensor([0], device=m.device) index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
...@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module): ...@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module):
""" """
Implements Algorithm 18. Implements Algorithm 18.
""" """
def __init__(self,
def __init__(
self,
c_m: int, c_m: int,
c_z: int, c_z: int,
c_hidden_msa_att: int, c_hidden_msa_att: int,
...@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module): ...@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module):
_is_extra_msa_stack=True, _is_extra_msa_stack=True,
) )
def forward(self, def forward(
self,
m: torch.Tensor, m: torch.Tensor,
z: torch.Tensor, z: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None, msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None, pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True _mask_trans: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
...@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module): ...@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module):
z, z,
msa_mask=msa_mask, msa_mask=msa_mask,
pair_mask=pair_mask, pair_mask=pair_mask,
_mask_trans=_mask_trans _mask_trans=_mask_trans,
) )
return z return z
...@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module): ...@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
**config["experimentally_resolved"], **config["experimentally_resolved"],
) )
if(config.tm.enabled): if config.tm.enabled:
self.tm = TMScoreHead( self.tm = TMScoreHead(
**config.tm, **config.tm,
) )
...@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module): ...@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits = self.experimentally_resolved( experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"] outputs["single"]
) )
aux_out["experimentally_resolved_logits"] = ( aux_out[
experimentally_resolved_logits "experimentally_resolved_logits"
) ] = experimentally_resolved_logits
if(self.config.tm.enabled): if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm( aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm tm_logits, **self.config.tm
) )
aux_out.update(compute_predicted_aligned_error( aux_out.update(
tm_logits, **self.config.tm, compute_predicted_aligned_error(
)) tm_logits,
**self.config.tm,
)
)
return aux_out return aux_out
...@@ -118,6 +121,7 @@ class DistogramHead(nn.Module): ...@@ -118,6 +121,7 @@ class DistogramHead(nn.Module):
For use in computation of distogram loss, subsection 1.9.8 For use in computation of distogram loss, subsection 1.9.8
""" """
def __init__(self, c_z, no_bins, **kwargs): def __init__(self, c_z, no_bins, **kwargs):
""" """
Args: Args:
...@@ -133,9 +137,7 @@ class DistogramHead(nn.Module): ...@@ -133,9 +137,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final") self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, def forward(self, z): # [*, N, N, C_z]
z # [*, N, N, C_z]
):
""" """
Args: Args:
z: z:
...@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module): ...@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module):
""" """
For use in computation of TM-score, subsection 1.9.7 For use in computation of TM-score, subsection 1.9.7
""" """
def __init__(self, c_z, no_bins, **kwargs): def __init__(self, c_z, no_bins, **kwargs):
""" """
Args: Args:
...@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module): ...@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module):
""" """
For use in computation of masked MSA loss, subsection 1.9.9 For use in computation of masked MSA loss, subsection 1.9.9
""" """
def __init__(self, c_m, c_out, **kwargs): def __init__(self, c_m, c_out, **kwargs):
""" """
Args: Args:
...@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module): ...@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module):
For use in computation of "experimentally resolved" loss, subsection For use in computation of "experimentally resolved" loss, subsection
1.9.10 1.9.10
""" """
def __init__(self, c_s, c_out, **kwargs): def __init__(self, c_s, c_out, **kwargs):
""" """
Args: Args:
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment