Commit 86b990d6 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Separate out input parsing code in `EmbeddingGenerator`

Bugfix: Corrected paths for just-in-time embedding generation
parent 8185c307
...@@ -84,7 +84,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -84,7 +84,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
no_cpus=args.cpus, no_cpus=args.cpus,
) )
embedding_generator = EmbeddingGenerator() embedding_generator = EmbeddingGenerator()
embedding_generator.run(args.fasta_dir, local_alignment_dir) embedding_generator.run(tmp_fasta_path, alignment_dir)
else: else:
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
......
...@@ -79,12 +79,8 @@ class EmbeddingGenerator: ...@@ -79,12 +79,8 @@ class EmbeddingGenerator:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S") self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu: if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda") self.model = self.model.to(device="cuda")
def run( def parse_sequences(self, fasta_dir, output_dir):
self,
fasta_dir,
output_dir,
):
labels = [] labels = []
seqs = [] seqs = []
...@@ -107,8 +103,15 @@ class EmbeddingGenerator: ...@@ -107,8 +103,15 @@ class EmbeddingGenerator:
temp_fasta_file = os.path.join(output_dir, 'temp.fasta') temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
with open(temp_fasta_file, 'w') as outfile: with open(temp_fasta_file, 'w') as outfile:
outfile.writelines(lines) outfile.writelines(lines)
return temp_fasta_file
def run(
self,
fasta_file,
output_dir,
):
dataset = SequenceDataset.from_file(temp_fasta_file) dataset = SequenceDataset.from_file(fasta_file)
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1) batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
...@@ -143,7 +146,6 @@ class EmbeddingGenerator: ...@@ -143,7 +146,6 @@ class EmbeddingGenerator:
os.path.join(output_dir, label, label+".pt") os.path.join(output_dir, label, label+".pt")
) )
os.remove(temp_fasta_file)
def main(args): def main(args):
...@@ -154,10 +156,15 @@ def main(args): ...@@ -154,10 +156,15 @@ def main(args):
args.use_local_esm, args.use_local_esm,
args.nogpu) args.nogpu)
logging.info("Loading the sequences and running the inference...") logging.info("Loading the sequences and running the inference...")
embedding_generator.run( temp_fasta_file = embedding_generator.parse_sequences(
args.fasta_dir, args.fasta_dir,
args.output_dir args.output_dir
) )
embedding_generator.run(
temp_fasta_file,
args.output_dir
)
os.remove(temp_fasta_file)
logging.info("Completed.") logging.info("Completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment