Commit 29962990 authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added documentation for some sequence embedding model changes.

parent a83c6fcc
......@@ -154,8 +154,10 @@ def model_config(
c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training":
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
elif name == "seqemb_finetuning":
......@@ -211,7 +213,11 @@ c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)
# For seqemb mode, dimension size of the per-residue sequence embedding passed to the model
# In current model, the dimension size is the ESM-1b dimension size i.e. 1280.
preemb_dim_size = mlc.FieldReference(1280, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
......@@ -322,13 +328,13 @@ config = mlc.ConfigDict(
"deletion_matrix",
"no_recycling_iters",
],
"seqemb_features": [
"seqemb_features": [ # List of features to be generated in seqemb mode
"seq_embedding"
],
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"seqemb_mode": {
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": False, # If True, use seq emb instead of MSA
"seqemb_config": {
"max_msa_clusters": 0,
......@@ -400,7 +406,7 @@ config = mlc.ConfigDict(
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"seqemb_mode_enabled": False,
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
......@@ -426,7 +432,7 @@ config = mlc.ConfigDict(
"c_m": c_m,
"relpos_k": 32,
},
"preembedding_embedder": {
"preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22,
"preembedding_dim": preemb_dim_size,
"c_z": c_z,
......
......@@ -261,6 +261,7 @@ def make_msa_features(
return features
# Generate 1-sequence MSA features having only the input sequence
def make_dummy_msa_feats(input_sequence):
msas = [[input_sequence]]
deletion_matrices = [[[0 for _ in input_sequence]]]
......@@ -639,6 +640,7 @@ class DataPipeline:
return msa_features
# Load and process sequence embedding features
def _process_seqemb_features(self,
alignment_dir: str,
) -> Mapping[str, Any]:
......@@ -648,6 +650,7 @@ class DataPipeline:
ext = os.path.splitext(f)[-1]
if (ext == ".pt"):
# Load embedding file
seqemb_data = torch.load(path)
seqemb_features["seq_embedding"] = seqemb_data
......@@ -686,6 +689,7 @@ class DataPipeline:
)
sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
......@@ -732,6 +736,7 @@ class DataPipeline:
)
sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
......
......@@ -41,6 +41,7 @@ def np_to_tensor_dict(
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
# torch generates warnings if feature is already a torch Tensor
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = {
k: to_tensor(v) for k, v in np_example.items() if k in features
......@@ -62,6 +63,7 @@ def make_data_config(
feature_names = cfg.common.unsupervised_features
# Add seqemb related features if using seqemb mode.
if cfg.seqemb_mode.enabled:
feature_names += cfg.common.seqemb_features
......
......@@ -337,6 +337,7 @@ class EvoformerBlock(nn.Module):
inf=inf,
)
# Specifically, seqemb mode does not use column attention
self.no_column_attention = no_column_attention
if self.no_column_attention == False:
self.msa_att_col = MSAColumnAttention(
......@@ -401,6 +402,7 @@ class EvoformerBlock(nn.Module):
inplace=inplace_safe,
)
# Specifically, column attention is not used in seqemb mode.
if self.no_column_attention == False:
m = add(m,
self.msa_att_col(
......@@ -637,6 +639,9 @@ class EvoformerStack(nn.Module):
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
no_column_attention:
When True, doesn't use column attention. Required for running
sequence embedding mode
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
......
......@@ -75,6 +75,8 @@ class AlphaFold(nn.Module):
self.seqemb_mode = config.globals.seqemb_mode_enabled
# Main trunk + structure module
# If using seqemb mode, embed the sequence embeddings passed
# to the model ("preembeddings") instead of embedding the sequence
if self.seqemb_mode:
self.preembedding_embedder = PreembeddingEmbedder(
**self.config["preembedding_embedder"],
......
......@@ -73,6 +73,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
os.makedirs(local_alignment_dir)
# In seqemb mode, use AlignmentRunner only to generate templates
if args.use_single_seq_mode:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment