"...models/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "52310ad9e6696473c90a659bfa879fde0c51898c"
Commit 02fc4376 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Beef up alignment script

parent 53bb9c10
import argparse import argparse
from functools import partial
import json
import logging import logging
import os import os
import threading
from multiprocessing import cpu_count
from shutil import copyfile
import tempfile import tempfile
import openfold.data.mmcif_parsing as mmcif_parsing import openfold.data.mmcif_parsing as mmcif_parsing
...@@ -10,30 +15,57 @@ from openfold.np import protein, residue_constants ...@@ -10,30 +15,57 @@ from openfold.np import protein, residue_constants
from utils import add_data_args from utils import add_data_args
#python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ data/uniref90/uniref90.fasta data/mgnify/mgy_clusters_2018_12.fa data/pdb70/pdb70 data/pdb_mmcif/mmcif_files/ data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt --cpus 16 --jackhmmer_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/jackhmmer --hhblits_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhblits --hhsearch_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/hhsearch --kalign_binary_path /home/u00u98too4mkqFBu8M357/openfold/lib/conda/envs/openfold_venv/bin/kalign
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.WARNING)
def main(args): def run_seq_group_alignments(seq_groups, alignment_runner, args):
# Build the alignment tool runner dirs = set(os.listdir(args.output_dir))
alignment_runner = AlignmentRunner( for seq, names in seq_groups:
jackhmmer_binary_path=args.jackhmmer_binary_path, first_name = names[0]
hhblits_binary_path=args.hhblits_binary_path, alignment_dir = os.path.join(args.output_dir, first_name)
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path, try:
mgnify_database_path=args.mgnify_database_path, os.makedirs(alignment_dir)
bfd_database_path=args.bfd_database_path, except Exception as e:
uniclust30_database_path=args.uniclust30_database_path, logging.warning(f"Failed to create directory for {first_name} with exception {e}...")
pdb70_database_path=args.pdb70_database_path, continue
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus, fd, fasta_path = tempfile.mkstemp(suffix=".fasta")
with os.fdopen(fd, 'w') as fp:
fp.write(f'>query\n{seq}')
try:
alignment_runner.run(
fasta_path, alignment_dir
) )
except:
logging.warning(f"Failed to run alignments for {first_name}. Skipping...")
os.remove(fasta_path)
os.rmdir(alignment_dir)
continue
os.remove(fasta_path)
for f in os.listdir(args.input_dir): for name in names[1:]:
if(name in dirs):
logging.warning(
f'{name} has already been processed. Skipping...'
)
continue
cp_dir = os.path.join(args.output_dir, name)
os.makedirs(cp_dir)
for f in os.listdir(alignment_dir):
copyfile(os.path.join(alignment_dir, f), os.path.join(cp_dir, f))
def parse_and_align(files, alignment_runner, args):
for f in files:
path = os.path.join(args.input_dir, f) path = os.path.join(args.input_dir, f)
file_id = os.path.splitext(f)[0] file_id = os.path.splitext(f)[0]
seqs = {} seq_group_dict = {}
if(f.endswith('.cif')): if(f.endswith('.cif')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
mmcif_str = fp.read() mmcif_str = fp.read()
...@@ -47,9 +79,10 @@ def main(args): ...@@ -47,9 +79,10 @@ def main(args):
else: else:
continue continue
mmcif = mmcif.mmcif_object mmcif = mmcif.mmcif_object
for k,v in mmcif.chain_to_seqres.items(): for chain_letter, seq in mmcif.chain_to_seqres.items():
chain_id = '_'.join([file_id, k]) chain_id = '_'.join([file_id, chain_letter])
seqs[chain_id] = v l = seq_group_dict.setdefault(seq, [])
l.append(chain_id)
elif(f.endswith('.fasta') or f.endswith('.fa')): elif(f.endswith('.fasta') or f.endswith('.fa')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
fasta_str = fp.read() fasta_str = fp.read()
...@@ -61,7 +94,7 @@ def main(args): ...@@ -61,7 +94,7 @@ def main(args):
else: else:
logging.warning(msg) logging.warning(msg)
input_sequence = input_seqs[0] input_sequence = input_seqs[0]
seqs[file_id] = input_sequence seq_group_dict[input_sequence] = [file_id]
elif(f.endswith('.core')): elif(f.endswith('.core')):
with open(path, 'r') as fp: with open(path, 'r') as fp:
core_str = fp.read() core_str = fp.read()
...@@ -71,27 +104,114 @@ def main(args): ...@@ -71,27 +104,114 @@ def main(args):
residue_constants.restypes_with_x[aatype[i]] residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype)) for i in range(len(aatype))
]) ])
seqs[file_id] = seq seq_group_dict[seq] = [file_id]
else: else:
continue continue
for name, seq in seqs.items(): seq_group_tuples = [(k,v) for k,v in seq_group_dict.items()]
alignment_dir = os.path.join(args.output_dir, name) run_seq_group_alignments(seq_group_tuples, alignment_runner, args)
if(os.path.isdir(alignment_dir)):
logging.info(f'{f} has already been processed. Skipping...')
continue
os.makedirs(alignment_dir)
fd, fasta_path = tempfile.mkstemp(suffix=".fasta") def main(args):
with os.fdopen(fd, 'w') as fp: # Build the alignment tool runner
fp.write(f'>query\n{seq}') alignment_runner = AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
use_small_bfd=args.bfd_database_path is None,
no_cpus=args.cpus_per_task,
)
alignment_runner.run( files = list(os.listdir(args.input_dir))
fasta_path, alignment_dir
# Do some filtering
if(args.mmcif_cache is not None):
with open(args.mmcif_cache, "r") as fp:
cache = json.load(fp)
else:
cache = None
if(cache is not None and args.filter):
dirs = set(os.listdir(args.output_dir))
def prot_is_done(f):
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
chain_ids = cache[prot_id]["chain_ids"]
for c in chain_ids:
full_name = prot_id + "_" + c
if(not full_name in dirs):
return False
else:
return False
return True
files = [f for f in files if not prot_is_done(f)]
def split_up_arglist(arglist):
# Split up the survivors
if(os.environ.get("SLURM_JOB_NUM_NODES", 0)):
num_nodes = int(os.environ["SLURM_JOB_NUM_NODES"])
if(num_nodes > 1):
node_id = int(os.environ["SLURM_NODEID"])
logging.warning(f"Num nodes: {num_nodes}")
logging.warning(f"Node ID: {node_id}")
arglist = arglist[node_id::num_nodes]
t_arglist = []
for i in range(args.no_tasks):
t_arglist.append(arglist[i::args.no_tasks])
return t_arglist
if(cache is not None and "seqs" in next(iter(cache.values()))):
seq_group_dict = {}
for f in files:
prot_id = os.path.splitext(f)[0]
if(prot_id in cache):
prot_cache = cache[prot_id]
chains_seqs = zip(
prot_cache["chain_ids"], prot_cache["seqs"]
)
for chain, seq in chains_seqs:
chain_name = prot_id + "_" + chain
if(chain_name not in dirs):
l = seq_group_dict.setdefault(seq, [])
l.append(chain_name)
func = partial(run_seq_group_alignments,
alignment_runner=alignment_runner,
args=args
) )
os.remove(fasta_path) seq_groups = [(k,v) for k,v in seq_group_dict.items()]
# Sort them by group length so the tasks are approximately balanced
seq_groups = sorted(seq_groups, key=lambda x: len(x[1]))
task_arglist = [[a] for a in split_up_arglist(seq_groups)]
else:
func = partial(parse_and_align,
alignment_runner=alignment_runner,
args=args,
)
task_arglist = [[a] for a in split_up_arglist(files)]
threads = []
for i, task_args in enumerate(task_arglist):
print(f"Started thread {i}...")
t = threading.Thread(target=func, args=task_args)
threads.append(t)
t.start()
for t in threads:
t.join()
if __name__ == "__main__": if __name__ == "__main__":
...@@ -111,9 +231,19 @@ if __name__ == "__main__": ...@@ -111,9 +231,19 @@ if __name__ == "__main__":
help="Whether to crash on parsing errors" help="Whether to crash on parsing errors"
) )
parser.add_argument( parser.add_argument(
"--cpus", type=int, default=4, "--cpus_per_task", type=int, default=cpu_count(),
help="Number of CPUs to use" help="Number of CPUs to use"
) )
parser.add_argument(
"--mmcif_cache", type=str, default=None,
help="Path to mmCIF cache. Used to filter files to be parsed"
)
parser.add_argument(
"--no_tasks", type=int, default=1,
)
parser.add_argument(
"--filter", type=bool, default=True,
)
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment