Commit 29962990 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added documentation for some sequence embedding model changes.

parent a83c6fcc
...@@ -154,8 +154,10 @@ def model_config( ...@@ -154,8 +154,10 @@ def model_config(
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS # SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training": elif name == "seqemb_initial_training":
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True c.model.evoformer_stack.no_column_attention = True
elif name == "seqemb_finetuning": elif name == "seqemb_finetuning":
...@@ -211,7 +213,11 @@ c_m = mlc.FieldReference(256, field_type=int) ...@@ -211,7 +213,11 @@ c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int) c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int) c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int) c_s = mlc.FieldReference(384, field_type=int)
# For seqemb mode, dimension size of the per-residue sequence embedding passed to the model
# In current model, the dimension size is the ESM-1b dimension size i.e. 1280.
preemb_dim_size = mlc.FieldReference(1280, field_type=int) preemb_dim_size = mlc.FieldReference(1280, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int) blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
...@@ -322,13 +328,13 @@ config = mlc.ConfigDict( ...@@ -322,13 +328,13 @@ config = mlc.ConfigDict(
"deletion_matrix", "deletion_matrix",
"no_recycling_iters", "no_recycling_iters",
], ],
"seqemb_features": [ "seqemb_features": [ # List of features to be generated in seqemb mode
"seq_embedding" "seq_embedding"
], ],
"use_templates": templates_enabled, "use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles, "use_template_torsion_angles": embed_template_torsion_angles,
}, },
"seqemb_mode": { "seqemb_mode": { # Configuration for sequence embedding mode
"enabled": False, # If True, use seq emb instead of MSA "enabled": False, # If True, use seq emb instead of MSA
"seqemb_config": { "seqemb_config": {
"max_msa_clusters": 0, "max_msa_clusters": 0,
...@@ -400,7 +406,7 @@ config = mlc.ConfigDict( ...@@ -400,7 +406,7 @@ config = mlc.ConfigDict(
}, },
# Recurring FieldReferences that can be changed globally here # Recurring FieldReferences that can be changed globally here
"globals": { "globals": {
"seqemb_mode_enabled": False, "seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually # Use Staats & Rabe's low-memory attention algorithm. Mutually
...@@ -426,7 +432,7 @@ config = mlc.ConfigDict( ...@@ -426,7 +432,7 @@ config = mlc.ConfigDict(
"c_m": c_m, "c_m": c_m,
"relpos_k": 32, "relpos_k": 32,
}, },
"preembedding_embedder": { "preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22, "tf_dim": 22,
"preembedding_dim": preemb_dim_size, "preembedding_dim": preemb_dim_size,
"c_z": c_z, "c_z": c_z,
......
...@@ -261,6 +261,7 @@ def make_msa_features( ...@@ -261,6 +261,7 @@ def make_msa_features(
return features return features
# Generate 1-sequence MSA features having only the input sequence
def make_dummy_msa_feats(input_sequence): def make_dummy_msa_feats(input_sequence):
msas = [[input_sequence]] msas = [[input_sequence]]
deletion_matrices = [[[0 for _ in input_sequence]]] deletion_matrices = [[[0 for _ in input_sequence]]]
...@@ -639,6 +640,7 @@ class DataPipeline: ...@@ -639,6 +640,7 @@ class DataPipeline:
return msa_features return msa_features
# Load and process sequence embedding features
def _process_seqemb_features(self, def _process_seqemb_features(self,
alignment_dir: str, alignment_dir: str,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
...@@ -648,6 +650,7 @@ class DataPipeline: ...@@ -648,6 +650,7 @@ class DataPipeline:
ext = os.path.splitext(f)[-1] ext = os.path.splitext(f)[-1]
if (ext == ".pt"): if (ext == ".pt"):
# Load embedding file
seqemb_data = torch.load(path) seqemb_data = torch.load(path)
seqemb_features["seq_embedding"] = seqemb_data seqemb_features["seq_embedding"] = seqemb_data
...@@ -686,6 +689,7 @@ class DataPipeline: ...@@ -686,6 +689,7 @@ class DataPipeline:
) )
sequence_embedding_features = {} sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode: if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence) msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir) sequence_embedding_features = self._process_seqemb_features(alignment_dir)
...@@ -732,6 +736,7 @@ class DataPipeline: ...@@ -732,6 +736,7 @@ class DataPipeline:
) )
sequence_embedding_features = {} sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode: if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence) msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir) sequence_embedding_features = self._process_seqemb_features(alignment_dir)
......
...@@ -41,6 +41,7 @@ def np_to_tensor_dict( ...@@ -41,6 +41,7 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
# torch generates warnings if feature is already a torch Tensor
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach() to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = { tensor_dict = {
k: to_tensor(v) for k, v in np_example.items() if k in features k: to_tensor(v) for k, v in np_example.items() if k in features
...@@ -62,6 +63,7 @@ def make_data_config( ...@@ -62,6 +63,7 @@ def make_data_config(
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
# Add seqemb related features if using seqemb mode.
if cfg.seqemb_mode.enabled: if cfg.seqemb_mode.enabled:
feature_names += cfg.common.seqemb_features feature_names += cfg.common.seqemb_features
......
...@@ -337,6 +337,7 @@ class EvoformerBlock(nn.Module): ...@@ -337,6 +337,7 @@ class EvoformerBlock(nn.Module):
inf=inf, inf=inf,
) )
# Specifically, seqemb mode does not use column attention
self.no_column_attention = no_column_attention self.no_column_attention = no_column_attention
if self.no_column_attention == False: if self.no_column_attention == False:
self.msa_att_col = MSAColumnAttention( self.msa_att_col = MSAColumnAttention(
...@@ -401,6 +402,7 @@ class EvoformerBlock(nn.Module): ...@@ -401,6 +402,7 @@ class EvoformerBlock(nn.Module):
inplace=inplace_safe, inplace=inplace_safe,
) )
# Specifically, column attention is not used in seqemb mode.
if self.no_column_attention == False: if self.no_column_attention == False:
m = add(m, m = add(m,
self.msa_att_col( self.msa_att_col(
...@@ -637,6 +639,9 @@ class EvoformerStack(nn.Module): ...@@ -637,6 +639,9 @@ class EvoformerStack(nn.Module):
Dropout used for pair activations Dropout used for pair activations
blocks_per_ckpt: blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint Number of Evoformer blocks in each activation checkpoint
no_column_attention:
When True, doesn't use column attention. Required for running
sequence embedding mode
clear_cache_between_blocks: clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation stack. Slows down each block but can reduce fragmentation
......
...@@ -75,6 +75,8 @@ class AlphaFold(nn.Module): ...@@ -75,6 +75,8 @@ class AlphaFold(nn.Module):
self.seqemb_mode = config.globals.seqemb_mode_enabled self.seqemb_mode = config.globals.seqemb_mode_enabled
# Main trunk + structure module # Main trunk + structure module
# If using seqemb mode, embed the sequence embeddings passed
# to the model ("preembeddings") instead of embedding the sequence
if self.seqemb_mode: if self.seqemb_mode:
self.preembedding_embedder = PreembeddingEmbedder( self.preembedding_embedder = PreembeddingEmbedder(
**self.config["preembedding_embedder"], **self.config["preembedding_embedder"],
......
...@@ -73,6 +73,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -73,6 +73,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
# In seqemb mode, use AlignmentRunner only to generate templates
if args.use_single_seq_mode: if args.use_single_seq_mode:
alignment_runner = data_pipeline.AlignmentRunner( alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment