Commit 5aa54958 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into deepspeed-evo-attention

parents f545323c 099769d2
# Some functions borrowed from [ESM](https://www.github.com/facebookresearch/esm)
import argparse
import logging
import os
import torch
from openfold.data import parsers
logging.basicConfig(level=logging.INFO)
class SequenceDataset(object):
def __init__(self, labels, sequences) -> None:
self.labels = labels
self.sequences = sequences
@classmethod
def from_file(cls, fasta_file):
labels, sequences = [], []
with open(fasta_file, "r") as infile:
fasta_str = infile.read()
sequences, labels = parsers.parse_fasta(fasta_str)
assert len(set(labels)) == len(labels),\
"Sequence labels need to be unique. Duplicates found!"
return cls(labels, sequences)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.labels[idx], self.sequences[idx]
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq):
sizes = [(len(s), i) for i, s in enumerate(self.sequences)]
sizes.sort()
batches = []
buf = []
max_len = 0
def _flush_current_buf():
nonlocal max_len, buf
if len(buf) == 0:
return
batches.append(buf)
buf = []
max_len = 0
for sz, i in sizes:
sz += extra_toks_per_seq
if max(sz, max_len) * (len(buf)+1) > toks_per_batch:
_flush_current_buf()
max_len = max(max_len, sz)
buf.append(i)
_flush_current_buf()
return batches
class EmbeddingGenerator:
"""Generates the ESM-1b embeddings for the single sequence model"""
def __init__(self,
toks_per_batch: int = 4096,
truncate: bool = True,
use_local_esm: str = None,
nogpu: bool = False,
):
self.toks_per_batch = toks_per_batch
self.truncate = truncate
self.use_local_esm = use_local_esm
self.nogpu = nogpu
# Generate embeddings in bulk
if self.use_local_esm:
self.model, self.alphabet = torch.hub.load(self.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda")
def parse_sequences(self, fasta_dir, output_dir):
labels = []
seqs = []
# Generate a single bulk file
for f in os.listdir(fasta_dir):
f_name, ext = os.path.splitext(f)
if ext != '.fasta' and ext != '.fa':
logging.warning(f"Ignoring non-FASTA file: {f}")
continue
with open(os.path.join(fasta_dir, f), 'r') as infile:
seq = infile.readlines()[1].strip()
labels.append(f_name)
seqs.append(seq)
lines = []
for label, seq in zip(labels, seqs):
lines += f'>{label}\n'
lines += f'{seq}\n'
os.makedirs(output_dir, exist_ok=True)
temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
with open(temp_fasta_file, 'w') as outfile:
outfile.writelines(lines)
return temp_fasta_file
def run(
self,
fasta_file,
output_dir,
):
dataset = SequenceDataset.from_file(fasta_file)
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
)
logging.info("Loaded all sequences")
repr_layers = [33]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available() and not self.nogpu:
toks = toks.to(device="cuda", non_blocking=True)
if self.truncate:
toks = toks[:1022]
out = self.model(toks, repr_layers=repr_layers, return_contacts=False)
representations = {
33: out["representations"][33].to(device="cpu")
}
for i, label in enumerate(labels):
os.makedirs(os.path.join(output_dir, label), exist_ok=True)
result = {"label": label}
result["representations"] = {
33: representations[33][i, 1: len(strs[i]) + 1].clone()
}
torch.save(
result,
os.path.join(output_dir, label, label+".pt")
)
def main(args):
logging.info("Loading the model...")
embedding_generator = EmbeddingGenerator(
args.toks_per_batch,
args.truncate,
args.use_local_esm,
args.nogpu)
logging.info("Loading the sequences and running the inference...")
temp_fasta_file = embedding_generator.parse_sequences(
args.fasta_dir,
args.output_dir
)
embedding_generator.run(
temp_fasta_file,
args.output_dir
)
os.remove(temp_fasta_file)
logging.info("Completed.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_dir", type=str,
help="""Path to directory containing FASTA files."""
)
parser.add_argument(
"output_dir", type=str,
help="Directory in which to output embeddings"
)
parser.add_argument(
"--toks_per_batch", type=int, default=4096,
help="maximum tokens in a batch"
)
parser.add_argument(
"--truncate", action="store_true", default=True,
help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
)
parser.add_argument(
"--use_local_esm", type=str, default=None,
help="Use a local ESM repository instead of cloning from Github"
)
parser.add_argument(
"--nogpu", action="store_true",
help="Do not use GPU"
)
args = parser.parse_args()
main(args)
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import unittest import unittest
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
PreembeddingEmbedder,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateAngleEmbedder,
TemplatePairEmbedder, TemplatePairEmbedder,
...@@ -46,6 +47,28 @@ class TestInputEmbedder(unittest.TestCase): ...@@ -46,6 +47,28 @@ class TestInputEmbedder(unittest.TestCase):
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
class TestPreembeddingEmbedder(unittest.TestCase):
def test_shape(self):
tf_dim = 22
preembedding_dim = 1280
c_z = 4
c_m = 6
relpos_k = 10
batch_size = 4
num_res = 20
tf = torch.rand((batch_size, num_res, tf_dim))
ri = torch.rand((batch_size, num_res))
preemb = torch.rand((batch_size, num_res, preembedding_dim))
pe = PreembeddingEmbedder(tf_dim, preembedding_dim, c_z, c_m, relpos_k)
seq_emb, pair_emb = pe(tf, ri, preemb)
self.assertTrue(seq_emb.shape == (batch_size, 1, num_res, c_m))
self.assertTrue(pair_emb.shape == (batch_size, num_res, num_res, c_z))
class TestRecyclingEmbedder(unittest.TestCase): class TestRecyclingEmbedder(unittest.TestCase):
def test_shape(self): def test_shape(self):
batch_size = 2 batch_size = 2
......
...@@ -66,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -66,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
blocks_per_ckpt=None, blocks_per_ckpt=None,
no_column_attention=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
).eval() ).eval()
...@@ -86,6 +87,62 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -86,6 +87,62 @@ class TestEvoformerStack(unittest.TestCase):
self.assertTrue(z.shape == shape_z_before) self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s)) self.assertTrue(s.shape == (batch_size, n_res, c_s))
def test_shape_without_column_attention(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c_z = consts.c_z
c_hidden_msa_att = 12
c_hidden_opm = 17
c_hidden_mul = 19
c_hidden_pair_att = 14
c_s = consts.c_s
no_heads_msa = 3
no_heads_pair = 7
no_blocks = 2
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
inf = 1e9
eps = 1e-10
es = EvoformerStack(
c_m,
c_z,
c_hidden_msa_att,
c_hidden_opm,
c_hidden_mul,
c_hidden_pair_att,
c_s,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
no_column_attention=True,
inf=inf,
eps=eps,
).eval()
m_init = torch.rand((batch_size, n_seq, n_res, c_m))
z_init = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_m_before = m_init.shape
shape_z_before = z_init.shape
m, z, s = es(
m_init, z_init, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
)
self.assertTrue(m.shape == shape_m_before)
self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s))
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def run_ei(activations, masks): def run_ei(activations, masks):
...@@ -206,7 +263,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -206,7 +263,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
pair_mask = torch.randint( pair_mask = torch.randint(
0, 0,
2, 2,
...@@ -216,7 +273,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -216,7 +273,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
shape_z_before = z.shape shape_z_before = z.shape
......
...@@ -47,33 +47,73 @@ class TestModel(unittest.TestCase): ...@@ -47,33 +47,73 @@ class TestModel(unittest.TestCase):
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c).cuda()
model.eval() model.eval()
batch = {} batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)) tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
batch["target_feat"] = nn.functional.one_hot( batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim tf, c.model.input_embedder.tf_dim
).float() ).float().cuda()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res).cuda()
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res) extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()}) batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint( batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res) low=0, high=2, size=(n_seq, n_res)
).float() ).float().cuda()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.).cuda()
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
batch = tensor_tree_map(add_recycling_dims, batch)
with torch.no_grad():
out = model(batch)
def test_dry_run_seqemb_mode(self):
n_seq = 1
n_templ = consts.n_templ
n_res = consts.n_res
msa_dim = 49
c = model_config("seq_model_esm1b")
c.model.evoformer_stack.no_blocks = 2
c.model.evoformer_stack.blocks_per_ckpt = None
model = AlphaFold(c)
model.to(torch.device('cuda'))
model.eval()
batch = {}
tf = torch.randint(c.model.preembedding_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(tf, c.model.preembedding_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, msa_dim))
batch["seq_embedding"] = torch.rand((n_res, c.model.preembedding_embedder.preembedding_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float() batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["msa_mask"] = torch.randint(low=0, high=2, size=(n_seq, n_res)).float()
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
) )
batch = tensor_tree_map(add_recycling_dims, batch) batch = tensor_tree_map(add_recycling_dims, batch)
to_cuda_device = lambda t: t.to(torch.device("cuda"))
batch = tensor_tree_map(to_cuda_device, batch)
with torch.no_grad(): with torch.no_grad():
out = model(batch) out = model(batch)
......
...@@ -416,6 +416,10 @@ if __name__ == "__main__": ...@@ -416,6 +416,10 @@ if __name__ == "__main__":
help='''Cutoff for all templates. In training mode, templates are also help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target''' filtered by the release date of the target'''
) )
parser.add_argument(
"--use_single_seq_mode", type=str, default=False,
help="Use single sequence embeddings instead of MSAs."
)
parser.add_argument( parser.add_argument(
"--distillation_data_dir", type=str, default=None, "--distillation_data_dir", type=str, default=None,
help="Directory containing training PDB files" help="Directory containing training PDB files"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment