Commit 8185c307 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Just-in-time embedding generation for the SoloSeq model

parent 4c8e3764
......@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq,
trace_model_,
)
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args
......@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus,
)
embedding_generator = EmbeddingGenerator()
embedding_generator.run(args.fasta_dir, local_alignment_dir)
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
......
......@@ -58,17 +58,43 @@ class SequenceDataset(object):
_flush_current_buf()
return batches
def main(args):
class EmbeddingGenerator:
"""Generates the ESM-1b embeddings for the single sequence model"""
def __init__(self,
toks_per_batch: int = 4096,
truncate: bool = True,
use_local_esm: str = None,
nogpu: bool = False,
):
self.toks_per_batch = toks_per_batch
self.truncate = truncate
self.use_local_esm = use_local_esm
self.nogpu = nogpu
# Generate embeddings in bulk
if self.use_local_esm:
self.model, self.alphabet = torch.hub.load(self.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda")
def run(
self,
fasta_dir,
output_dir,
):
labels = []
seqs = []
# Generate a single bulk file
for f in os.listdir(args.fasta_dir):
for f in os.listdir(fasta_dir):
f_name, ext = os.path.splitext(f)
if ext != '.fasta' and ext != '.fa':
logging.warning(f"Ignoring non-FASTA file: {f}")
continue
with open(os.path.join(args.fasta_dir, f), 'r') as infile:
with open(os.path.join(fasta_dir, f), 'r') as infile:
seq = infile.readlines()[1].strip()
labels.append(f_name)
seqs.append(seq)
......@@ -77,22 +103,15 @@ def main(args):
for label, seq in zip(labels, seqs):
lines += f'>{label}\n'
lines += f'{seq}\n'
os.makedirs(args.output_dir, exist_ok=True)
temp_fasta_file = os.path.join(args.output_dir, 'temp.fasta')
os.makedirs(output_dir, exist_ok=True)
temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
with open(temp_fasta_file, 'w') as outfile:
outfile.writelines(lines)
# Generate embeddings in bulk
if args.use_local_esm:
model, alphabet = torch.hub.load(args.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not args.nogpu:
model = model.to(device="cuda")
dataset = SequenceDataset.from_file(temp_fasta_file)
batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches
dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
)
logging.info("Loaded all sequences")
repr_layers = [33]
......@@ -100,21 +119,20 @@ def main(args):
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available() and not args.nogpu:
if torch.cuda.is_available() and not self.nogpu:
toks = toks.to(device="cuda", non_blocking=True)
if args.truncate:
if self.truncate:
toks = toks[:1022]
out = model(toks, repr_layers=repr_layers, return_contacts=False)
out = self.model(toks, repr_layers=repr_layers, return_contacts=False)
logits = out["logits"].to(device="cpu")
representations = {
33: out["representations"][33].to(device="cpu")
}
for i, label in enumerate(labels):
os.makedirs(os.path.join(args.output_dir, label), exist_ok=True)
os.makedirs(os.path.join(output_dir, label), exist_ok=True)
result = {"label": label}
result["representations"] = {
......@@ -122,10 +140,24 @@ def main(args):
}
torch.save(
result,
os.path.join(args.output_dir, label, label+".pt")
os.path.join(output_dir, label, label+".pt")
)
os.remove(temp_fasta_file)
def main(args):
logging.info("Loading the model...")
embedding_generator = EmbeddingGenerator(
args.toks_per_batch,
args.truncate,
args.use_local_esm,
args.nogpu)
logging.info("Loading the sequences and running the inference...")
embedding_generator.run(
args.fasta_dir,
args.output_dir
)
logging.info("Completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment