Commit 86b990d6 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Separate out input parsing code in `EmbeddingGenerator`

Bugfix: Corrected paths for just-in-time embedding generation
parent 8185c307
......@@ -84,7 +84,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
no_cpus=args.cpus,
)
embedding_generator = EmbeddingGenerator()
embedding_generator.run(args.fasta_dir, local_alignment_dir)
embedding_generator.run(tmp_fasta_path, alignment_dir)
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
......
......@@ -79,12 +79,8 @@ class EmbeddingGenerator:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda")
def run(
self,
fasta_dir,
output_dir,
):
def parse_sequences(self, fasta_dir, output_dir):
labels = []
seqs = []
......@@ -107,8 +103,15 @@ class EmbeddingGenerator:
temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
with open(temp_fasta_file, 'w') as outfile:
outfile.writelines(lines)
return temp_fasta_file
def run(
self,
fasta_file,
output_dir,
):
dataset = SequenceDataset.from_file(temp_fasta_file)
dataset = SequenceDataset.from_file(fasta_file)
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
......@@ -143,7 +146,6 @@ class EmbeddingGenerator:
os.path.join(output_dir, label, label+".pt")
)
os.remove(temp_fasta_file)
def main(args):
......@@ -154,10 +156,15 @@ def main(args):
args.use_local_esm,
args.nogpu)
logging.info("Loading the sequences and running the inference...")
embedding_generator.run(
temp_fasta_file = embedding_generator.parse_sequences(
args.fasta_dir,
args.output_dir
)
embedding_generator.run(
temp_fasta_file,
args.output_dir
)
os.remove(temp_fasta_file)
logging.info("Completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment