Commit 8185c307 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Just-in-time embedding generation for the SoloSeq model

parent 4c8e3764
...@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import ( ...@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq, pad_feature_dict_seq,
trace_model_, trace_model_,
) )
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args from scripts.utils import add_data_args
...@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
pdb70_database_path=args.pdb70_database_path, pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus, no_cpus=args.cpus,
) )
embedding_generator = EmbeddingGenerator()
embedding_generator.run(args.fasta_dir, local_alignment_dir)
else: else:
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
......
...@@ -58,17 +58,43 @@ class SequenceDataset(object): ...@@ -58,17 +58,43 @@ class SequenceDataset(object):
_flush_current_buf() _flush_current_buf()
return batches return batches
def main(args):
class EmbeddingGenerator:
"""Generates the ESM-1b embeddings for the single sequence model"""
def __init__(self,
toks_per_batch: int = 4096,
truncate: bool = True,
use_local_esm: str = None,
nogpu: bool = False,
):
self.toks_per_batch = toks_per_batch
self.truncate = truncate
self.use_local_esm = use_local_esm
self.nogpu = nogpu
# Generate embeddings in bulk
if self.use_local_esm:
self.model, self.alphabet = torch.hub.load(self.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda")
def run(
self,
fasta_dir,
output_dir,
):
labels = [] labels = []
seqs = [] seqs = []
# Generate a single bulk file # Generate a single bulk file
for f in os.listdir(args.fasta_dir): for f in os.listdir(fasta_dir):
f_name, ext = os.path.splitext(f) f_name, ext = os.path.splitext(f)
if ext != '.fasta' and ext != '.fa': if ext != '.fasta' and ext != '.fa':
logging.warning(f"Ignoring non-FASTA file: {f}") logging.warning(f"Ignoring non-FASTA file: {f}")
continue continue
with open(os.path.join(args.fasta_dir, f), 'r') as infile: with open(os.path.join(fasta_dir, f), 'r') as infile:
seq = infile.readlines()[1].strip() seq = infile.readlines()[1].strip()
labels.append(f_name) labels.append(f_name)
seqs.append(seq) seqs.append(seq)
...@@ -77,22 +103,15 @@ def main(args): ...@@ -77,22 +103,15 @@ def main(args):
for label, seq in zip(labels, seqs): for label, seq in zip(labels, seqs):
lines += f'>{label}\n' lines += f'>{label}\n'
lines += f'{seq}\n' lines += f'{seq}\n'
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
temp_fasta_file = os.path.join(args.output_dir, 'temp.fasta') temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
with open(temp_fasta_file, 'w') as outfile: with open(temp_fasta_file, 'w') as outfile:
outfile.writelines(lines) outfile.writelines(lines)
# Generate embeddings in bulk
if args.use_local_esm:
model, alphabet = torch.hub.load(args.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not args.nogpu:
model = model.to(device="cuda")
dataset = SequenceDataset.from_file(temp_fasta_file) dataset = SequenceDataset.from_file(temp_fasta_file)
batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader( data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
) )
logging.info("Loaded all sequences") logging.info("Loaded all sequences")
repr_layers = [33] repr_layers = [33]
...@@ -100,21 +119,20 @@ def main(args): ...@@ -100,21 +119,20 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader): for batch_idx, (labels, strs, toks) in enumerate(data_loader):
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)") logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available() and not args.nogpu: if torch.cuda.is_available() and not self.nogpu:
toks = toks.to(device="cuda", non_blocking=True) toks = toks.to(device="cuda", non_blocking=True)
if args.truncate: if self.truncate:
toks = toks[:1022] toks = toks[:1022]
out = model(toks, repr_layers=repr_layers, return_contacts=False) out = self.model(toks, repr_layers=repr_layers, return_contacts=False)
logits = out["logits"].to(device="cpu")
representations = { representations = {
33: out["representations"][33].to(device="cpu") 33: out["representations"][33].to(device="cpu")
} }
for i, label in enumerate(labels): for i, label in enumerate(labels):
os.makedirs(os.path.join(args.output_dir, label), exist_ok=True) os.makedirs(os.path.join(output_dir, label), exist_ok=True)
result = {"label": label} result = {"label": label}
result["representations"] = { result["representations"] = {
...@@ -122,10 +140,24 @@ def main(args): ...@@ -122,10 +140,24 @@ def main(args):
} }
torch.save( torch.save(
result, result,
os.path.join(args.output_dir, label, label+".pt") os.path.join(output_dir, label, label+".pt")
) )
os.remove(temp_fasta_file) os.remove(temp_fasta_file)
def main(args):
logging.info("Loading the model...")
embedding_generator = EmbeddingGenerator(
args.toks_per_batch,
args.truncate,
args.use_local_esm,
args.nogpu)
logging.info("Loading the sequences and running the inference...")
embedding_generator.run(
args.fasta_dir,
args.output_dir
)
logging.info("Completed.") logging.info("Completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment