Commit 8185c307 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Just-in-time embedding generation for the SoloSeq model

parent 4c8e3764
...@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import ( ...@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq, pad_feature_dict_seq,
trace_model_, trace_model_,
) )
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args from scripts.utils import add_data_args
...@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -82,6 +83,8 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
pdb70_database_path=args.pdb70_database_path, pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus, no_cpus=args.cpus,
) )
embedding_generator = EmbeddingGenerator()
embedding_generator.run(args.fasta_dir, local_alignment_dir)
else: else:
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
......
...@@ -58,74 +58,106 @@ class SequenceDataset(object): ...@@ -58,74 +58,106 @@ class SequenceDataset(object):
_flush_current_buf() _flush_current_buf()
return batches return batches
def main(args):
labels = [] class EmbeddingGenerator:
seqs = [] """Generates the ESM-1b embeddings for the single sequence model"""
def __init__(self,
# Generate a single bulk file toks_per_batch: int = 4096,
for f in os.listdir(args.fasta_dir): truncate: bool = True,
f_name, ext = os.path.splitext(f) use_local_esm: str = None,
if ext != '.fasta' and ext != '.fa': nogpu: bool = False,
logging.warning(f"Ignoring non-FASTA file: {f}") ):
continue self.toks_per_batch = toks_per_batch
with open(os.path.join(args.fasta_dir, f), 'r') as infile: self.truncate = truncate
seq = infile.readlines()[1].strip() self.use_local_esm = use_local_esm
labels.append(f_name) self.nogpu = nogpu
seqs.append(seq)
# Generate embeddings in bulk
if self.use_local_esm:
self.model, self.alphabet = torch.hub.load(self.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda")
lines = [] def run(
for label, seq in zip(labels, seqs): self,
lines += f'>{label}\n' fasta_dir,
lines += f'{seq}\n' output_dir,
os.makedirs(args.output_dir, exist_ok=True) ):
temp_fasta_file = os.path.join(args.output_dir, 'temp.fasta') labels = []
with open(temp_fasta_file, 'w') as outfile: seqs = []
outfile.writelines(lines)
# Generate a single bulk file
# Generate embeddings in bulk for f in os.listdir(fasta_dir):
if args.use_local_esm: f_name, ext = os.path.splitext(f)
model, alphabet = torch.hub.load(args.use_local_esm, "esm1b_t33_650M_UR50S", source='local') if ext != '.fasta' and ext != '.fa':
else: logging.warning(f"Ignoring non-FASTA file: {f}")
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S") continue
if torch.cuda.is_available() and not args.nogpu: with open(os.path.join(fasta_dir, f), 'r') as infile:
model = model.to(device="cuda") seq = infile.readlines()[1].strip()
dataset = SequenceDataset.from_file(temp_fasta_file) labels.append(f_name)
batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1) seqs.append(seq)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=alphabet.get_batch_converter(), batch_sampler=batches lines = []
) for label, seq in zip(labels, seqs):
logging.info("Loaded all sequences") lines += f'>{label}\n'
repr_layers = [33] lines += f'{seq}\n'
os.makedirs(output_dir, exist_ok=True)
with torch.no_grad(): temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
for batch_idx, (labels, strs, toks) in enumerate(data_loader): with open(temp_fasta_file, 'w') as outfile:
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)") outfile.writelines(lines)
if torch.cuda.is_available() and not args.nogpu:
toks = toks.to(device="cuda", non_blocking=True) dataset = SequenceDataset.from_file(temp_fasta_file)
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
if args.truncate: data_loader = torch.utils.data.DataLoader(
toks = toks[:1022] dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
)
out = model(toks, repr_layers=repr_layers, return_contacts=False) logging.info("Loaded all sequences")
repr_layers = [33]
logits = out["logits"].to(device="cpu")
representations = { with torch.no_grad():
33: out["representations"][33].to(device="cpu") for batch_idx, (labels, strs, toks) in enumerate(data_loader):
} logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available() and not self.nogpu:
for i, label in enumerate(labels): toks = toks.to(device="cuda", non_blocking=True)
os.makedirs(os.path.join(args.output_dir, label), exist_ok=True)
result = {"label": label} if self.truncate:
toks = toks[:1022]
result["representations"] = {
33: representations[33][i, 1: len(strs[i]) + 1].clone() out = self.model(toks, repr_layers=repr_layers, return_contacts=False)
representations = {
33: out["representations"][33].to(device="cpu")
} }
torch.save(
result, for i, label in enumerate(labels):
os.path.join(args.output_dir, label, label+".pt") os.makedirs(os.path.join(output_dir, label), exist_ok=True)
) result = {"label": label}
os.remove(temp_fasta_file) result["representations"] = {
33: representations[33][i, 1: len(strs[i]) + 1].clone()
}
torch.save(
result,
os.path.join(output_dir, label, label+".pt")
)
os.remove(temp_fasta_file)
def main(args):
logging.info("Loading the model...")
embedding_generator = EmbeddingGenerator(
args.toks_per_batch,
args.truncate,
args.use_local_esm,
args.nogpu)
logging.info("Loading the sequences and running the inference...")
embedding_generator.run(
args.fasta_dir,
args.output_dir
)
logging.info("Completed.") logging.info("Completed.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment