Commit b5f885d0 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add ProteinNet support to parser

parent 94ab346e
...@@ -17,9 +17,9 @@ Try it out with our [Colab notebook](https://colab.research.google.com/github/aq ...@@ -17,9 +17,9 @@ Try it out with our [Colab notebook](https://colab.research.google.com/github/aq
(not yet visible from Colab because the repo is still private). (not yet visible from Colab because the repo is still private).
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with or without [DeepSpeed](https://github.com/microsoft/deepspeed) and with with [DeepSpeed](https://github.com/microsoft/deepspeed) and with mixed
mixed precision. `bfloat16` training is not currently supported, but will be precision. `bfloat16` training is not currently supported, but will be in the
in the future. future.
## Installation (Linux) ## Installation (Linux)
......
...@@ -46,11 +46,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -46,11 +46,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
output by an AlignmentRunner output by an AlignmentRunner
(defined in openfold.features.alignment_runner). (defined in openfold.features.alignment_runner).
I.e. a directory of directories named {PDB_ID}_{CHAIN_ID} I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
or simply {PDB_ID}, each containing: or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
* bfd_uniclust_hits.a3m/small_bfd_hits.sto files.
* mgnify_hits.a3m
* pdb70_hits.hhr
* uniref90_hits.a3m
config: config:
A dataset config object. See openfold.config A dataset config object. See openfold.config
mapping_path: mapping_path:
...@@ -97,7 +94,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -97,7 +94,6 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.data_pipeline = data_pipeline.DataPipeline( self.data_pipeline = data_pipeline.DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
use_small_bfd=use_small_bfd,
) )
if(not self.output_raw): if(not self.output_raw):
......
...@@ -260,47 +260,65 @@ class DataPipeline: ...@@ -260,47 +260,65 @@ class DataPipeline:
def __init__( def __init__(
self, self,
template_featurizer: templates.TemplateHitFeaturizer, template_featurizer: templates.TemplateHitFeaturizer,
use_small_bfd: bool,
): ):
self.template_featurizer = template_featurizer self.template_featurizer = template_featurizer
self.use_small_bfd = use_small_bfd
def _parse_alignment_output( def _parse_msa_data(
self, self,
alignment_dir: str, alignment_dir: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
uniref90_out_path = os.path.join(alignment_dir, "uniref90_hits.a3m") msa_data = {}
with open(uniref90_out_path, "r") as f: for f in os.listdir(alignment_dir):
uniref90_msa, uniref90_deletion_matrix = parsers.parse_a3m(f.read()) path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
mgnify_out_path = os.path.join(alignment_dir, "mgnify_hits.a3m")
with open(mgnify_out_path, "r") as f: if(ext == ".a3m"):
mgnify_msa, mgnify_deletion_matrix = parsers.parse_a3m(f.read()) with open(path, "r") as fp:
msa, deletion_matrix = parsers.parse_a3m(fp.read())
pdb70_out_path = os.path.join(alignment_dir, "pdb70_hits.hhr") data = {"msa": msa, "deletion_matrix": deletion_matrix}
with open(pdb70_out_path, "r") as f: elif(ext == ".sto"):
hhsearch_hits = parsers.parse_hhr(f.read()) with open(path, "r") as fp:
msa, deletion_matrix, _ = parsers.parse_stockholm(
if self.use_small_bfd: fp.read()
bfd_out_path = os.path.join(alignment_dir, "small_bfd_hits.sto")
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix, _ = parsers.parse_stockholm(
f.read()
) )
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else: else:
bfd_out_path = os.path.join(alignment_dir, "bfd_uniclust_hits.a3m") continue
with open(bfd_out_path, "r") as f:
bfd_msa, bfd_deletion_matrix = parsers.parse_a3m(f.read())
return { msa_data[f] = data
"uniref90_msa": uniref90_msa,
"uniref90_deletion_matrix": uniref90_deletion_matrix, return msa_data
"mgnify_msa": mgnify_msa,
"mgnify_deletion_matrix": mgnify_deletion_matrix, def _parse_template_hits(
"hhsearch_hits": hhsearch_hits, self,
"bfd_msa": bfd_msa, alignment_dir: str,
"bfd_deletion_matrix": bfd_deletion_matrix, ) -> Mapping[str, Any]:
} all_hits = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if(ext == ".hhr"):
with open(path, "r") as fp:
hits = parsers.parse_hhr(fp.read())
all_hits[f] = hits
return all_hits
def _process_msa_feats(
self,
alignment_dir: str,
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir)
msas, deletion_matrices = zip(*[
(v["msa"], v["deletion_matrix"]) for v in msa_data.values()
])
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def process_fasta( def process_fasta(
self, self,
...@@ -319,13 +337,13 @@ class DataPipeline: ...@@ -319,13 +337,13 @@ class DataPipeline:
input_description = input_descs[0] input_description = input_descs[0]
num_res = len(input_sequence) num_res = len(input_sequence)
alignments = self._parse_alignment_output(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hits=alignments["hhsearch_hits"], hits=hits_cat,
) )
sequence_features = make_sequence_features( sequence_features = make_sequence_features(
...@@ -334,18 +352,8 @@ class DataPipeline: ...@@ -334,18 +352,8 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = make_msa_features( msa_features = self._process_msa_feats(alignment_dir)
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
return { return {
**sequence_features, **sequence_features,
**msa_features, **msa_features,
...@@ -373,28 +381,18 @@ class DataPipeline: ...@@ -373,28 +381,18 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id) mmcif_feats = make_mmcif_features(mmcif, chain_id)
alignments = self._parse_alignment_output(alignment_dir)
input_sequence = mmcif.chain_to_seqres[chain_id] input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
print(len(hits_cat))
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=input_sequence, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=to_date(mmcif.header["release_date"]), query_release_date=to_date(mmcif.header["release_date"]),
hits=alignments["hhsearch_hits"], hits=hits_cat,
) )
msa_features = make_msa_features( msa_features = self._process_msa_feats(alignment_dir)
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
return {**mmcif_feats, **templates_result.features, **msa_features} return {**mmcif_feats, **templates_result.features, **msa_features}
...@@ -413,26 +411,15 @@ class DataPipeline: ...@@ -413,26 +411,15 @@ class DataPipeline:
pdb_feats = make_pdb_features(protein_object) pdb_feats = make_pdb_features(protein_object)
alignments = self._parse_alignment_output(alignment_dir) hits = self._parse_template_hits(alignment_dir)
hits_cat = sum(hits.values(), [])
templates_result = self.template_featurizer.get_templates( templates_result = self.template_featurizer.get_templates(
query_sequence=protein_object.aatype, query_sequence=input_sequence,
query_pdb_code=None, query_pdb_code=None,
query_release_date=None, query_release_date=None,
hits=alignments["hhsearch_hits"], hits=hits_cat,
) )
msa_features = make_msa_features( msa_features = self._process_msa_feats(alignment_dir)
msas=(
alignments["uniref90_msa"],
alignments["bfd_msa"],
alignments["mgnify_msa"],
),
deletion_matrices=(
alignments["uniref90_deletion_matrix"],
alignments["bfd_deletion_matrix"],
alignments["mgnify_deletion_matrix"],
),
)
return {**pdb_feats, **templates_result.features, **msa_features} return {**pdb_feats, **templates_result.features, **msa_features}
import argparse
import logging
import os
import shutil
def main(args):
count = 0
max_count = args.max_count if args.max_count is not None else -1
msas = sorted(f for f in os.listdir(args.msa_dir))
mmcifs = sorted(f for f in os.listdir(args.mmcif_dir))
mmcif_idx = 0
for f in msas:
if(count == max_count):
break
path = os.path.join(args.msa_dir, f)
name = os.path.splitext(f)[0]
spl = name.upper().split('_')
if(len(spl) != 3):
continue
pdb_id, _, chain_id = spl
while pdb_id > os.path.splitext(mmcifs[mmcif_idx])[0].upper():
mmcif_idx += 1
# Only consider files with matching mmCIF files
if(pdb_id == os.path.splitext(mmcifs[mmcif_idx])[0].upper()):
dirname = os.path.join(args.out_dir, '_'.join([pdb_id, chain_id]))
os.makedirs(dirname, exist_ok=True)
dest = os.path.join(dirname, f)
if(args.copy):
shutil.copyfile(path, dest)
else:
os.rename(path, dest)
count += 1
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"msa_dir", type=str, help="Directory containing ProteinNet MSAs"
)
parser.add_argument(
"mmcif_dir", type=str, help="Directory containing PDB mmCIFs"
)
parser.add_argument(
"out_dir", type=str,
help="Directory to which output should be saved"
)
parser.add_argument(
"--copy", type=bool, default=True,
help="Whether to copy the MSAs to out_dir rather than moving them"
)
parser.add_argument(
"--max_count", type=int, default=None,
help="A bound on the number of MSAs to process"
)
args = parser.parse_args()
main(args)
...@@ -2,7 +2,7 @@ import argparse ...@@ -2,7 +2,7 @@ import argparse
import logging import logging
import os import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "6" #os.environ["CUDA_VISIBLE_DEVICES"] = "5"
#os.environ["MASTER_ADDR"]="10.119.81.14" #os.environ["MASTER_ADDR"]="10.119.81.14"
#os.environ["MASTER_PORT"]="42069" #os.environ["MASTER_PORT"]="42069"
#os.environ["NODE_RANK"]="0" #os.environ["NODE_RANK"]="0"
...@@ -14,7 +14,6 @@ import time ...@@ -14,7 +14,6 @@ import time
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.plugins import DDPPlugin
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin
import torch import torch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment