Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -27,76 +27,79 @@ from openfold.np import residue_constants ...@@ -27,76 +27,79 @@ from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features( def make_sequence_features(
sequence: str, sequence: str, description: str, num_res: int
description: str,
num_res: int
) -> FeatureDict: ) -> FeatureDict:
"""Construct a feature dict of sequence features.""" """Construct a feature dict of sequence features."""
features = {} features = {}
features['aatype'] = residue_constants.sequence_to_onehot( features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence, sequence=sequence,
mapping=residue_constants.restype_order_with_x, mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True map_unknown_to_x=True,
) )
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array( features["domain_name"] = np.array(
[description.encode('utf-8')], dtype=np.object_ [description.encode("utf-8")], dtype=np.object_
) )
features['residue_index'] = np.array(range(num_res), dtype=np.int32) features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array( features["sequence"] = np.array(
[sequence.encode('utf-8')], dtype=np.object_ [sequence.encode("utf-8")], dtype=np.object_
) )
return features return features
def make_mmcif_features( def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
chain_id: str
) -> FeatureDict: ) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id] input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id]) description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence) num_res = len(input_sequence)
mmcif_feats = {} mmcif_feats = {}
mmcif_feats.update(make_sequence_features( mmcif_feats.update(
sequence=input_sequence, make_sequence_features(
description=description, sequence=input_sequence,
num_res=num_res, description=description,
)) num_res=num_res,
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords( all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id mmcif_object=mmcif_object, chain_id=chain_id
) )
mmcif_feats["all_atom_positions"] = all_atom_positions mmcif_feats["all_atom_positions"] = all_atom_positions
mmcif_feats["all_atom_mask"] = all_atom_mask mmcif_feats["all_atom_mask"] = all_atom_mask
mmcif_feats["resolution"] = np.array( mmcif_feats["resolution"] = np.array(
[mmcif_object.header["resolution"]], dtype=np.float32 [mmcif_object.header["resolution"]], dtype=np.float32
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
) )
return mmcif_feats return mmcif_feats
def make_msa_features( def make_msa_features(
msas: Sequence[Sequence[str]], msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict: deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features.""" """Constructs a feature dict of MSA features."""
if not msas: if not msas:
raise ValueError('At least one MSA must be provided.') raise ValueError("At least one MSA must be provided.")
int_msa = [] int_msa = []
deletion_matrix = [] deletion_matrix = []
seen_sequences = set() seen_sequences = set()
for msa_index, msa in enumerate(msas): for msa_index, msa in enumerate(msas):
if not msa: if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.') raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa): for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences: if sequence in seen_sequences:
continue continue
...@@ -109,30 +112,32 @@ def make_msa_features( ...@@ -109,30 +112,32 @@ def make_msa_features(
num_res = len(msas[0][0]) num_res = len(msas[0][0])
num_alignments = len(int_msa) num_alignments = len(int_msa)
features = {} features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32) features["msa"] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array( features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32 [num_alignments] * num_res, dtype=np.int32
) )
return features return features
class AlignmentRunner: class AlignmentRunner:
""" Runs alignment tools and saves the results """ """Runs alignment tools and saves the results"""
def __init__(self,
jackhmmer_binary_path: str, def __init__(
hhblits_binary_path: str, self,
hhsearch_binary_path: str, jackhmmer_binary_path: str,
uniref90_database_path: str, hhblits_binary_path: str,
mgnify_database_path: str, hhsearch_binary_path: str,
bfd_database_path: Optional[str], uniref90_database_path: str,
uniclust30_database_path: Optional[str], mgnify_database_path: str,
small_bfd_database_path: Optional[str], bfd_database_path: Optional[str],
pdb70_database_path: str, uniclust30_database_path: Optional[str],
use_small_bfd: bool, small_bfd_database_path: Optional[str],
no_cpus: int, pdb70_database_path: str,
uniref_max_hits: int = 10000, use_small_bfd: bool,
mgnify_max_hits: int = 5000, no_cpus: int,
uniref_max_hits: int = 10000,
mgnify_max_hits: int = 5000,
): ):
self._use_small_bfd = use_small_bfd self._use_small_bfd = use_small_bfd
self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer(
...@@ -161,115 +166,120 @@ class AlignmentRunner: ...@@ -161,115 +166,120 @@ class AlignmentRunner:
) )
self.hhsearch_pdb70_runner = hhsearch.HHSearch( self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path, binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
databases=[pdb70_database_path]
) )
self.uniref_max_hits = uniref_max_hits self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits self.mgnify_max_hits = mgnify_max_hits
def run(self, def run(
self,
fasta_path: str, fasta_path: str,
output_dir: str, output_dir: str,
): ):
"""Runs alignment tools on a sequence""" """Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0] jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
) )
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m') uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, 'w') as f: with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m) f.write(uniref90_msa_as_a3m)
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0] jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m( mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
) )
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m') mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, 'w') as f: with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m) f.write(mgnify_msa_as_a3m)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m) hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr') pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, 'w') as f: with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result) f.write(hhsearch_result)
if self._use_small_bfd: if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0] jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto') fasta_path
with open(bfd_out_path, 'w') as f: )[0]
f.write(jackhmmer_small_bfd_result['sto']) bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
else: else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path) hhblits_bfd_uniclust_result = (
if(output_dir is not None): self.hhblits_bfd_uniclust_runner.query(fasta_path)
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m') )
with open(bfd_out_path, 'w') as f: if output_dir is not None:
f.write(hhblits_bfd_uniclust_result['a3m']) bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline: class DataPipeline:
"""Assembles input features.""" """Assembles input features."""
def __init__(self,
template_featurizer: templates.TemplateHitFeaturizer, def __init__(
use_small_bfd: bool, self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
): ):
self.template_featurizer = template_featurizer self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self, def _parse_alignment_output(
self,
alignment_dir: str, alignment_dir: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m') uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, 'r') as f: with open(uniref90_out_path, "r") as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m( uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m') mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, 'r') as f: with open(mgnify_out_path, "r") as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m( mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr') pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, 'r') as f: with open(pdb70_out_path, "r") as f:
hhsearch_hits = parsers.parse_hhr( hhsearch_hits = parsers.parse_hhr(f.read())
f.read()
)
if(self.use_small_bfd): if self.use_small_bfd:
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto') bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, 'r') as f: with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm( bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read() f.read()
) )
else: else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m') bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, 'r') as f: with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m( bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
f.read()
)
return { return {
'uniref90_msa': uniref90_msa, "uniref90_msa": uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix, "uniref90_deletion_matrix": uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa, "mgnify_msa": mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix, "mgnify_deletion_matrix": mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits, "hhsearch_hits": hhsearch_hits,
'bfd_msa': bfd_msa, "bfd_msa": bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix, "bfd_deletion_matrix": bfd_deletion_matrix,
} }
def process_fasta(self, def process_fasta(
self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
fasta_str = f.read() fasta_str = f.read()
input_seqs, input_descs = parsers.parse_fasta(fasta_str) input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1: if len(input_seqs) != 1:
raise ValueError( raise ValueError(
f'More than one input sequence found in {fasta_path}.') f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0] input_sequence = input_seqs[0]
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
...@@ -280,47 +290,46 @@ class DataPipeline: ...@@ -280,47 +290,46 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hits=alignments['hhsearch_hits'] hits=alignments["hhsearch_hits"],
) )
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
sequence=input_sequence, sequence=input_sequence,
description=input_description, description=input_description,
num_res=num_res num_res=num_res,
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=( msas=(
alignments['uniref90_msa'], alignments["uniref90_msa"],
alignments['bfd_msa'], alignments["bfd_msa"],
alignments['mgnify_msa'] alignments["mgnify_msa"],
), ),
deletion_matrices=( deletion_matrices=(
alignments['uniref90_deletion_matrix'], alignments["uniref90_deletion_matrix"],
alignments['bfd_deletion_matrix'], alignments["bfd_deletion_matrix"],
alignments['mgnify_deletion_matrix'] alignments["mgnify_deletion_matrix"],
) ),
) )
return {**sequence_features, **msa_features, **templates_result.data} return {**sequence_features, **msa_features, **templates_result.data}
def process_mmcif(self, def process_mmcif(
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
If chain_id is None, it is assumed that there is only one chain If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown. in the object. Otherwise, a ValueError is thrown.
""" """
if(chain_id is None): if chain_id is None:
chains = mmcif.structure.get_chains() chains = mmcif.structure.get_chains()
chain = next(chains, None) chain = next(chains, None)
if(chain is None): if chain is None:
raise ValueError( raise ValueError("No chains in mmCIF file")
'No chains in mmCIF file'
)
chain_id = chain.id chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
...@@ -332,20 +341,20 @@ class DataPipeline: ...@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]), query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits'] hits=alignments["hhsearch_hits"],
) )
msa_features = make_msa_features( msa_features = make_msa_features(
msas=( msas=(
alignments['uniref90_msa'], alignments["uniref90_msa"],
alignments['bfd_msa'], alignments["bfd_msa"],
alignments['mgnify_msa'] alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
), ),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
) )
return {**mmcif_feats, **templates_result.data, **msa_features} return {**mmcif_feats, **templates_result.data, **msa_features}
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -26,10 +26,11 @@ from openfold.data import input_pipeline ...@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray] FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor] TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict( def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray], np_example: Mapping[str, np.ndarray],
features: Sequence[str], features: Sequence[str],
) -> TensorDict: ) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays. """Creates dict of tensors from a dict of NumPy arrays.
Args: Args:
...@@ -47,14 +48,14 @@ def np_to_tensor_dict( ...@@ -47,14 +48,14 @@ def np_to_tensor_dict(
def make_data_config( def make_data_config(
config: ml_collections.ConfigDict, config: ml_collections.ConfigDict,
mode: str, mode: str,
num_res: int, num_res: int,
) -> Tuple[ml_collections.ConfigDict, List[str]]: ) -> Tuple[ml_collections.ConfigDict, List[str]]:
cfg = copy.deepcopy(config) cfg = copy.deepcopy(config)
mode_cfg = cfg[mode] mode_cfg = cfg[mode]
with cfg.unlocked(): with cfg.unlocked():
if(mode_cfg.crop_size is None): if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
...@@ -62,7 +63,7 @@ def make_data_config( ...@@ -62,7 +63,7 @@ def make_data_config(
if cfg.common.use_templates: if cfg.common.use_templates:
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
if(cfg[mode].supervised): if cfg[mode].supervised:
feature_names += cfg.common.supervised_features feature_names += cfg.common.supervised_features
return cfg, feature_names return cfg, feature_names
...@@ -75,47 +76,47 @@ def np_example_to_features( ...@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode: str, batch_mode: str,
): ):
np_example = dict(np_example) np_example = dict(np_example)
num_res = int(np_example['seq_length'][0]) num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config( cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
config, mode=mode, num_res=num_res
)
if 'deletion_matrix_int' in np_example: if "deletion_matrix_int" in np_example:
np_example['deletion_matrix'] = ( np_example["deletion_matrix"] = np_example.pop(
np_example.pop('deletion_matrix_int').astype(np.float32) "deletion_matrix_int"
) ).astype(np.float32)
if batch_mode == 'clamped': if batch_mode == "clamped":
np_example['use_clamped_fape'] = ( np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
np.array(1.).astype(np.float32) elif batch_mode == "unclamped":
) np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
tensor_dict = np_to_tensor_dict( tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names np_example=np_example, features=feature_names
) )
with torch.no_grad(): with torch.no_grad():
features = input_pipeline.process_tensors_from_config( features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode, tensor_dict,
cfg.common,
cfg[mode],
batch_mode=batch_mode,
) )
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
class FeaturePipeline: class FeaturePipeline:
def __init__(self, def __init__(
config: ml_collections.ConfigDict, self,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None): config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
):
self.config = config self.config = config
self.params = params self.params = params
def process_features(self, def process_features(
self,
raw_features: FeatureDict, raw_features: FeatureDict,
mode: str = 'train', mode: str = "train",
batch_mode: str = 'clamped', batch_mode: str = "clamped",
) -> FeatureDict: ) -> FeatureDict:
return np_example_to_features( return np_example_to_features(
np_example=raw_features, np_example=raw_features,
......
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_hhblits_profile, data_transforms.make_hhblits_profile,
] ]
if common_cfg.use_templates: if common_cfg.use_templates:
transforms.extend([ transforms.extend(
data_transforms.fix_templates_aatype, [
data_transforms.make_template_mask, data_transforms.fix_templates_aatype,
data_transforms.make_pseudo_beta('template_') data_transforms.make_template_mask,
]) data_transforms.make_pseudo_beta("template_"),
if(common_cfg.use_template_torsion_angles): ]
transforms.extend([ )
data_transforms.atom37_to_torsion_angles('template_'), if common_cfg.use_template_torsion_angles:
]) transforms.extend(
[
transforms.extend([ data_transforms.atom37_to_torsion_angles("template_"),
data_transforms.make_atom14_masks, ]
]) )
if(mode_cfg.supervised): transforms.extend(
transforms.extend([ [
data_transforms.make_atom14_positions, data_transforms.make_atom14_masks,
data_transforms.atom37_to_frames, ]
data_transforms.atom37_to_torsion_angles(''), )
data_transforms.make_pseudo_beta(''),
data_transforms.get_backbone_frames, if mode_cfg.supervised:
data_transforms.get_chi_angles, transforms.extend(
]) [
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
]
)
return transforms return transforms
...@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms.sample_msa(max_msa_clusters, keep_extra=True) data_transforms.sample_msa(max_msa_clusters, keep_extra=True)
) )
if 'masked_msa' in common_cfg: if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that # Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about # the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations. # the masked locations and secret corrupted locations.
transforms.append( transforms.append(
data_transforms.make_masked_msa( data_transforms.make_masked_msa(
common_cfg.masked_msa, common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
mode_cfg.masked_msa_replace_fraction
) )
) )
...@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if mode_cfg.fixed_size: if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats))) transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size( transforms.append(
mode_cfg.crop_size, data_transforms.random_crop_to_size(
mode_cfg.max_templates, mode_cfg.crop_size,
crop_feats, mode_cfg.max_templates,
mode_cfg.subsample_templates, crop_feats,
batch_mode=batch_mode, mode_cfg.subsample_templates,
seed=torch.Generator().seed() batch_mode=batch_mode,
)) seed=torch.Generator().seed(),
transforms.append(data_transforms.make_fixed_size( )
crop_feats, )
pad_msa_clusters, transforms.append(
common_cfg.max_extra_msa, data_transforms.make_fixed_size(
mode_cfg.crop_size, crop_feats,
mode_cfg.max_templates pad_msa_clusters,
)) common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates,
)
)
else: else:
transforms.append( transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates) data_transforms.crop_templates(mode_cfg.max_templates)
...@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode): ...@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def process_tensors_from_config( def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped' tensors, common_cfg, mode_cfg, batch_mode="clamped"
): ):
"""Based on the config, apply filters and transformations to the data.""" """Based on the config, apply filters and transformations to the data."""
...@@ -136,12 +147,10 @@ def process_tensors_from_config( ...@@ -136,12 +147,10 @@ def process_tensors_from_config(
d = data.copy() d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode) fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns) fn = compose(fns)
d['ensemble_index'] = i d["ensemble_index"] = i
return fn(d) return fn(d)
tensors = compose( tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0) tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble num_ensemble = mode_cfg.num_ensemble
...@@ -150,8 +159,9 @@ def process_tensors_from_config( ...@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble *= common_cfg.num_recycle + 1 num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1: if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x), tensors = map_fn(
torch.arange(num_ensemble)) lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else: else:
tensors = tree.map_structure(lambda x: x[None], tensors_0) tensors = tree.map_structure(lambda x: x[None], tensors_0)
......
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple ...@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]] DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TemplateHit: class TemplateHit:
"""Class representing a template hit.""" """Class representing a template hit."""
index: int index: int
name: str name: str
aligned_cols: int aligned_cols: int
...@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index = -1 index = -1
for line in fasta_string.splitlines(): for line in fasta_string.splitlines():
line = line.strip() line = line.strip()
if line.startswith('>'): if line.startswith(">"):
index += 1 index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning. descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('') sequences.append("")
continue continue
elif not line: elif not line:
continue # Skip blank lines. continue # Skip blank lines.
...@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: ...@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions return sequences, descriptions
def parse_stockholm(stockholm_string: str def parse_stockholm(
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]: stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment. """Parses sequences and deletion matrix from stockholm format alignment.
Args: Args:
...@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str ...@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence = collections.OrderedDict() name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines(): for line in stockholm_string.splitlines():
line = line.strip() line = line.strip()
if not line or line.startswith(('#', '//')): if not line or line.startswith(("#", "//")):
continue continue
name, sequence = line.split() name, sequence = line.split()
if name not in name_to_sequence: if name not in name_to_sequence:
name_to_sequence[name] = '' name_to_sequence[name] = ""
name_to_sequence[name] += sequence name_to_sequence[name] += sequence
msa = [] msa = []
deletion_matrix = [] deletion_matrix = []
query = '' query = ""
keep_columns = [] keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()): for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0: if seq_index == 0:
# Gather the columns with gaps from the query # Gather the columns with gaps from the query
query = sequence query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-'] keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences. # Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns]) aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence) msa.append(aligned_sequence)
...@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str ...@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec = [] deletion_vec = []
deletion_count = 0 deletion_count = 0
for seq_res, query_res in zip(sequence, query): for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-': if seq_res != "-" or query_res != "-":
if query_res == '-': if query_res == "-":
deletion_count += 1 deletion_count += 1
else: else:
deletion_vec.append(deletion_count) deletion_vec.append(deletion_count)
...@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]: ...@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix.append(deletion_vec) deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences. # Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase) deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences] aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m( def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]: query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap: if is_query_res_non_gap:
yield sequence_res yield sequence_res
elif sequence_res != '-': elif sequence_res != "-":
yield sequence_res.lower() yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str, def convert_stockholm_to_a3m(
max_sequences: Optional[int] = None) -> str: stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format.""" """Converts MSA in Stockholm format to the A3M format."""
descriptions = {} descriptions = {}
sequences = {} sequences = {}
reached_max_sequences = False reached_max_sequences = False
for line in stockholm_format.splitlines(): for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences reached_max_sequences = (
if line.strip() and not line.startswith(('#', '//')): max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment # Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts. # sequence parts.
seqname, aligned_seq = line.split(maxsplit=1) seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences: if seqname not in sequences:
if reached_max_sequences: if reached_max_sequences:
continue continue
sequences[seqname] = '' sequences[seqname] = ""
sequences[seqname] += aligned_seq sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines(): for line in stockholm_format.splitlines():
if line[:4] == '#=GS': if line[:4] == "#=GS":
# Description row - example format is: # Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3) columns = line.split(maxsplit=3)
seqname, feature = columns[1:3] seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else '' value = columns[3] if len(columns) == 4 else ""
if feature != 'DE': if feature != "DE":
continue continue
if reached_max_sequences and seqname not in sequences: if reached_max_sequences and seqname not in sequences:
continue continue
...@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str, ...@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences = {} a3m_sequences = {}
# query_sequence is assumed to be the first sequence # query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values())) query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence] query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items(): for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join( a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)) _convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" fasta_chunks = (
for k in a3m_sequences) f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups( def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]: regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line) match = re.match(regex_pattern, line)
if match is None: if match is None:
raise RuntimeError(f'Could not parse query line {line}') raise RuntimeError(f"Could not parse query line {line}")
return match.groups() return match.groups()
def _update_hhr_residue_indices_list( def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]): sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence.""" """Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index counter = start_index
for symbol in sequence: for symbol in sequence:
if symbol == '-': if symbol == "-":
indices_list.append(-1) indices_list.append(-1)
else: else:
indices_list.append(counter) indices_list.append(counter)
...@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line. # Parse the summary line.
pattern = ( pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' "Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' " ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
']*Template_Neff=(.*)') "]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2]) match = re.match(pattern, detailed_lines[2])
if match is None: if match is None:
raise RuntimeError( raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' % "Could not parse section: %s. Expected this: \n%s to contain summary."
(detailed_lines, detailed_lines[2])) % (detailed_lines, detailed_lines[2])
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, )
neff) = [float(x) for x in match.groups()] (prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human # The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to # readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse # assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block. # that with a regexp in order to deduce the fixed length used for that block.
query = '' query = ""
hit_sequence = '' hit_sequence = ""
indices_query = [] indices_query = []
indices_hit = [] indices_hit = []
length_block = None length_block = None
for line in detailed_lines[3:]: for line in detailed_lines[3:]:
# Parse the query sequence line # Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and if (
not line.startswith('Q ss_pred') and line.startswith("Q ")
not line.startswith('Q Consensus')): and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse # Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that. # everything after that.
# start sequence end total_sequence_length # start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:]) groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices, # Get the length of the parsed block using the start and finish indices,
...@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start = int(groups[0]) - 1 # Make index zero based. start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1] delta_query = groups[1]
end = int(groups[2]) end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-']) num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions length_block = end - start + num_insertions
assert length_block == len(delta_query) assert length_block == len(delta_query)
...@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query += delta_query query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query) _update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '): elif line.startswith("T "):
# Parse the hit sequence. # Parse the hit sequence.
if (not line.startswith('T ss_dssp') and if (
not line.startswith('T ss_pred') and not line.startswith("T ss_dssp")
not line.startswith('T Consensus')): and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can # Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that. # parse everything after that.
# start sequence end total_sequence_length # start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:]) groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based. start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1] delta_hit_sequence = groups[1]
...@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: ...@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list. # Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit) _update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit( return TemplateHit(
index=number_of_hit, index=number_of_hit,
...@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: ...@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We # "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit. # iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')] block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = [] hits = []
if block_starts: if block_starts:
block_starts.append(len(lines)) # Add the end of the final block. block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1): for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string.""" """Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0} e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != '#'] lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and # space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1). # (5) E-value (full sequence) (numbering from 1).
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
...@@ -25,21 +25,21 @@ from typing import Optional ...@@ -25,21 +25,21 @@ from typing import Optional
@contextlib.contextmanager @contextlib.contextmanager
def tmpdir_manager(base_dir: Optional[str] = None): def tmpdir_manager(base_dir: Optional[str] = None):
"""Context manager that deletes a temporary directory on exit.""" """Context manager that deletes a temporary directory on exit."""
tmpdir = tempfile.mkdtemp(dir=base_dir) tmpdir = tempfile.mkdtemp(dir=base_dir)
try: try:
yield tmpdir yield tmpdir
finally: finally:
shutil.rmtree(tmpdir, ignore_errors=True) shutil.rmtree(tmpdir, ignore_errors=True)
@contextlib.contextmanager @contextlib.contextmanager
def timing(msg: str): def timing(msg: str):
logging.info('Started %s', msg) logging.info("Started %s", msg)
tic = time.time() tic = time.time()
yield yield
toc = time.time() toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic) logging.info("Finished %s in %.3f seconds", msg, toc - tic)
def to_date(s: str): def to_date(s: str):
......
...@@ -3,13 +3,14 @@ import glob ...@@ -3,13 +3,14 @@ import glob
import importlib as importlib import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py")) _files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")] __all__ = [
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__] os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules: for _m in _modules:
globals()[_m[0]] = _m[1] globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace # Avoid needlessly cluttering the global namespace
del _files, _m, _modules del _files, _m, _modules
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment