Commit b5f885d0 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add ProteinNet support to parser

parent 94ab346e
......@@ -17,9 +17,9 @@ Try it out with our [Colab notebook](https://colab.research.google.com/github/aq
(not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with or without [DeepSpeed](https://github.com/microsoft/deepspeed) and with
mixed precision. `bfloat16` training is not currently supported, but will be
in the future.
with [DeepSpeed](https://github.com/microsoft/deepspeed) and with mixed
precision. `bfloat16` training is not currently supported, but will be in the
future.
## Installation (Linux)
......
......@@ -46,11 +46,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
output by an AlignmentRunner
(defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing:
* bfd_uniclust_hits.a3m/small_bfd_hits.sto
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
files.
config:
A dataset config object. See openfold.config
mapping_path:
......@@ -97,7 +94,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
)
if(not self.output_raw):
......
......@@ -260,47 +260,65 @@ class DataPipeline:
def __init__(
self,
template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
):
self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output(
def _parse_msa_data(
self,
alignment_dir: str,
) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m")
with open(uniref90_out_path, "r") as f:
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read())
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "r") as f:
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read())
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr")
with open(pdb70_out_path, "r") as f:
hhsearch_hits = parsers.parse_hhr(f.read())
if self.use_small_bfd:
bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
msa_data = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".a3m"):
with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
fp.read()
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
continue
return {
"uniref90_msa": uniref90_msa,
"uniref90_deletion_matrix": uniref90_deletion_matrix,
"mgnify_msa": mgnify_msa,
"mgnify_deletion_matrix": mgnify_deletion_matrix,
"hhsearch_hits": hhsearch_hits,
"bfd_msa": bfd_msa,
"bfd_deletion_matrix": bfd_deletion_matrix,
}
msa_data[f] = data
return msa_data
def _parse_template_hits(
self,
alignment_dir: str,
) -> Mapping[str, Any]:
all_hits = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits
def _process_msa_feats(
self,
alignment_dir: str,
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir)
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def process_fasta(
self,
......@@ -319,13 +337,13 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
alignments = self._parse_alignment_output(alignment_dir)
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=alignments["hhsearch_hits"],
hits=hits_cat,
)
sequence_features = make_sequence_features(
......@@ -334,18 +352,8 @@ class DataPipeline:
num_res=num_res,
)
msa_features = make_msa_features(
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
msa_features = self._process_msa_feats(alignment_dir)
return {
**sequence_features,
**msa_features,
......@@ -373,28 +381,18 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
print(len(hits_cat))
templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments["hhsearch_hits"],
hits=hits_cat,
)
msa_features = make_msa_features(
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
msa_features = self._process_msa_feats(alignment_dir)
return {**mmcif_feats, **templates_result.features, **msa_features}
......@@ -413,26 +411,15 @@ class DataPipeline:
pdb_feats = make_pdb_features(protein_object)
alignments = self._parse_alignment_output(alignment_dir)
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
templates_result = self.template_featurizer.get_templates(
query_sequence=protein_object.aatype,
query_sequence=input_sequence,
query_pdb_code=None,
query_release_date=None,
hits=alignments["hhsearch_hits"],
hits=hits_cat,
)
msa_features = make_msa_features(
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
msa_features = self._process_msa_feats(alignment_dir)
return {**pdb_feats, **templates_result.features, **msa_features}
import argparse
import logging
import os
import shutil
def main(args):
count = 0
max_count = args.max_count if args.max_count is not None else -1
msas = sorted(f for f in os.listdir(args.msa_dir))
mmcifs = sorted(f for f in os.listdir(args.mmcif_dir))
mmcif_idx = 0
for f in msas:
if(count == max_count):
break
path = os.path.join(args.msa_dir, f)
name = os.path.splitext(f)[0]
spl = name.upper().split('_')
if(len(spl) != 3):
continue
pdb_id, _, chain_id = spl
while pdb_id > os.path.splitext(mmcifs[mmcif_idx])[0].upper():
mmcif_idx += 1
# Only consider files with matching mmCIF files
if(pdb_id == os.path.splitext(mmcifs[mmcif_idx])[0].upper()):
dirname = os.path.join(args.out_dir, '_'.join([pdb_id, chain_id]))
os.makedirs(dirname, exist_ok=True)
dest = os.path.join(dirname, f)
if(args.copy):
shutil.copyfile(path, dest)
else:
os.rename(path, dest)
count += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"msa_dir", type=str, help="Directory containing ProteinNet MSAs"
)
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing PDB mmCIFs"
)
parser.add_argument(
"out_dir", type=str,
help="Directory to which output should be saved"
)
parser.add_argument(
"--copy", type=bool, default=True,
help="Whether to copy the MSAs to out_dir rather than moving them"
)
parser.add_argument(
"--max_count", type=int, default=None,
help="A bound on the number of MSAs to process"
)
args = parser.parse_args()
main(args)
......@@ -2,7 +2,7 @@ import argparse
import logging
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "6"
#os.environ["CUDA_VISIBLE_DEVICES"] = "5"
#os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0"
......@@ -14,7 +14,6 @@ import time
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment