Commit 2c7ce956 authored by Gustaf's avatar Gustaf
Browse files

Add more ProteinNet functionality

parent 78fa6c6e
...@@ -167,8 +167,8 @@ python3 scripts/precompute_alignments_mmseqs.py input.fasta \ ...@@ -167,8 +167,8 @@ python3 scripts/precompute_alignments_mmseqs.py input.fasta \
``` ```
where `input.fasta` is a FASTA file containing one or more query sequences. To where `input.fasta` is a FASTA file containing one or more query sequences. To
generate an input FASTA from a directory of mmCIF files, we provide generate an input FASTA from a directory of mmCIF and/or ProteinNet .core
`scripts/mmcif_dir_to_fasta.py`. files, we provide `scripts/data_dir_to_fasta.py`.
Next, generate a cache of certain datapoints in the mmCIF files: Next, generate a cache of certain datapoints in the mmCIF files:
......
...@@ -152,16 +152,19 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -152,16 +152,19 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
file_id, = spl file_id, = spl
chain_id = None chain_id = None
path = os.path.join(self.data_dir, file_id + '.cif') path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path)): if(os.path.exists(path + ".cif")):
data = self._parse_mmcif( data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir path + ".cif", file_id, chain_id, alignment_dir
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir
) )
else: else:
# Try to search for a distillation PDB file instead # Try to search for a distillation PDB file instead
path = os.path.join(self.data_dir, file_id + '.pdb')
data = self.data_pipeline.process_pdb( data = self.data_pipeline.process_pdb(
pdb_path=path, pdb_path=path + ".pdb",
alignment_dir=alignment_dir alignment_dir=alignment_dir
) )
else: else:
......
...@@ -123,8 +123,9 @@ def make_mmcif_features( ...@@ -123,8 +123,9 @@ def make_mmcif_features(
def _aatype_to_str_sequence(aatype): def _aatype_to_str_sequence(aatype):
return str([ return ''.join([
residue_constants.restypes[aatype[i]] for i in range(len(aatype)) residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
]) ])
def make_protein_features( def make_protein_features(
......
...@@ -3,36 +3,47 @@ import logging ...@@ -3,36 +3,47 @@ import logging
import os import os
from openfold.data import mmcif_parsing from openfold.data import mmcif_parsing
from openfold.np import protein, residue_constants
def main(args): def main(args):
fasta = [] fasta = []
for fname in os.listdir(args.mmcif_dir): for fname in os.listdir(args.data_dir):
basename, ext = os.path.splitext(fname) basename, ext = os.path.splitext(fname)
basename = basename.upper() basename = basename.upper()
fpath = os.path.join(args.data_dir, fname)
if(ext == ".cif"):
with open(fpath, 'r') as fp:
mmcif_str = fp.read()
mmcif = mmcif_parsing.parse(
file_id=basename, mmcif_string=mmcif_str
)
if(mmcif.mmcif_object is None):
logging.warning(f'Failed to parse {fname}...')
if(args.raise_errors):
raise list(mmcif.errors.values())[0]
else:
continue
if(not ext == ".cif"): mmcif = mmcif.mmcif_object
continue for chain, seq in mmcif.chain_to_seqres.items():
chain_id = '_'.join([basename, chain])
fpath = os.path.join(args.mmcif_dir, fname) fasta.append(f">{chain_id}")
with open(fpath, 'r') as fp: fasta.append(seq)
mmcif_str = fp.read() elif(ext == ".core"):
with open(fpath, 'r') as fp:
mmcif = mmcif_parsing.parse( core_str = fp.read()
file_id=basename, mmcif_string=mmcif_str
) core_protein = protein.from_proteinnet_string(core_str)
if(mmcif.mmcif_object is None): aatype = core_protein.aatype
logging.warning(f'Failed to parse {fname}...') seq = ''.join([
if(args.raise_errors): residue_constants.restypes_with_x[aatype[i]]
raise list(mmcif.errors.values())[0] for i in range(len(aatype))
else: ])
continue fasta.append(f">{basename}")
mmcif = mmcif.mmcif_object
for chain, seq in mmcif.chain_to_seqres.items():
chain_id = '_'.join([basename, chain])
fasta.append(f">{chain_id}")
fasta.append(seq) fasta.append(seq)
with open(args.output_path, "w") as fp: with open(args.output_path, "w") as fp:
fp.write('\n'.join(fasta)) fp.write('\n'.join(fasta))
...@@ -41,8 +52,8 @@ def main(args): ...@@ -41,8 +52,8 @@ def main(args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"mmcif_dir", type=str, "data_dir", type=str,
help="Path to a directory containing mmCIF files" help="Path to a directory containing mmCIF or .core files"
) )
parser.add_argument( parser.add_argument(
"output_path", type=str, "output_path", type=str,
......
...@@ -5,6 +5,7 @@ import tempfile ...@@ -5,6 +5,7 @@ import tempfile
import openfold.data.mmcif_parsing as mmcif_parsing import openfold.data.mmcif_parsing as mmcif_parsing
from openfold.data.data_pipeline import AlignmentRunner from openfold.data.data_pipeline import AlignmentRunner
from openfold.np import protein, residue_constants
from utils import add_data_args from utils import add_data_args
...@@ -31,11 +32,9 @@ def main(args): ...@@ -31,11 +32,9 @@ def main(args):
for f in os.listdir(args.input_dir): for f in os.listdir(args.input_dir):
path = os.path.join(args.input_dir, f) path = os.path.join(args.input_dir, f)
is_mmcif = f.endswith('.cif')
is_fasta = f.endswith('.fasta')
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
seqs = {} seqs = {}
if(is_mmcif): if(f.endswith('.cif')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
mmcif_str = fp.read() mmcif_str = fp.read()
mmcif = mmcif_parsing.parse( mmcif = mmcif_parsing.parse(
...@@ -51,7 +50,7 @@ def main(args): ...@@ -51,7 +50,7 @@ def main(args):
for k,v in mmcif.chain_to_seqres.items(): for k,v in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k]) chain_id = '_'.join([file_id, k])
seqs[chain_id] = v seqs[chain_id] = v
elif(is_fasta): elif(f.endswith('.fasta')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
fasta_str = fp.read() fasta_str = fp.read()
input_seqs, _ = parsers.parse_fasta(fasta_str) input_seqs, _ = parsers.parse_fasta(fasta_str)
...@@ -63,6 +62,15 @@ def main(args): ...@@ -63,6 +62,15 @@ def main(args):
logging.warning(msg) logging.warning(msg)
input_sequence = input_seqs[0] input_sequence = input_seqs[0]
seqs[file_id] = input_sequence seqs[file_id] = input_sequence
elif(f.endswith('.core')):
with open(path, 'r') as fp:
core_str = fp.read()
core_prot = protein.from_proteinnet_string(core_str)
seq = ''.join([
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
seqs[file_id] = seq
else: else:
continue continue
...@@ -74,17 +82,15 @@ def main(args): ...@@ -74,17 +82,15 @@ def main(args):
os.makedirs(alignment_dir) os.makedirs(alignment_dir)
if(not is_fasta): fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
fd, fasta_path = tempfile.mkstemp(suffix=".fasta") with os.fdopen(fd, 'w') as fp:
with os.fdopen(fd, 'w') as fp: fp.write(f'>query\n{seq}')
fp.write(f'>query\n{seq}')
alignment_runner.run( alignment_runner.run(
f if is_fasta else fasta_path, alignment_dir fasta_path, alignment_dir
) )
if(not is_fasta): os.remove(fasta_path)
os.remove(fasta_path)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -114,8 +114,6 @@ def main(args): ...@@ -114,8 +114,6 @@ def main(args):
not os.path.splitext(fname)[-1] == ".a3m"): not os.path.splitext(fname)[-1] == ".a3m"):
continue continue
print(fpath)
with open(fpath, "r") as fp: with open(fpath, "r") as fp:
a3m = fp.read() a3m = fp.read()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment