Commit 294c04ad authored by Brian Loyal's avatar Brian Loyal
Browse files

Updated code to and fix bad fasta inputs

parent 7cf3125e
......@@ -29,6 +29,7 @@ import random
import sys
import time
import torch
import re
from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
......@@ -163,6 +164,7 @@ def prep_output(out, batch, feature_dict, feature_processor, args):
def parse_fasta(data):
data = re.sub('>$', '', data, flags=re.M)
lines = [
l.replace('\n', '')
for prot in data.split('>') for l in prot.strip().split('\n', 1)
......@@ -272,6 +274,8 @@ def load_models_from_command_line(args, config):
"be specified."
)
def list_files_with_extensions(dir, extensions):
return [f for f in os.listdir(dir) if f.endswith(extensions)]
def main(args):
# Create the output directory
......@@ -307,19 +311,13 @@ def main(args):
prediction_dir = os.path.join(args.output_dir, "predictions")
os.makedirs(prediction_dir, exist_ok=True)
for fasta_file in [f for f in os.listdir(args.fasta_dir) if os.path.splitext(f)[1] in [".fasta", ".fa"]]:
for fasta_file in list_files_with_extensions(args.fasta_dir, (".fasta", ".fa")):
# Gather input sequences
with open(os.path.join(args.fasta_dir, fasta_file), "r") as fp:
data = fp.read()
tags = []
seqs = []
for prot in data.split(">")[1::]:
lines = prot.strip().split("\n")
tags.append(lines[0].strip().split()[0])
seqs.append("".join(lines[1:]))
assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tags, seqs = parse_fasta(data)
# assert len(tags) == len(set(tags)), "All FASTA tags must be unique"
tag = '-'.join(tags)
output_name = f'{tag}_{args.config_preset}'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment