"vscode:/vscode.git/clone" did not exist on "0a0dbb39602ff3aa1acb35b68fc517d3e6afb051"
Commit 75889e9a authored by Sachin Kadyan's avatar Sachin Kadyan
Browse files

Added switch for using the single sequence embedder when using the model in `seqemb` mode.

- `seqemb_mode_enabled` added as a configuration option.
- `model.py` switches to using the `PreembeddingEmbedder` when the flag is `True`.
parent aacf1b6f
......@@ -377,6 +377,7 @@ config = mlc.ConfigDict(
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"seqemb_mode_enabled": False,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
......
......@@ -24,6 +24,7 @@ from openfold.model.embedders import (
TemplateAngleEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder,
PreembeddingEmbedder,
)
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads
......@@ -71,8 +72,14 @@ class AlphaFold(nn.Module):
self.config = config.model
self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa
self.seqemb_mode = config.globals.seqemb_mode_enabled
# Main trunk + structure module
if self.seqemb_mode:
self.preembedding_embedder = PreembeddingEmbedder(
**self.config["preembedding_embedder"],
)
else:
self.input_embedder = InputEmbedder(
**self.config["input_embedder"],
)
......@@ -238,8 +245,18 @@ class AlphaFold(nn.Module):
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
## Initialize the SingleSeq and pair representations
# m: [*, 1, N, C_m]
# z: [*, N, N, C_z]
if self.seqemb_mode:
m, z = self.preembedding_embedder(
feats["target_feat"],
feats["residue_index"],
feats["seq_embedding"]
)
else:
## Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment