Commit 07e64267 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Standardize code style

parent de07730f
This diff is collapsed.
......@@ -27,45 +27,45 @@ from openfold.np import residue_constants
FeatureDict = Mapping[str, np.ndarray]
def make_sequence_features(
sequence: str,
description: str,
num_res: int
sequence: str, description: str, num_res: int
) -> FeatureDict:
"""Construct a feature dict of sequence features."""
features = {}
features['aatype'] = residue_constants.sequence_to_onehot(
features["aatype"] = residue_constants.sequence_to_onehot(
sequence=sequence,
mapping=residue_constants.restype_order_with_x,
map_unknown_to_x=True
map_unknown_to_x=True,
)
features['between_segment_residues'] = np.zeros((num_res,), dtype=np.int32)
features['domain_name'] = np.array(
[description.encode('utf-8')], dtype=np.object_
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_
)
features['residue_index'] = np.array(range(num_res), dtype=np.int32)
features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32)
features['sequence'] = np.array(
[sequence.encode('utf-8')], dtype=np.object_
features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_
)
return features
def make_mmcif_features(
mmcif_object: mmcif_parsing.MmcifObject,
chain_id: str
mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
input_sequence = mmcif_object.chain_to_seqres[chain_id]
description = '_'.join([mmcif_object.file_id, chain_id])
description = "_".join([mmcif_object.file_id, chain_id])
num_res = len(input_sequence)
mmcif_feats = {}
mmcif_feats.update(make_sequence_features(
mmcif_feats.update(
make_sequence_features(
sequence=input_sequence,
description=description,
num_res=num_res,
))
)
)
all_atom_positions, all_atom_mask = mmcif_parsing.get_atom_coords(
mmcif_object=mmcif_object, chain_id=chain_id
......@@ -78,7 +78,7 @@ def make_mmcif_features(
)
mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode('utf-8')], dtype=np.object_
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
)
return mmcif_feats
......@@ -86,17 +86,20 @@ def make_mmcif_features(
def make_msa_features(
msas: Sequence[Sequence[str]],
deletion_matrices: Sequence[parsers.DeletionMatrix]) -> FeatureDict:
deletion_matrices: Sequence[parsers.DeletionMatrix],
) -> FeatureDict:
"""Constructs a feature dict of MSA features."""
if not msas:
raise ValueError('At least one MSA must be provided.')
raise ValueError("At least one MSA must be provided.")
int_msa = []
deletion_matrix = []
seen_sequences = set()
for msa_index, msa in enumerate(msas):
if not msa:
raise ValueError(f'MSA {msa_index} must contain at least one sequence.')
raise ValueError(
f"MSA {msa_index} must contain at least one sequence."
)
for sequence_index, sequence in enumerate(msa):
if sequence in seen_sequences:
continue
......@@ -109,17 +112,19 @@ def make_msa_features(
num_res = len(msas[0][0])
num_alignments = len(int_msa)
features = {}
features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32)
features['msa'] = np.array(int_msa, dtype=np.int32)
features['num_alignments'] = np.array(
features["deletion_matrix_int"] = np.array(deletion_matrix, dtype=np.int32)
features["msa"] = np.array(int_msa, dtype=np.int32)
features["num_alignments"] = np.array(
[num_alignments] * num_res, dtype=np.int32
)
return features
class AlignmentRunner:
""" Runs alignment tools and saves the results """
def __init__(self,
"""Runs alignment tools and saves the results"""
def __init__(
self,
jackhmmer_binary_path: str,
hhblits_binary_path: str,
hhsearch_binary_path: str,
......@@ -161,105 +166,109 @@ class AlignmentRunner:
)
self.hhsearch_pdb70_runner = hhsearch.HHSearch(
binary_path=hhsearch_binary_path,
databases=[pdb70_database_path]
binary_path=hhsearch_binary_path, databases=[pdb70_database_path]
)
self.uniref_max_hits = uniref_max_hits
self.mgnify_max_hits = mgnify_max_hits
def run(self,
def run(
self,
fasta_path: str,
output_dir: str,
):
"""Runs alignment tools on a sequence"""
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(fasta_path)[0]
jackhmmer_uniref90_result = self.jackhmmer_uniref90_runner.query(
fasta_path
)[0]
uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_uniref90_result['sto'], max_sequences=self.uniref_max_hits
jackhmmer_uniref90_result["sto"], max_sequences=self.uniref_max_hits
)
uniref90_out_path = os.path.join(output_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'w') as f:
uniref90_out_path = os.path.join(output_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "w") as f:
f.write(uniref90_msa_as_a3m)
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(fasta_path)[0]
jackhmmer_mgnify_result = self.jackhmmer_mgnify_runner.query(
fasta_path
)[0]
mgnify_msa_as_a3m = parsers.convert_stockholm_to_a3m(
jackhmmer_mgnify_result['sto'], max_sequences=self.mgnify_max_hits
jackhmmer_mgnify_result["sto"], max_sequences=self.mgnify_max_hits
)
mgnify_out_path = os.path.join(output_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'w') as f:
mgnify_out_path = os.path.join(output_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "w") as f:
f.write(mgnify_msa_as_a3m)
hhsearch_result = self.hhsearch_pdb70_runner.query(uniref90_msa_as_a3m)
pdb70_out_path = os.path.join(output_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'w') as f:
pdb70_out_path = os.path.join(output_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "w") as f:
f.write(hhsearch_result)
if self._use_small_bfd:
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(fasta_path)[0]
bfd_out_path = os.path.join(output_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'w') as f:
f.write(jackhmmer_small_bfd_result['sto'])
jackhmmer_small_bfd_result = self.jackhmmer_small_bfd_runner.query(
fasta_path
)[0]
bfd_out_path = os.path.join(output_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "w") as f:
f.write(jackhmmer_small_bfd_result["sto"])
else:
hhblits_bfd_uniclust_result = self.hhblits_bfd_uniclust_runner.query(fasta_path)
if(output_dir is not None):
bfd_out_path = os.path.join(output_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'w') as f:
f.write(hhblits_bfd_uniclust_result['a3m'])
hhblits_bfd_uniclust_result = (
self.hhblits_bfd_uniclust_runner.query(fasta_path)
)
if output_dir is not None:
bfd_out_path = os.path.join(output_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "w") as f:
f.write(hhblits_bfd_uniclust_result["a3m"])
class DataPipeline:
"""Assembles input features."""
def __init__(self,
def __init__(
self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
):
self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output(self,
def _parse_alignment_output(
self,
alignment_dir: str,
) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, 'uniref90_hits.a3m')
with open(uniref90_out_path, 'r') as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(
f.read()
)
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "r") as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
mgnify_out_path = os.path.join(alignment_dir, 'mgnify_hits.a3m')
with open(mgnify_out_path, 'r') as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(
f.read()
)
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "r") as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
pdb70_out_path = os.path.join(alignment_dir, 'pdb70_hits.hhr')
with open(pdb70_out_path, 'r') as f:
hhsearch_hits = parsers.parse_hhr(
f.read()
)
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "r") as f:
hhsearch_hits = parsers.parse_hhr(f.read())
if(self.use_small_bfd):
bfd_out_path = os.path.join(alignment_dir, 'small_bfd_hits.sto')
with open(bfd_out_path, 'r') as f:
if self.use_small_bfd:
bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
)
else:
bfd_out_path = os.path.join(alignment_dir, 'bfd_uniclust_hits.a3m')
with open(bfd_out_path, 'r') as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(
f.read()
)
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
return {
'uniref90_msa': uniref90_msa,
'uniref90_deletion_matrix': uniref90_deletion_matrix,
'mgnify_msa': mgnify_msa,
'mgnify_deletion_matrix': mgnify_deletion_matrix,
'hhsearch_hits': hhsearch_hits,
'bfd_msa': bfd_msa,
'bfd_deletion_matrix': bfd_deletion_matrix,
"uniref90_msa": uniref90_msa,
"uniref90_deletion_matrix": uniref90_deletion_matrix,
"mgnify_msa": mgnify_msa,
"mgnify_deletion_matrix": mgnify_deletion_matrix,
"hhsearch_hits": hhsearch_hits,
"bfd_msa": bfd_msa,
"bfd_deletion_matrix": bfd_deletion_matrix,
}
def process_fasta(self,
def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
) -> FeatureDict:
......@@ -269,7 +278,8 @@ class DataPipeline:
input_seqs, input_descs = parsers.parse_fasta(fasta_str)
if len(input_seqs) != 1:
raise ValueError(
f'More than one input sequence found in {fasta_path}.')
f"More than one input sequence found in {fasta_path}."
)
input_sequence = input_seqs[0]
input_description = input_descs[0]
num_res = len(input_sequence)
......@@ -280,30 +290,31 @@ class DataPipeline:
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=alignments['hhsearch_hits']
hits=alignments["hhsearch_hits"],
)
sequence_features = make_sequence_features(
sequence=input_sequence,
description=input_description,
num_res=num_res
num_res=num_res,
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
return {**sequence_features, **msa_features, **templates_result.data}
def process_mmcif(self,
def process_mmcif(
self,
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
......@@ -314,13 +325,11 @@ class DataPipeline:
If chain_id is None, it is assumed that there is only one chain
in the object. Otherwise, a ValueError is thrown.
"""
if(chain_id is None):
if chain_id is None:
chains = mmcif.structure.get_chains()
chain = next(chains, None)
if(chain is None):
raise ValueError(
'No chains in mmCIF file'
)
if chain is None:
raise ValueError("No chains in mmCIF file")
chain_id = chain.id
mmcif_feats = make_mmcif_features(mmcif, chain_id)
......@@ -332,20 +341,20 @@ class DataPipeline:
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments['hhsearch_hits']
hits=alignments["hhsearch_hits"],
)
msa_features = make_msa_features(
msas=(
alignments['uniref90_msa'],
alignments['bfd_msa'],
alignments['mgnify_msa']
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
deletion_matrices = (
alignments['uniref90_deletion_matrix'],
alignments['bfd_deletion_matrix'],
alignments['mgnify_deletion_matrix']
)
)
return {**mmcif_feats, **templates_result.data, **msa_features}
This diff is collapsed.
......@@ -26,10 +26,11 @@ from openfold.data import input_pipeline
FeatureDict = Mapping[str, np.ndarray]
TensorDict = Dict[str, torch.Tensor]
def np_to_tensor_dict(
np_example: Mapping[str, np.ndarray],
features: Sequence[str],
) -> TensorDict:
) -> TensorDict:
"""Creates dict of tensors from a dict of NumPy arrays.
Args:
......@@ -54,7 +55,7 @@ def make_data_config(
cfg = copy.deepcopy(config)
mode_cfg = cfg[mode]
with cfg.unlocked():
if(mode_cfg.crop_size is None):
if mode_cfg.crop_size is None:
mode_cfg.crop_size = num_res
feature_names = cfg.common.unsupervised_features
......@@ -62,7 +63,7 @@ def make_data_config(
if cfg.common.use_templates:
feature_names += cfg.common.template_features
if(cfg[mode].supervised):
if cfg[mode].supervised:
feature_names += cfg.common.supervised_features
return cfg, feature_names
......@@ -75,47 +76,47 @@ def np_example_to_features(
batch_mode: str,
):
np_example = dict(np_example)
num_res = int(np_example['seq_length'][0])
cfg, feature_names = make_data_config(
config, mode=mode, num_res=num_res
)
num_res = int(np_example["seq_length"][0])
cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res)
if 'deletion_matrix_int' in np_example:
np_example['deletion_matrix'] = (
np_example.pop('deletion_matrix_int').astype(np.float32)
)
if "deletion_matrix_int" in np_example:
np_example["deletion_matrix"] = np_example.pop(
"deletion_matrix_int"
).astype(np.float32)
if batch_mode == 'clamped':
np_example['use_clamped_fape'] = (
np.array(1.).astype(np.float32)
)
elif batch_mode == 'unclamped':
np_example['use_clamped_fape'] = (
np.array(0.).astype(np.float32)
)
if batch_mode == "clamped":
np_example["use_clamped_fape"] = np.array(1.0).astype(np.float32)
elif batch_mode == "unclamped":
np_example["use_clamped_fape"] = np.array(0.0).astype(np.float32)
tensor_dict = np_to_tensor_dict(
np_example=np_example, features=feature_names
)
with torch.no_grad():
features = input_pipeline.process_tensors_from_config(
tensor_dict, cfg.common, cfg[mode], batch_mode=batch_mode,
tensor_dict,
cfg.common,
cfg[mode],
batch_mode=batch_mode,
)
return {k: v for k, v in features.items()}
class FeaturePipeline:
def __init__(self,
def __init__(
self,
config: ml_collections.ConfigDict,
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None):
params: Optional[Mapping[str, Mapping[str, np.ndarray]]] = None,
):
self.config = config
self.params = params
def process_features(self,
def process_features(
self,
raw_features: FeatureDict,
mode: str = 'train',
batch_mode: str = 'clamped',
mode: str = "train",
batch_mode: str = "clamped",
) -> FeatureDict:
return np_example_to_features(
np_example=raw_features,
......
......@@ -33,29 +33,37 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.make_hhblits_profile,
]
if common_cfg.use_templates:
transforms.extend([
transforms.extend(
[
data_transforms.fix_templates_aatype,
data_transforms.make_template_mask,
data_transforms.make_pseudo_beta('template_')
])
if(common_cfg.use_template_torsion_angles):
transforms.extend([
data_transforms.atom37_to_torsion_angles('template_'),
])
transforms.extend([
data_transforms.make_pseudo_beta("template_"),
]
)
if common_cfg.use_template_torsion_angles:
transforms.extend(
[
data_transforms.atom37_to_torsion_angles("template_"),
]
)
transforms.extend(
[
data_transforms.make_atom14_masks,
])
]
)
if(mode_cfg.supervised):
transforms.extend([
if mode_cfg.supervised:
transforms.extend(
[
data_transforms.make_atom14_positions,
data_transforms.atom37_to_frames,
data_transforms.atom37_to_torsion_angles(''),
data_transforms.make_pseudo_beta(''),
data_transforms.atom37_to_torsion_angles(""),
data_transforms.make_pseudo_beta(""),
data_transforms.get_backbone_frames,
data_transforms.get_chi_angles,
])
]
)
return transforms
......@@ -76,14 +84,13 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
data_transforms.sample_msa(max_msa_clusters, keep_extra=True)
)
if 'masked_msa' in common_cfg:
if "masked_msa" in common_cfg:
# Masked MSA should come *before* MSA clustering so that
# the clustering and full MSA profile do not leak information about
# the masked locations and secret corrupted locations.
transforms.append(
data_transforms.make_masked_msa(
common_cfg.masked_msa,
mode_cfg.masked_msa_replace_fraction
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
)
)
......@@ -103,21 +110,25 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
if mode_cfg.fixed_size:
transforms.append(data_transforms.select_feat(list(crop_feats)))
transforms.append(data_transforms.random_crop_to_size(
transforms.append(
data_transforms.random_crop_to_size(
mode_cfg.crop_size,
mode_cfg.max_templates,
crop_feats,
mode_cfg.subsample_templates,
batch_mode=batch_mode,
seed=torch.Generator().seed()
))
transforms.append(data_transforms.make_fixed_size(
seed=torch.Generator().seed(),
)
)
transforms.append(
data_transforms.make_fixed_size(
crop_feats,
pad_msa_clusters,
common_cfg.max_extra_msa,
mode_cfg.crop_size,
mode_cfg.max_templates
))
mode_cfg.max_templates,
)
)
else:
transforms.append(
data_transforms.crop_templates(mode_cfg.max_templates)
......@@ -127,7 +138,7 @@ def ensembled_transform_fns(common_cfg, mode_cfg, batch_mode):
def process_tensors_from_config(
tensors, common_cfg, mode_cfg, batch_mode='clamped'
tensors, common_cfg, mode_cfg, batch_mode="clamped"
):
"""Based on the config, apply filters and transformations to the data."""
......@@ -136,12 +147,10 @@ def process_tensors_from_config(
d = data.copy()
fns = ensembled_transform_fns(common_cfg, mode_cfg, batch_mode)
fn = compose(fns)
d['ensemble_index'] = i
d["ensemble_index"] = i
return fn(d)
tensors = compose(
nonensembled_transform_fns(common_cfg, mode_cfg)
)(tensors)
tensors = compose(nonensembled_transform_fns(common_cfg, mode_cfg))(tensors)
tensors_0 = wrap_ensemble_fn(tensors, 0)
num_ensemble = mode_cfg.num_ensemble
......@@ -150,8 +159,9 @@ def process_tensors_from_config(
num_ensemble *= common_cfg.num_recycle + 1
if isinstance(num_ensemble, torch.Tensor) or num_ensemble > 1:
tensors = map_fn(lambda x: wrap_ensemble_fn(tensors, x),
torch.arange(num_ensemble))
tensors = map_fn(
lambda x: wrap_ensemble_fn(tensors, x), torch.arange(num_ensemble)
)
else:
tensors = tree.map_structure(lambda x: x[None], tensors_0)
......
......@@ -90,6 +90,7 @@ class MmcifObject:
...}}
raw_string: The raw string used to construct the MmcifObject.
"""
file_id: str
header: PdbHeader
structure: PdbStructure
......@@ -107,6 +108,7 @@ class ParsingResult:
parsed.
errors: A dict mapping (file_id, chain_id) to any exception generated.
"""
mmcif_object: Optional[MmcifObject]
errors: Mapping[Tuple[str, str], Any]
......@@ -115,8 +117,9 @@ class ParseError(Exception):
"""An error indicating that an mmCIF file could not be parsed."""
def mmcif_loop_to_list(prefix: str,
parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]:
def mmcif_loop_to_list(
prefix: str, parsed_info: MmCIFDict
) -> Sequence[Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a list.
Reference for loop_ in mmCIF:
......@@ -140,15 +143,17 @@ def mmcif_loop_to_list(prefix: str,
data.append(value)
assert all([len(xs) == len(data[0]) for xs in data]), (
'mmCIF error: Not all loops are the same length: %s' % cols)
"mmCIF error: Not all loops are the same length: %s" % cols
)
return [dict(zip(cols, xs)) for xs in zip(*data)]
def mmcif_loop_to_dict(prefix: str,
def mmcif_loop_to_dict(
prefix: str,
index: str,
parsed_info: MmCIFDict,
) -> Mapping[str, Mapping[str, str]]:
) -> Mapping[str, Mapping[str, str]]:
"""Extracts loop associated with a prefix from mmCIF data as a dictionary.
Args:
......@@ -167,10 +172,9 @@ def mmcif_loop_to_dict(prefix: str,
return {entry[index]: entry for entry in entries}
def parse(*,
file_id: str,
mmcif_string: str,
catch_all_errors: bool = True) -> ParsingResult:
def parse(
*, file_id: str, mmcif_string: str, catch_all_errors: bool = True
) -> ParsingResult:
"""Entry point, parses an mmcif_string.
Args:
......@@ -188,7 +192,7 @@ def parse(*,
try:
parser = PDB.MMCIFParser(QUIET=True)
handle = io.StringIO(mmcif_string)
full_structure = parser.get_structure('', handle)
full_structure = parser.get_structure("", handle)
first_model_structure = _get_first_model(full_structure)
# Extract the _mmcif_dict from the parser, which contains useful fields not
# reflected in the Biopython structure.
......@@ -206,9 +210,12 @@ def parse(*,
valid_chains = _get_protein_chains(parsed_info=parsed_info)
if not valid_chains:
return ParsingResult(
None, {(file_id, ''): 'No protein chains found in this file.'})
seq_start_num = {chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()}
None, {(file_id, ""): "No protein chains found in this file."}
)
seq_start_num = {
chain_id: min([monomer.num for monomer in seq])
for chain_id, seq in valid_chains.items()
}
# Loop over the atoms for which we have coordinates. Populate two mappings:
# -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used
......@@ -217,34 +224,42 @@ def parse(*,
mmcif_to_author_chain_id = {}
seq_to_structure_mappings = {}
for atom in _get_atom_site_list(parsed_info):
if atom.model_num != '1':
if atom.model_num != "1":
# We only process the first model at the moment.
continue
mmcif_to_author_chain_id[atom.mmcif_chain_id] = atom.author_chain_id
if atom.mmcif_chain_id in valid_chains:
hetflag = ' '
if atom.hetatm_atom == 'HETATM':
hetflag = " "
if atom.hetatm_atom == "HETATM":
# Water atoms are assigned a special hetflag of W in Biopython. We
# need to do the same, so that this hetflag can be used to fetch
# a residue from the Biopython structure by id.
if atom.residue_name in ('HOH', 'WAT'):
hetflag = 'W'
if atom.residue_name in ("HOH", "WAT"):
hetflag = "W"
else:
hetflag = 'H_' + atom.residue_name
hetflag = "H_" + atom.residue_name
insertion_code = atom.insertion_code
if not _is_set(atom.insertion_code):
insertion_code = ' '
position = ResiduePosition(chain_id=atom.author_chain_id,
insertion_code = " "
position = ResiduePosition(
chain_id=atom.author_chain_id,
residue_number=int(atom.author_seq_num),
insertion_code=insertion_code)
seq_idx = int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
current = seq_to_structure_mappings.get(atom.author_chain_id, {})
current[seq_idx] = ResidueAtPosition(position=position,
insertion_code=insertion_code,
)
seq_idx = (
int(atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id]
)
current = seq_to_structure_mappings.get(
atom.author_chain_id, {}
)
current[seq_idx] = ResidueAtPosition(
position=position,
name=atom.residue_name,
is_missing=False,
hetflag=hetflag)
hetflag=hetflag,
)
seq_to_structure_mappings[atom.author_chain_id] = current
# Add missing residue information to seq_to_structure_mappings.
......@@ -253,19 +268,21 @@ def parse(*,
current_mapping = seq_to_structure_mappings[author_chain]
for idx, monomer in enumerate(seq_info):
if idx not in current_mapping:
current_mapping[idx] = ResidueAtPosition(position=None,
current_mapping[idx] = ResidueAtPosition(
position=None,
name=monomer.id,
is_missing=True,
hetflag=' ')
hetflag=" ",
)
author_chain_to_sequence = {}
for chain_id, seq_info in valid_chains.items():
author_chain = mmcif_to_author_chain_id[chain_id]
seq = []
for monomer in seq_info:
code = SCOPData.protein_letters_3to1.get(monomer.id, 'X')
seq.append(code if len(code) == 1 else 'X')
seq = ''.join(seq)
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
seq.append(code if len(code) == 1 else "X")
seq = "".join(seq)
author_chain_to_sequence[author_chain] = seq
mmcif_object = MmcifObject(
......@@ -274,11 +291,12 @@ def parse(*,
structure=first_model_structure,
chain_to_seqres=author_chain_to_sequence,
seqres_to_structure=seq_to_structure_mappings,
raw_string=parsed_info)
raw_string=parsed_info,
)
return ParsingResult(mmcif_object=mmcif_object, errors=errors)
except Exception as e: # pylint:disable=broad-except
errors[(file_id, '')] = e
errors[(file_id, "")] = e
if not catch_all_errors:
raise
return ParsingResult(mmcif_object=None, errors=errors)
......@@ -288,12 +306,13 @@ def _get_first_model(structure: PdbStructure) -> PdbStructure:
"""Returns the first model in a Biopython structure."""
return next(structure.get_models())
_MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21
def get_release_date(parsed_info: MmCIFDict) -> str:
"""Returns the oldest revision date."""
revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date']
revision_dates = parsed_info["_pdbx_audit_revision_history.revision_date"]
return min(revision_dates)
......@@ -301,47 +320,58 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
"""Returns a basic header containing method, release date and resolution."""
header = {}
experiments = mmcif_loop_to_list('_exptl.', parsed_info)
header['structure_method'] = ','.join([
experiment['_exptl.method'].lower() for experiment in experiments])
experiments = mmcif_loop_to_list("_exptl.", parsed_info)
header["structure_method"] = ",".join(
[experiment["_exptl.method"].lower() for experiment in experiments]
)
# Note: The release_date here corresponds to the oldest revision. We prefer to
# use this for dataset filtering over the deposition_date.
if '_pdbx_audit_revision_history.revision_date' in parsed_info:
header['release_date'] = get_release_date(parsed_info)
if "_pdbx_audit_revision_history.revision_date" in parsed_info:
header["release_date"] = get_release_date(parsed_info)
else:
logging.warning('Could not determine release_date: %s',
parsed_info['_entry.id'])
logging.warning(
"Could not determine release_date: %s", parsed_info["_entry.id"]
)
header['resolution'] = 0.00
for res_key in ('_refine.ls_d_res_high', '_em_3d_reconstruction.resolution',
'_reflns.d_resolution_high'):
header["resolution"] = 0.00
for res_key in (
"_refine.ls_d_res_high",
"_em_3d_reconstruction.resolution",
"_reflns.d_resolution_high",
):
if res_key in parsed_info:
try:
raw_resolution = parsed_info[res_key][0]
header['resolution'] = float(raw_resolution)
header["resolution"] = float(raw_resolution)
except ValueError:
logging.warning('Invalid resolution format: %s', parsed_info[res_key])
logging.warning(
"Invalid resolution format: %s", parsed_info[res_key]
)
return header
def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]:
"""Returns list of atom sites; contains data not present in the structure."""
return [AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension
parsed_info['_atom_site.label_comp_id'],
parsed_info['_atom_site.auth_asym_id'],
parsed_info['_atom_site.label_asym_id'],
parsed_info['_atom_site.auth_seq_id'],
parsed_info['_atom_site.label_seq_id'],
parsed_info['_atom_site.pdbx_PDB_ins_code'],
parsed_info['_atom_site.group_PDB'],
parsed_info['_atom_site.pdbx_PDB_model_num'],
)]
return [
AtomSite(*site)
for site in zip( # pylint:disable=g-complex-comprehension
parsed_info["_atom_site.label_comp_id"],
parsed_info["_atom_site.auth_asym_id"],
parsed_info["_atom_site.label_asym_id"],
parsed_info["_atom_site.auth_seq_id"],
parsed_info["_atom_site.label_seq_id"],
parsed_info["_atom_site.pdbx_PDB_ins_code"],
parsed_info["_atom_site.group_PDB"],
parsed_info["_atom_site.pdbx_PDB_model_num"],
)
]
def _get_protein_chains(
*, parsed_info: Mapping[str, Any]) -> Mapping[ChainId, Sequence[Monomer]]:
*, parsed_info: Mapping[str, Any]
) -> Mapping[ChainId, Sequence[Monomer]]:
"""Extracts polymer information for protein chains only.
Args:
......@@ -351,26 +381,29 @@ def _get_protein_chains(
A dict mapping mmcif chain id to a list of Monomers.
"""
# Get polymer information for each entity in the structure.
entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info)
entity_poly_seqs = mmcif_loop_to_list("_entity_poly_seq.", parsed_info)
polymers = collections.defaultdict(list)
for entity_poly_seq in entity_poly_seqs:
polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append(
Monomer(id=entity_poly_seq['_entity_poly_seq.mon_id'],
num=int(entity_poly_seq['_entity_poly_seq.num'])))
polymers[entity_poly_seq["_entity_poly_seq.entity_id"]].append(
Monomer(
id=entity_poly_seq["_entity_poly_seq.mon_id"],
num=int(entity_poly_seq["_entity_poly_seq.num"]),
)
)
# Get chemical compositions. Will allow us to identify which of these polymers
# are proteins.
chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', parsed_info)
chem_comps = mmcif_loop_to_dict("_chem_comp.", "_chem_comp.id", parsed_info)
# Get chains information for each entity. Necessary so that we can return a
# dict keyed on chain id rather than entity.
struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info)
struct_asyms = mmcif_loop_to_list("_struct_asym.", parsed_info)
entity_to_mmcif_chains = collections.defaultdict(list)
for struct_asym in struct_asyms:
chain_id = struct_asym['_struct_asym.id']
entity_id = struct_asym['_struct_asym.entity_id']
chain_id = struct_asym["_struct_asym.id"]
entity_id = struct_asym["_struct_asym.entity_id"]
entity_to_mmcif_chains[entity_id].append(chain_id)
# Identify and return the valid protein chains.
......@@ -379,8 +412,12 @@ def _get_protein_chains(
chain_ids = entity_to_mmcif_chains[entity_id]
# Reject polymers without any peptide-like components, such as DNA/RNA.
if any(['peptide' in chem_comps[monomer.id]['_chem_comp.type']
for monomer in seq_info]):
if any(
[
"peptide" in chem_comps[monomer.id]["_chem_comp.type"]
for monomer in seq_info
]
):
for chain_id in chain_ids:
valid_chains[chain_id] = seq_info
return valid_chains
......@@ -388,19 +425,18 @@ def _get_protein_chains(
def _is_set(data: str) -> bool:
"""Returns False if data is a special mmCIF character indicating 'unset'."""
return data not in ('.', '?')
return data not in (".", "?")
def get_atom_coords(
mmcif_object: MmcifObject,
chain_id: str
mmcif_object: MmcifObject, chain_id: str
) -> Tuple[np.ndarray, np.ndarray]:
# Locate the right chain
chains = list(mmcif_object.structure.get_chains())
relevant_chains = [c for c in chains if c.id == chain_id]
if len(relevant_chains) != 1:
raise MultipleChainsError(
f'Expected exactly one chain in structure with id {chain_id}.'
f"Expected exactly one chain in structure with id {chain_id}."
)
chain = relevant_chains[0]
......@@ -417,19 +453,23 @@ def get_atom_coords(
mask = np.zeros([residue_constants.atom_type_num], dtype=np.float32)
res_at_position = mmcif_object.seqres_to_structure[chain_id][res_index]
if not res_at_position.is_missing:
res = chain[(res_at_position.hetflag,
res = chain[
(
res_at_position.hetflag,
res_at_position.position.residue_number,
res_at_position.position.insertion_code)]
res_at_position.position.insertion_code,
)
]
for atom in res.get_atoms():
atom_name = atom.get_name()
x, y, z = atom.get_coord()
if atom_name in residue_constants.atom_order.keys():
pos[residue_constants.atom_order[atom_name]] = [x, y, z]
mask[residue_constants.atom_order[atom_name]] = 1.0
elif atom_name.upper() == 'SE' and res.get_resname() == 'MSE':
elif atom_name.upper() == "SE" and res.get_resname() == "MSE":
# Put the coords of the selenium atom in the sulphur column
pos[residue_constants.atom_order['SD']] = [x, y, z]
mask[residue_constants.atom_order['SD']] = 1.0
pos[residue_constants.atom_order["SD"]] = [x, y, z]
mask[residue_constants.atom_order["SD"]] = 1.0
all_atom_positions[res_index] = pos
all_atom_mask[res_index] = mask
......@@ -440,22 +480,22 @@ def get_atom_coords(
def generate_mmcif_cache(mmcif_dir: str, out_path: str):
data = {}
for f in os.listdir(mmcif_dir):
if(f.endswith('.cif')):
with open(os.path.join(mmcif_dir, f), 'r') as fp:
if f.endswith(".cif"):
with open(os.path.join(mmcif_dir, f), "r") as fp:
mmcif_string = fp.read()
file_id = os.path.splitext(f)[0]
mmcif = parse(file_id=file_id, mmcif_string=mmcif_string)
if(mmcif.mmcif_object is None):
logging.warning(f'Could not parse {f}. Skipping...')
if mmcif.mmcif_object is None:
logging.warning(f"Could not parse {f}. Skipping...")
continue
else:
mmcif = mmcif.mmcif_object
local_data = {}
local_data['release_date'] = mmcif.header["release_date"]
local_data['no_chains'] = len(list(mmcif.structure.get_chains()))
local_data["release_date"] = mmcif.header["release_date"]
local_data["no_chains"] = len(list(mmcif.structure.get_chains()))
data[file_id] = local_data
with open(out_path, 'w') as fp:
with open(out_path, "w") as fp:
fp.write(json.dumps(data))
......@@ -23,9 +23,11 @@ from typing import Dict, Iterable, List, Optional, Sequence, Tuple
DeletionMatrix = Sequence[Sequence[int]]
@dataclasses.dataclass(frozen=True)
class TemplateHit:
"""Class representing a template hit."""
index: int
name: str
aligned_cols: int
......@@ -53,10 +55,10 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
index = -1
for line in fasta_string.splitlines():
line = line.strip()
if line.startswith('>'):
if line.startswith(">"):
index += 1
descriptions.append(line[1:]) # Remove the '>' at the beginning.
sequences.append('')
sequences.append("")
continue
elif not line:
continue # Skip blank lines.
......@@ -65,8 +67,9 @@ def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]:
return sequences, descriptions
def parse_stockholm(stockholm_string: str
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
def parse_stockholm(
stockholm_string: str,
) -> Tuple[Sequence[str], DeletionMatrix, Sequence[str]]:
"""Parses sequences and deletion matrix from stockholm format alignment.
Args:
......@@ -86,26 +89,26 @@ def parse_stockholm(stockholm_string: str
name_to_sequence = collections.OrderedDict()
for line in stockholm_string.splitlines():
line = line.strip()
if not line or line.startswith(('#', '//')):
if not line or line.startswith(("#", "//")):
continue
name, sequence = line.split()
if name not in name_to_sequence:
name_to_sequence[name] = ''
name_to_sequence[name] = ""
name_to_sequence[name] += sequence
msa = []
deletion_matrix = []
query = ''
query = ""
keep_columns = []
for seq_index, sequence in enumerate(name_to_sequence.values()):
if seq_index == 0:
# Gather the columns with gaps from the query
query = sequence
keep_columns = [i for i, res in enumerate(query) if res != '-']
keep_columns = [i for i, res in enumerate(query) if res != "-"]
# Remove the columns with gaps in the query from all sequences.
aligned_sequence = ''.join([sequence[c] for c in keep_columns])
aligned_sequence = "".join([sequence[c] for c in keep_columns])
msa.append(aligned_sequence)
......@@ -113,8 +116,8 @@ def parse_stockholm(stockholm_string: str
deletion_vec = []
deletion_count = 0
for seq_res, query_res in zip(sequence, query):
if seq_res != '-' or query_res != '-':
if query_res == '-':
if seq_res != "-" or query_res != "-":
if query_res == "-":
deletion_count += 1
else:
deletion_vec.append(deletion_count)
......@@ -153,47 +156,51 @@ def parse_a3m(a3m_string: str) -> Tuple[Sequence[str], DeletionMatrix]:
deletion_matrix.append(deletion_vec)
# Make the MSA matrix out of aligned (deletion-free) sequences.
deletion_table = str.maketrans('', '', string.ascii_lowercase)
deletion_table = str.maketrans("", "", string.ascii_lowercase)
aligned_sequences = [s.translate(deletion_table) for s in sequences]
return aligned_sequences, deletion_matrix
def _convert_sto_seq_to_a3m(
query_non_gaps: Sequence[bool], sto_seq: str) -> Iterable[str]:
query_non_gaps: Sequence[bool], sto_seq: str
) -> Iterable[str]:
for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq):
if is_query_res_non_gap:
yield sequence_res
elif sequence_res != '-':
elif sequence_res != "-":
yield sequence_res.lower()
def convert_stockholm_to_a3m(stockholm_format: str,
max_sequences: Optional[int] = None) -> str:
def convert_stockholm_to_a3m(
stockholm_format: str, max_sequences: Optional[int] = None
) -> str:
"""Converts MSA in Stockholm format to the A3M format."""
descriptions = {}
sequences = {}
reached_max_sequences = False
for line in stockholm_format.splitlines():
reached_max_sequences = max_sequences and len(sequences) >= max_sequences
if line.strip() and not line.startswith(('#', '//')):
reached_max_sequences = (
max_sequences and len(sequences) >= max_sequences
)
if line.strip() and not line.startswith(("#", "//")):
# Ignore blank lines, markup and end symbols - remainder are alignment
# sequence parts.
seqname, aligned_seq = line.split(maxsplit=1)
if seqname not in sequences:
if reached_max_sequences:
continue
sequences[seqname] = ''
sequences[seqname] = ""
sequences[seqname] += aligned_seq
for line in stockholm_format.splitlines():
if line[:4] == '#=GS':
if line[:4] == "#=GS":
# Description row - example format is:
# #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ...
columns = line.split(maxsplit=3)
seqname, feature = columns[1:3]
value = columns[3] if len(columns) == 4 else ''
if feature != 'DE':
value = columns[3] if len(columns) == 4 else ""
if feature != "DE":
continue
if reached_max_sequences and seqname not in sequences:
continue
......@@ -205,30 +212,35 @@ def convert_stockholm_to_a3m(stockholm_format: str,
a3m_sequences = {}
# query_sequence is assumed to be the first sequence
query_sequence = next(iter(sequences.values()))
query_non_gaps = [res != '-' for res in query_sequence]
query_non_gaps = [res != "-" for res in query_sequence]
for seqname, sto_sequence in sequences.items():
a3m_sequences[seqname] = ''.join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence))
a3m_sequences[seqname] = "".join(
_convert_sto_seq_to_a3m(query_non_gaps, sto_sequence)
)
fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences)
return '\n'.join(fasta_chunks) + '\n' # Include terminating newline.
fasta_chunks = (
f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}"
for k in a3m_sequences
)
return "\n".join(fasta_chunks) + "\n" # Include terminating newline.
def _get_hhr_line_regex_groups(
regex_pattern: str, line: str) -> Sequence[Optional[str]]:
regex_pattern: str, line: str
) -> Sequence[Optional[str]]:
match = re.match(regex_pattern, line)
if match is None:
raise RuntimeError(f'Could not parse query line {line}')
raise RuntimeError(f"Could not parse query line {line}")
return match.groups()
def _update_hhr_residue_indices_list(
sequence: str, start_index: int, indices_list: List[int]):
sequence: str, start_index: int, indices_list: List[int]
):
"""Computes the relative indices for each residue with respect to the original sequence."""
counter = start_index
for symbol in sequence:
if symbol == '-':
if symbol == "-":
indices_list.append(-1)
else:
indices_list.append(counter)
......@@ -256,36 +268,42 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Parse the summary line.
pattern = (
'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t'
' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t '
']*Template_Neff=(.*)')
"Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t"
" ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t "
"]*Template_Neff=(.*)"
)
match = re.match(pattern, detailed_lines[2])
if match is None:
raise RuntimeError(
'Could not parse section: %s. Expected this: \n%s to contain summary.' %
(detailed_lines, detailed_lines[2]))
(prob_true, e_value, _, aligned_cols, _, _, sum_probs,
neff) = [float(x) for x in match.groups()]
"Could not parse section: %s. Expected this: \n%s to contain summary."
% (detailed_lines, detailed_lines[2])
)
(prob_true, e_value, _, aligned_cols, _, _, sum_probs, neff) = [
float(x) for x in match.groups()
]
# The next section reads the detailed comparisons. These are in a 'human
# readable' format which has a fixed length. The strategy employed is to
# assume that each block starts with the query sequence line, and to parse
# that with a regexp in order to deduce the fixed length used for that block.
query = ''
hit_sequence = ''
query = ""
hit_sequence = ""
indices_query = []
indices_hit = []
length_block = None
for line in detailed_lines[3:]:
# Parse the query sequence line
if (line.startswith('Q ') and not line.startswith('Q ss_dssp') and
not line.startswith('Q ss_pred') and
not line.startswith('Q Consensus')):
if (
line.startswith("Q ")
and not line.startswith("Q ss_dssp")
and not line.startswith("Q ss_pred")
and not line.startswith("Q Consensus")
):
# Thus the first 17 characters must be 'Q <query_name> ', and we can parse
# everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)'
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
# Get the length of the parsed block using the start and finish indices,
......@@ -293,7 +311,7 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
start = int(groups[0]) - 1 # Make index zero based.
delta_query = groups[1]
end = int(groups[2])
num_insertions = len([x for x in delta_query if x == '-'])
num_insertions = len([x for x in delta_query if x == "-"])
length_block = end - start + num_insertions
assert length_block == len(delta_query)
......@@ -301,15 +319,17 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
query += delta_query
_update_hhr_residue_indices_list(delta_query, start, indices_query)
elif line.startswith('T '):
elif line.startswith("T "):
# Parse the hit sequence.
if (not line.startswith('T ss_dssp') and
not line.startswith('T ss_pred') and
not line.startswith('T Consensus')):
if (
not line.startswith("T ss_dssp")
and not line.startswith("T ss_pred")
and not line.startswith("T Consensus")
):
# Thus the first 17 characters must be 'T <hit_name> ', and we can
# parse everything after that.
# start sequence end total_sequence_length
patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)'
patt = r"[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)"
groups = _get_hhr_line_regex_groups(patt, line[17:])
start = int(groups[0]) - 1 # Make index zero based.
delta_hit_sequence = groups[1]
......@@ -317,7 +337,9 @@ def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit:
# Update the hit sequence and indices list.
hit_sequence += delta_hit_sequence
_update_hhr_residue_indices_list(delta_hit_sequence, start, indices_hit)
_update_hhr_residue_indices_list(
delta_hit_sequence, start, indices_hit
)
return TemplateHit(
index=number_of_hit,
......@@ -339,20 +361,22 @@ def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]:
# "paragraphs", each paragraph starting with a line 'No <hit number>'. We
# iterate through each paragraph to parse each hit.
block_starts = [i for i, line in enumerate(lines) if line.startswith('No ')]
block_starts = [i for i, line in enumerate(lines) if line.startswith("No ")]
hits = []
if block_starts:
block_starts.append(len(lines)) # Add the end of the final block.
for i in range(len(block_starts) - 1):
hits.append(_parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]]))
hits.append(
_parse_hhr_hit(lines[block_starts[i] : block_starts[i + 1]])
)
return hits
def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]:
"""Parse target to e-value mapping parsed from Jackhmmer tblout string."""
e_values = {'query': 0}
lines = [line for line in tblout.splitlines() if line[0] != '#']
e_values = {"query": 0}
lines = [line for line in tblout.splitlines() if line[0] != "#"]
# As per http://eddylab.org/software/hmmer/Userguide.pdf fields are
# space-delimited. Relevant fields are (1) target name: and
# (5) E-value (full sequence) (numbering from 1).
......
This diff is collapsed.
......@@ -30,7 +30,8 @@ _HHBLITS_DEFAULT_Z = 500
class HHBlits:
"""Python wrapper of the HHblits binary."""
def __init__(self,
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
......@@ -44,7 +45,8 @@ class HHBlits:
all_seqs: bool = False,
alt: Optional[int] = None,
p: int = _HHBLITS_DEFAULT_P,
z: int = _HHBLITS_DEFAULT_Z):
z: int = _HHBLITS_DEFAULT_Z,
):
"""Initializes the Python HHblits wrapper.
Args:
......@@ -77,9 +79,13 @@ class HHBlits:
self.databases = databases
for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHBlits database %s', database_path)
raise ValueError(f'Could not find HHBlits database {database_path}')
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHBlits database %s", database_path
)
raise ValueError(
f"Could not find HHBlits database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
......@@ -95,52 +101,66 @@ class HHBlits:
def query(self, input_fasta_path: str) -> Mapping[str, Any]:
"""Queries the database using HHblits."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
a3m_path = os.path.join(query_tmp_dir, "output.a3m")
db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-cpu', str(self.n_cpu),
'-oa3m', a3m_path,
'-o', '/dev/null',
'-n', str(self.n_iter),
'-e', str(self.e_value),
'-maxseq', str(self.maxseq),
'-realign_max', str(self.realign_max),
'-maxfilt', str(self.maxfilt),
'-min_prefilter_hits', str(self.min_prefilter_hits)]
"-i",
input_fasta_path,
"-cpu",
str(self.n_cpu),
"-oa3m",
a3m_path,
"-o",
"/dev/null",
"-n",
str(self.n_iter),
"-e",
str(self.e_value),
"-maxseq",
str(self.maxseq),
"-realign_max",
str(self.realign_max),
"-maxfilt",
str(self.maxfilt),
"-min_prefilter_hits",
str(self.min_prefilter_hits),
]
if self.all_seqs:
cmd += ['-all']
cmd += ["-all"]
if self.alt:
cmd += ['-alt', str(self.alt)]
cmd += ["-alt", str(self.alt)]
if self.p != _HHBLITS_DEFAULT_P:
cmd += ['-p', str(self.p)]
cmd += ["-p", str(self.p)]
if self.z != _HHBLITS_DEFAULT_Z:
cmd += ['-Z', str(self.z)]
cmd += ["-Z", str(self.z)]
cmd += db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing('HHblits query'):
with utils.timing("HHblits query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Logs have a 15k character limit, so log HHblits error line by line.
logging.error('HHblits failed. HHblits stderr begin:')
for error_line in stderr.decode('utf-8').splitlines():
logging.error("HHblits failed. HHblits stderr begin:")
for error_line in stderr.decode("utf-8").splitlines():
if error_line.strip():
logging.error(error_line.strip())
logging.error('HHblits stderr end')
raise RuntimeError('HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:500_000].decode('utf-8')))
logging.error("HHblits stderr end")
raise RuntimeError(
"HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:500_000].decode("utf-8"))
)
with open(a3m_path) as f:
a3m = f.read()
......@@ -150,5 +170,6 @@ class HHBlits:
output=stdout,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
e_value=self.e_value,
)
return raw_output
......@@ -26,12 +26,14 @@ from openfold.data.np import utils
class HHSearch:
"""Python wrapper of the HHsearch binary."""
def __init__(self,
def __init__(
self,
*,
binary_path: str,
databases: Sequence[str],
n_cpu: int = 2,
maxseq: int = 1_000_000):
maxseq: int = 1_000_000,
):
"""Initializes the Python HHsearch wrapper.
Args:
......@@ -52,41 +54,52 @@ class HHSearch:
self.maxseq = maxseq
for database_path in self.databases:
if not glob.glob(database_path + '_*'):
logging.error('Could not find HHsearch database %s', database_path)
raise ValueError(f'Could not find HHsearch database {database_path}')
if not glob.glob(database_path + "_*"):
logging.error(
"Could not find HHsearch database %s", database_path
)
raise ValueError(
f"Could not find HHsearch database {database_path}"
)
def query(self, a3m: str) -> str:
"""Queries the database using HHsearch using a given a3m."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, 'query.a3m')
hhr_path = os.path.join(query_tmp_dir, 'output.hhr')
with open(input_path, 'w') as f:
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_path = os.path.join(query_tmp_dir, "query.a3m")
hhr_path = os.path.join(query_tmp_dir, "output.hhr")
with open(input_path, "w") as f:
f.write(a3m)
db_cmd = []
for db_path in self.databases:
db_cmd.append('-d')
db_cmd.append("-d")
db_cmd.append(db_path)
cmd = [self.binary_path,
'-i', input_path,
'-o', hhr_path,
'-maxseq', str(self.maxseq),
'-cpu', str(self.n_cpu),
cmd = [
self.binary_path,
"-i",
input_path,
"-o",
hhr_path,
"-maxseq",
str(self.maxseq),
"-cpu",
str(self.n_cpu),
] + db_cmd
logging.info('Launching subprocess "%s"', ' '.join(cmd))
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
with utils.timing('HHsearch query'):
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing("HHsearch query"):
stdout, stderr = process.communicate()
retcode = process.wait()
if retcode:
# Stderr is truncated to prevent proto size errors in Beam.
raise RuntimeError(
'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % (
stdout.decode('utf-8'), stderr[:100_000].decode('utf-8')))
"HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr[:100_000].decode("utf-8"))
)
with open(hhr_path) as f:
hhr = f.read()
......
......@@ -29,7 +29,8 @@ from openfold.data.tools import utils
class Jackhmmer:
"""Python wrapper of the Jackhmmer binary."""
def __init__(self,
def __init__(
self,
*,
binary_path: str,
database_path: str,
......@@ -44,7 +45,8 @@ class Jackhmmer:
incdom_e: Optional[float] = None,
dom_e: Optional[float] = None,
num_streamed_chunks: Optional[int] = None,
streaming_callback: Optional[Callable[[int], None]] = None):
streaming_callback: Optional[Callable[[int], None]] = None,
):
"""Initializes the Python Jackhmmer wrapper.
Args:
......@@ -69,9 +71,14 @@ class Jackhmmer:
self.database_path = database_path
self.num_streamed_chunks = num_streamed_chunks
if not os.path.exists(self.database_path) and num_streamed_chunks is None:
logging.error('Could not find Jackhmmer database %s', database_path)
raise ValueError(f'Could not find Jackhmmer database {database_path}')
if (
not os.path.exists(self.database_path)
and num_streamed_chunks is None
):
logging.error("Could not find Jackhmmer database %s", database_path)
raise ValueError(
f"Could not find Jackhmmer database {database_path}"
)
self.n_cpu = n_cpu
self.n_iter = n_iter
......@@ -85,11 +92,12 @@ class Jackhmmer:
self.get_tblout = get_tblout
self.streaming_callback = streaming_callback
def _query_chunk(self, input_fasta_path: str, database_path: str
def _query_chunk(
self, input_fasta_path: str, database_path: str
) -> Mapping[str, Any]:
"""Queries the database chunk using Jackhmmer."""
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, 'output.sto')
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
sto_path = os.path.join(query_tmp_dir, "output.sto")
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
......@@ -98,48 +106,63 @@ class Jackhmmer:
# amount of time.
cmd_flags = [
# Don't pollute stdout with Jackhmmer output.
'-o', '/dev/null',
'-A', sto_path,
'--noali',
'--F1', str(self.filter_f1),
'--F2', str(self.filter_f2),
'--F3', str(self.filter_f3),
'--incE', str(self.e_value),
"-o",
"/dev/null",
"-A",
sto_path,
"--noali",
"--F1",
str(self.filter_f1),
"--F2",
str(self.filter_f2),
"--F3",
str(self.filter_f3),
"--incE",
str(self.e_value),
# Report only sequences with E-values <= x in per-sequence output.
'-E', str(self.e_value),
'--cpu', str(self.n_cpu),
'-N', str(self.n_iter)
"-E",
str(self.e_value),
"--cpu",
str(self.n_cpu),
"-N",
str(self.n_iter),
]
if self.get_tblout:
tblout_path = os.path.join(query_tmp_dir, 'tblout.txt')
cmd_flags.extend(['--tblout', tblout_path])
tblout_path = os.path.join(query_tmp_dir, "tblout.txt")
cmd_flags.extend(["--tblout", tblout_path])
if self.z_value:
cmd_flags.extend(['-Z', str(self.z_value)])
cmd_flags.extend(["-Z", str(self.z_value)])
if self.dom_e is not None:
cmd_flags.extend(['--domE', str(self.dom_e)])
cmd_flags.extend(["--domE", str(self.dom_e)])
if self.incdom_e is not None:
cmd_flags.extend(['--incdomE', str(self.incdom_e)])
cmd_flags.extend(["--incdomE", str(self.incdom_e)])
cmd = [self.binary_path] + cmd_flags + [input_fasta_path,
database_path]
cmd = (
[self.binary_path]
+ cmd_flags
+ [input_fasta_path, database_path]
)
logging.info('Launching subprocess "%s"', ' '.join(cmd))
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing(
f'Jackhmmer ({os.path.basename(database_path)}) query'):
f"Jackhmmer ({os.path.basename(database_path)}) query"
):
_, stderr = process.communicate()
retcode = process.wait()
if retcode:
raise RuntimeError(
'Jackhmmer failed\nstderr:\n%s\n' % stderr.decode('utf-8'))
"Jackhmmer failed\nstderr:\n%s\n" % stderr.decode("utf-8")
)
# Get e-values for each target name
tbl = ''
tbl = ""
if self.get_tblout:
with open(tblout_path) as f:
tbl = f.read()
......@@ -152,7 +175,8 @@ class Jackhmmer:
tbl=tbl,
stderr=stderr,
n_iter=self.n_iter,
e_value=self.e_value)
e_value=self.e_value,
)
return raw_output
......@@ -162,15 +186,15 @@ class Jackhmmer:
return [self._query_chunk(input_fasta_path, self.database_path)]
db_basename = os.path.basename(self.database_path)
db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}'
db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}'
db_remote_chunk = lambda db_idx: f"{self.database_path}.{db_idx}"
db_local_chunk = lambda db_idx: f"/tmp/ramdisk/{db_basename}.{db_idx}"
# Remove existing files to prevent OOM
for f in glob.glob(db_local_chunk('[0-9]*')):
for f in glob.glob(db_local_chunk("[0-9]*")):
try:
os.remove(f)
except OSError:
print(f'OSError while deleting {f}')
print(f"OSError while deleting {f}")
# Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk
with futures.ThreadPoolExecutor(max_workers=2) as executor:
......@@ -179,15 +203,22 @@ class Jackhmmer:
# Copy the chunk locally
if i == 1:
future = executor.submit(
request.urlretrieve, db_remote_chunk(i), db_local_chunk(i))
request.urlretrieve,
db_remote_chunk(i),
db_local_chunk(i),
)
if i < self.num_streamed_chunks:
next_future = executor.submit(
request.urlretrieve, db_remote_chunk(i+1), db_local_chunk(i+1))
request.urlretrieve,
db_remote_chunk(i + 1),
db_local_chunk(i + 1),
)
# Run Jackhmmer with the chunk
future.result()
chunked_output.append(
self._query_chunk(input_fasta_path, db_local_chunk(i)))
self._query_chunk(input_fasta_path, db_local_chunk(i))
)
# Remove the local copy of the chunk
os.remove(db_local_chunk(i))
......
......@@ -25,12 +25,12 @@ from openfold.data.tools import utils
def _to_a3m(sequences: Sequence[str]) -> str:
"""Converts sequences to an a3m file."""
names = ['sequence %d' % i for i in range(1, len(sequences) + 1)]
names = ["sequence %d" % i for i in range(1, len(sequences) + 1)]
a3m = []
for sequence, name in zip(sequences, names):
a3m.append(u'>' + name + u'\n')
a3m.append(sequence + u'\n')
return ''.join(a3m)
a3m.append(u">" + name + u"\n")
a3m.append(sequence + u"\n")
return "".join(a3m)
class Kalign:
......@@ -63,40 +63,51 @@ class Kalign:
RuntimeError: If Kalign fails.
ValueError: If any of the sequences is less than 6 residues long.
"""
logging.info('Aligning %d sequences', len(sequences))
logging.info("Aligning %d sequences", len(sequences))
for s in sequences:
if len(s) < 6:
raise ValueError('Kalign requires all sequences to be at least 6 '
'residues long. Got %s (%d residues).' % (s, len(s)))
raise ValueError(
"Kalign requires all sequences to be at least 6 "
"residues long. Got %s (%d residues)." % (s, len(s))
)
with utils.tmpdir_manager(base_dir='/tmp') as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta')
output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m')
with utils.tmpdir_manager(base_dir="/tmp") as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, "input.fasta")
output_a3m_path = os.path.join(query_tmp_dir, "output.a3m")
with open(input_fasta_path, 'w') as f:
with open(input_fasta_path, "w") as f:
f.write(_to_a3m(sequences))
cmd = [
self.binary_path,
'-i', input_fasta_path,
'-o', output_a3m_path,
'-format', 'fasta',
"-i",
input_fasta_path,
"-o",
output_a3m_path,
"-format",
"fasta",
]
logging.info('Launching subprocess "%s"', ' '.join(cmd))
process = subprocess.Popen(cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
logging.info('Launching subprocess "%s"', " ".join(cmd))
process = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
with utils.timing('Kalign query'):
with utils.timing("Kalign query"):
stdout, stderr = process.communicate()
retcode = process.wait()
logging.info('Kalign stdout:\n%s\n\nstderr:\n%s\n',
stdout.decode('utf-8'), stderr.decode('utf-8'))
logging.info(
"Kalign stdout:\n%s\n\nstderr:\n%s\n",
stdout.decode("utf-8"),
stderr.decode("utf-8"),
)
if retcode:
raise RuntimeError('Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n'
% (stdout.decode('utf-8'), stderr.decode('utf-8')))
raise RuntimeError(
"Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n"
% (stdout.decode("utf-8"), stderr.decode("utf-8"))
)
with open(output_a3m_path) as f:
a3m = f.read()
......
......@@ -35,11 +35,11 @@ def tmpdir_manager(base_dir: Optional[str] = None):
@contextlib.contextmanager
def timing(msg: str):
logging.info('Started %s', msg)
logging.info("Started %s", msg)
tic = time.time()
yield
toc = time.time()
logging.info('Finished %s in %.3f seconds', msg, toc - tic)
logging.info("Finished %s in %.3f seconds", msg, toc - tic)
def to_date(s: str):
......
......@@ -3,13 +3,14 @@ import glob
import importlib as importlib
_files = glob.glob(os.path.join(os.path.dirname(__file__), "*.py"))
__all__ = [os.path.basename(f)[:-3] for f in _files if os.path.isfile(f) and not f.endswith("__init__.py")]
_modules = [(m, importlib.import_module('.' + m, __name__)) for m in __all__]
__all__ = [
os.path.basename(f)[:-3]
for f in _files
if os.path.isfile(f) and not f.endswith("__init__.py")
]
_modules = [(m, importlib.import_module("." + m, __name__)) for m in __all__]
for _m in _modules:
globals()[_m[0]] = _m[1]
# Avoid needlessly cluttering the global namespace
del _files, _m, _modules
......@@ -26,6 +26,7 @@ class Dropout(nn.Module):
If not in training mode, this module computes the identity function.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
"""
Args:
......@@ -37,7 +38,7 @@ class Dropout(nn.Module):
super(Dropout, self).__init__()
self.r = r
if(type(batch_dim) == int):
if type(batch_dim) == int:
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
......@@ -50,7 +51,7 @@ class Dropout(nn.Module):
compatible with self.batch_dim
"""
shape = list(x.shape)
if(self.batch_dim is not None):
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape)
......@@ -64,6 +65,7 @@ class DropoutRowwise(Dropout):
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
......@@ -72,4 +74,5 @@ class DropoutColumnwise(Dropout):
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
......@@ -27,6 +27,7 @@ class InputEmbedder(nn.Module):
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
......@@ -67,9 +68,7 @@ class InputEmbedder(nn.Module):
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self,
ri: torch.Tensor
):
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
......@@ -86,7 +85,8 @@ class InputEmbedder(nn.Module):
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(self,
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
......@@ -132,14 +132,16 @@ class RecyclingEmbedder(nn.Module):
Implements Algorithm 32.
"""
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs
**kwargs,
):
"""
Args:
......@@ -169,7 +171,8 @@ class RecyclingEmbedder(nn.Module):
self.layer_norm_m = nn.LayerNorm(self.c_m)
self.layer_norm_z = nn.LayerNorm(self.c_z)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
......@@ -188,13 +191,13 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if(self.bins is None):
if self.bins is None:
self.bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device
device=x.device,
)
# [*, N, C_m]
......@@ -205,15 +208,10 @@ class RecyclingEmbedder(nn.Module):
# couldn't find in time.
squared_bins = self.bins ** 2
upper = torch.cat(
[
squared_bins[1:],
squared_bins.new_tensor([self.inf])
], dim=-1
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2,
dim=-1,
keepdims=True
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
......@@ -232,7 +230,9 @@ class TemplateAngleEmbedder(nn.Module):
Implements Algorithm 2, line 7.
"""
def __init__(self,
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
......@@ -253,9 +253,7 @@ class TemplateAngleEmbedder(nn.Module):
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
......@@ -275,7 +273,9 @@ class TemplatePairEmbedder(nn.Module):
Implements Algorithm 2, line 9.
"""
def __init__(self,
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
......@@ -295,7 +295,8 @@ class TemplatePairEmbedder(nn.Module):
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(self,
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
......@@ -316,7 +317,9 @@ class ExtraMSAEmbedder(nn.Module):
Implements Algorithm 2, line 15
"""
def __init__(self,
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
......@@ -335,9 +338,7 @@ class ExtraMSAEmbedder(nn.Module):
self.linear = Linear(self.c_in, self.c_out)
def forward(self,
x: torch.Tensor
) -> torch.Tensor:
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
......
......@@ -45,6 +45,7 @@ class MSATransition(nn.Module):
Implements Algorithm 9
"""
def __init__(self, c_m, n, chunk_size):
"""
Args:
......@@ -71,7 +72,8 @@ class MSATransition(nn.Module):
m = self.linear_2(m) * mask
return m
def forward(self,
def forward(
self,
m: torch.Tensor,
mask: torch.Tensor = None,
) -> torch.Tensor:
......@@ -86,7 +88,7 @@ class MSATransition(nn.Module):
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if(mask is None):
if mask is None:
mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1)
......@@ -94,7 +96,7 @@ class MSATransition(nn.Module):
m = self.layer_norm(m)
inp = {"m": m, "mask": mask}
if(self.chunk_size is not None):
if self.chunk_size is not None:
m = chunk_layer(
self._transition,
inp,
......@@ -108,7 +110,8 @@ class MSATransition(nn.Module):
class EvoformerBlock(nn.Module):
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -136,7 +139,7 @@ class EvoformerBlock(nn.Module):
inf=inf,
)
if(_is_extra_msa_stack):
if _is_extra_msa_stack:
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
......@@ -201,7 +204,8 @@ class EvoformerBlock(nn.Module):
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
......@@ -233,7 +237,9 @@ class EvoformerStack(nn.Module):
Implements Algorithm 6.
"""
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -313,10 +319,11 @@ class EvoformerStack(nn.Module):
)
self.blocks.append(block)
if(not self._is_extra_msa_stack):
if not self._is_extra_msa_stack:
self.linear = Linear(c_m, c_s)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
......@@ -348,14 +355,15 @@ class EvoformerStack(nn.Module):
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans,
) for b in self.blocks
)
for b in self.blocks
],
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = None
if(not self._is_extra_msa_stack):
if not self._is_extra_msa_stack:
seq_dim = -3
index = torch.tensor([0], device=m.device)
s = self.linear(torch.index_select(m, dim=seq_dim, index=index))
......@@ -368,7 +376,9 @@ class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
......@@ -411,12 +421,13 @@ class ExtraMSAStack(nn.Module):
_is_extra_msa_stack=True,
)
def forward(self,
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
......@@ -436,6 +447,6 @@ class ExtraMSAStack(nn.Module):
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=_mask_trans
_mask_trans=_mask_trans,
)
return z
......@@ -44,7 +44,7 @@ class AuxiliaryHeads(nn.Module):
**config["experimentally_resolved"],
)
if(config.tm.enabled):
if config.tm.enabled:
self.tm = TMScoreHead(
**config.tm,
)
......@@ -68,19 +68,22 @@ class AuxiliaryHeads(nn.Module):
experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"]
)
aux_out["experimentally_resolved_logits"] = (
experimentally_resolved_logits
)
aux_out[
"experimentally_resolved_logits"
] = experimentally_resolved_logits
if(self.config.tm.enabled):
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(compute_predicted_aligned_error(
tm_logits, **self.config.tm,
))
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
**self.config.tm,
)
)
return aux_out
......@@ -118,6 +121,7 @@ class DistogramHead(nn.Module):
For use in computation of distogram loss, subsection 1.9.8
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
......@@ -133,9 +137,7 @@ class DistogramHead(nn.Module):
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self,
z # [*, N, N, C_z]
):
def forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
......@@ -153,6 +155,7 @@ class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
......@@ -185,6 +188,7 @@ class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def __init__(self, c_m, c_out, **kwargs):
"""
Args:
......@@ -218,6 +222,7 @@ class ExperimentallyResolvedHead(nn.Module):
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def __init__(self, c_s, c_out, **kwargs):
"""
Args:
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment