Commit d7c11537 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into multimer

parents 7c7dffd0 9e32781f
...@@ -29,7 +29,7 @@ vice versa (see `scripts/convert_of_weights_to_jax.py`). ...@@ -29,7 +29,7 @@ vice versa (see `scripts/convert_of_weights_to_jax.py`).
OpenFold has the following advantages over the reference implementation: OpenFold has the following advantages over the reference implementation:
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs. - **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on Ampere or higher architecture GPUs.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention - **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of ([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading. sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
...@@ -232,6 +232,51 @@ efficent AlphaFold-Multimer more than double the time. Use the ...@@ -232,6 +232,51 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option `--long_sequence_inference` command line option
#### SoloSeq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
```bash
python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
```
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
Now, you are ready to run inference:
```bash
python run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--use_precomputed_alignments embeddings_output_dir \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt
```
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
```
For generating template information, you will need the UniRef90 and PDB70 databases and the JackHmmer and HHSearch binaries.
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
### Training ### Training
To train the model, you will first need to precompute protein alignments. To train the model, you will first need to precompute protein alignments.
...@@ -439,17 +484,27 @@ Please cite our paper: ...@@ -439,17 +484,27 @@ Please cite our paper:
```bibtex ```bibtex
@article {Ahdritz2022.11.20.517210, @article {Ahdritz2022.11.20.517210,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed}, author = {Ahdritz, Gustaf and Bouatta, Nazim and Floristean, Christina and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
title = {OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization}, title = {{O}pen{F}old: {R}etraining {A}lpha{F}old2 yields new insights into its learning mechanisms and capacity for generalization},
elocation-id = {2022.11.20.517210}, elocation-id = {2022.11.20.517210},
year = {2022}, year = {2022},
doi = {10.1101/2022.11.20.517210}, doi = {10.1101/2022.11.20.517210},
publisher = {Cold Spring Harbor Laboratory}, publisher = {Cold Spring Harbor Laboratory},
abstract = {AlphaFold2 revolutionized structural biology with the ability to predict protein structures with exceptionally high accuracy. Its implementation, however, lacks the code and data required to train new models. These are necessary to (i) tackle new tasks, like protein-ligand complex structure prediction, (ii) investigate the process by which the model learns, which remains poorly understood, and (iii) assess the model{\textquoteright}s generalization capacity to unseen regions of fold space. Here we report OpenFold, a fast, memory-efficient, and trainable implementation of AlphaFold2, and OpenProteinSet, the largest public database of protein multiple sequence alignments. We use OpenProteinSet to train OpenFold from scratch, fully matching the accuracy of AlphaFold2. Having established parity, we assess OpenFold{\textquoteright}s capacity to generalize across fold space by retraining it using carefully designed datasets. We find that OpenFold is remarkably robust at generalizing despite extreme reductions in training set size and diversity, including near-complete elisions of classes of secondary structure elements. By analyzing intermediate structures produced by OpenFold during training, we also gain surprising insights into the manner in which the model learns to fold proteins, discovering that spatial dimensions are learned sequentially. Taken together, our studies demonstrate the power and utility of OpenFold, which we believe will prove to be a crucial new resource for the protein modeling community.},
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210}, URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf}, eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
journal = {bioRxiv} journal = {bioRxiv}
} }
``` ```
If you use OpenProteinSet, please also cite:
```bibtex
@misc{ahdritz2023openproteinset,
title={{O}pen{P}rotein{S}et: {T}raining data for structural biology at scale},
author={Gustaf Ahdritz and Nazim Bouatta and Sachin Kadyan and Lukas Jarosch and Daniel Berenberg and Ian Fisk and Andrew M. Watkins and Stephen Ra and Richard Bonneau and Mohammed AlQuraishi},
year={2023},
eprint={2308.05326},
archivePrefix={arXiv},
primaryClass={q-bio.BM}
}
```
Any work that cites OpenFold should also cite AlphaFold. Any work that cites OpenFold should also cite AlphaFold.
...@@ -153,7 +153,32 @@ def model_config( ...@@ -153,7 +153,32 @@ def model_config(
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
elif "multimer" in name: elif name.startswith("seq"): # SINGLE SEQUENCE EMBEDDING PRESETS
c.update(seq_mode_config.copy_and_resolve_references())
if name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "seq_model_esm1b":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
elif name == "seq_model_esm1b_ptm":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
elif "multimer" in name: # MULTIMER PRESETS
c.update(multimer_config_update.copy_and_resolve_references()) c.update(multimer_config_update.copy_and_resolve_references())
# Not used in multimer # Not used in multimer
...@@ -224,6 +249,11 @@ c_m = mlc.FieldReference(256, field_type=int) ...@@ -224,6 +249,11 @@ c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int) c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int) c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int) c_s = mlc.FieldReference(384, field_type=int)
# For seqemb mode, dimension size of the per-residue sequence embedding passed to the model
# In current model, the dimension size is the ESM-1b dimension size i.e. 1280.
preemb_dim_size = mlc.FieldReference(1280, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int) blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
...@@ -336,6 +366,9 @@ config = mlc.ConfigDict( ...@@ -336,6 +366,9 @@ config = mlc.ConfigDict(
"use_templates": templates_enabled, "use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles, "use_template_torsion_angles": embed_template_torsion_angles,
}, },
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": False, # If True, use seq emb instead of MSA
},
"supervised": { "supervised": {
"clamp_prob": 0.9, "clamp_prob": 0.9,
"supervised_features": [ "supervised_features": [
...@@ -422,6 +455,7 @@ config = mlc.ConfigDict( ...@@ -422,6 +455,7 @@ config = mlc.ConfigDict(
"c_s": c_s, "c_s": c_s,
"eps": eps, "eps": eps,
"is_multimer": False, "is_multimer": False,
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
}, },
"model": { "model": {
"_mask_trans": False, "_mask_trans": False,
...@@ -539,6 +573,7 @@ config = mlc.ConfigDict( ...@@ -539,6 +573,7 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"no_column_attention": False,
"opm_first": False, "opm_first": False,
"fuse_projection_weights": False, "fuse_projection_weights": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
...@@ -857,3 +892,38 @@ multimer_config_update = mlc.ConfigDict({ ...@@ -857,3 +892,38 @@ multimer_config_update = mlc.ConfigDict({
} }
} }
}) })
seq_mode_config = mlc.ConfigDict({
"data": {
"common": {
"feat": {
"seq_embedding": [NUM_RES, None],
},
"seqemb_features": [ # List of features to be generated in seqemb mode
"seq_embedding"
],
},
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": True, # If True, use seq emb instead of MSA
},
},
"globals": {
"seqemb_mode_enabled": True,
},
"model": {
"preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22,
"preembedding_dim": preemb_dim_size,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
"extra_msa": {
"enabled": False # Disable Extra MSA Stack
},
"evoformer_stack": {
"no_column_attention": True # Turn off Evoformer's column attention
},
}
})
...@@ -192,7 +192,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -192,7 +192,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
alignment_index=alignment_index alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled
) )
return data return data
...@@ -244,6 +245,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -244,6 +245,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
elif ext == ".core": elif ext == ".core":
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index, path, alignment_dir, alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
) )
elif ext == ".pdb": elif ext == ".pdb":
structure_index = None structure_index = None
...@@ -256,6 +258,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -256,6 +258,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id=chain_id, chain_id=chain_id,
alignment_index=alignment_index, alignment_index=alignment_index,
_structure_index=structure_index, _structure_index=structure_index,
seqemb_mode=self.config.seqemb_mode.enabled,
) )
else: else:
raise ValueError("Extension branch missing") raise ValueError("Extension branch missing")
...@@ -265,6 +268,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -265,6 +268,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index, alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
) )
if self._output_raw: if self._output_raw:
......
...@@ -23,6 +23,8 @@ import tempfile ...@@ -23,6 +23,8 @@ import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
import numpy as np import numpy as np
import torch
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
...@@ -269,6 +271,19 @@ def run_msa_tool( ...@@ -269,6 +271,19 @@ def run_msa_tool(
return result return result
# Generate 1-sequence MSA features having only the input sequence
def make_dummy_msa_feats(input_sequence):
msas = [[input_sequence]]
deletion_matrices = [[[0 for _ in input_sequence]]]
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def make_sequence_features_with_custom_template( def make_sequence_features_with_custom_template(
sequence: str, sequence: str,
mmcif_path: str, mmcif_path: str,
...@@ -821,11 +836,28 @@ class DataPipeline: ...@@ -821,11 +836,28 @@ class DataPipeline:
return msa_features return msa_features
# Load and process sequence embedding features
def _process_seqemb_features(self,
alignment_dir: str,
) -> Mapping[str, Any]:
seqemb_features = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".pt"):
# Load embedding file
seqemb_data = torch.load(path)
seqemb_features["seq_embedding"] = seqemb_data["representations"][33]
return seqemb_features
def process_fasta( def process_fasta(
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -857,12 +889,19 @@ class DataPipeline: ...@@ -857,12 +889,19 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return { return {
**sequence_features, **sequence_features,
**msa_features, **msa_features,
**template_features **template_features,
**sequence_embedding_features
} }
def process_mmcif( def process_mmcif(
...@@ -871,6 +910,7 @@ class DataPipeline: ...@@ -871,6 +910,7 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -899,9 +939,15 @@ class DataPipeline: ...@@ -899,9 +939,15 @@ class DataPipeline:
self.template_featurizer self.template_featurizer
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features} return {**mmcif_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_pdb( def process_pdb(
self, self,
...@@ -911,6 +957,7 @@ class DataPipeline: ...@@ -911,6 +957,7 @@ class DataPipeline:
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -949,15 +996,22 @@ class DataPipeline: ...@@ -949,15 +996,22 @@ class DataPipeline:
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) sequence_embedding_features = {}
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_core( def process_core(
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -982,9 +1036,15 @@ class DataPipeline: ...@@ -982,9 +1036,15 @@ class DataPipeline:
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) sequence_embedding_features = {}
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**core_feats, **template_features, **msa_features} return {**core_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_multiseq_fasta(self, def process_multiseq_fasta(self,
fasta_path: str, fasta_path: str,
......
...@@ -40,9 +40,11 @@ def np_to_tensor_dict( ...@@ -40,9 +40,11 @@ def np_to_tensor_dict(
Returns: Returns:
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
# torch generates warnings if feature is already a torch Tensor
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = { tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features k: to_tensor(v) for k, v in np_example.items() if k in features
} }
return tensor_dict return tensor_dict
...@@ -61,6 +63,10 @@ def make_data_config( ...@@ -61,6 +63,10 @@ def make_data_config(
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
# Add seqemb related features if using seqemb mode.
if cfg.seqemb_mode.enabled:
feature_names += cfg.common.seqemb_features
if cfg.common.use_templates: if cfg.common.use_templates:
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
......
...@@ -309,6 +309,99 @@ class InputEmbedderMultimer(nn.Module): ...@@ -309,6 +309,99 @@ class InputEmbedderMultimer(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class PreembeddingEmbedder(nn.Module):
"""
Embeds the sequence pre-embedding passed to the model and the target_feat features.
"""
def __init__(
self,
tf_dim: int,
preembedding_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
End channel dimension of the incoming target features
preembedding_dim:
End channel dimension of the incoming embeddings
c_z:
Pair embedding dimension
c_m:
Single-Seq embedding dimension
relpos_k:
Window size used in relative position encoding
"""
super(PreembeddingEmbedder, self).__init__()
self.tf_dim = tf_dim
self.preembedding_dim = preembedding_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_preemb_m = Linear(self.preembedding_dim, c_m)
self.linear_preemb_z_i = Linear(self.preembedding_dim, c_z)
self.linear_preemb_z_j = Linear(self.preembedding_dim, c_z)
# Relative Positional Encoding
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Args:
ri:
"residue_index" feature of shape [*, N]
Returns:
Relative positional encoding of protein using the
residue_index feature
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
preemb: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
)
preemb_emb = self.linear_preemb_m(preemb[..., None, :, :]) + tf_m
preemb_emb_i = self.linear_preemb_z_i(preemb)
preemb_emb_j = self.linear_preemb_z_j(preemb)
pair_emb = self.relpos(ri.type(preemb_emb_i.dtype))
pair_emb = add(pair_emb,
preemb_emb_i[..., None, :],
inplace=inplace_safe)
pair_emb = add(pair_emb,
preemb_emb_j[..., None, :, :],
inplace=inplace_safe)
return preemb_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
......
...@@ -263,6 +263,7 @@ class PairStack(nn.Module): ...@@ -263,6 +263,7 @@ class PairStack(nn.Module):
return z return z
class MSABlock(nn.Module, ABC): class MSABlock(nn.Module, ABC):
@abstractmethod @abstractmethod
def __init__(self, def __init__(self,
...@@ -383,6 +384,7 @@ class EvoformerBlock(MSABlock): ...@@ -383,6 +384,7 @@ class EvoformerBlock(MSABlock):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
no_column_attention: bool,
opm_first: bool, opm_first: bool,
fuse_projection_weights: bool, fuse_projection_weights: bool,
inf: float, inf: float,
...@@ -404,12 +406,16 @@ class EvoformerBlock(MSABlock): ...@@ -404,12 +406,16 @@ class EvoformerBlock(MSABlock):
inf=inf, inf=inf,
eps=eps) eps=eps)
self.msa_att_col = MSAColumnAttention( # Specifically, seqemb mode does not use column attention
c_m, self.no_column_attention = no_column_attention
c_hidden_msa_att,
no_heads_msa, if not self.no_column_attention:
inf=inf, self.msa_att_col = MSAColumnAttention(
) c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
def forward(self, def forward(self,
m: Optional[torch.Tensor], m: Optional[torch.Tensor],
...@@ -470,16 +476,18 @@ class EvoformerBlock(MSABlock): ...@@ -470,16 +476,18 @@ class EvoformerBlock(MSABlock):
torch.cuda.empty_cache() torch.cuda.empty_cache()
m, z = input_tensors m, z = input_tensors
m = add(m, # Specifically, column attention is not used in seqemb mode.
self.msa_att_col( if not self.no_column_attention:
m, m = add(m,
mask=msa_mask, self.msa_att_col(
chunk_size=chunk_size, m,
use_lma=use_lma, mask=msa_mask,
use_flash=use_flash, chunk_size=chunk_size,
), use_lma=use_lma,
inplace=inplace_safe, use_flash=use_flash,
) ),
inplace=inplace_safe,
)
m = add( m = add(
m, m,
...@@ -749,6 +757,7 @@ class EvoformerStack(nn.Module): ...@@ -749,6 +757,7 @@ class EvoformerStack(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
no_column_attention: bool,
opm_first: bool, opm_first: bool,
fuse_projection_weights: bool, fuse_projection_weights: bool,
blocks_per_ckpt: int, blocks_per_ckpt: int,
...@@ -787,6 +796,16 @@ class EvoformerStack(nn.Module): ...@@ -787,6 +796,16 @@ class EvoformerStack(nn.Module):
Dropout rate for MSA activations Dropout rate for MSA activations
pair_dropout: pair_dropout:
Dropout used for pair activations Dropout used for pair activations
no_column_attention:
When True, doesn't use column attention. Required for running
sequence embedding mode
opm_first:
When True, Outer Product Mean is performed at the beginning of
the Evoformer block instead of after the MSA Stack.
Used in Multimer pipeline.
fuse_projection_weights:
When True, uses FusedTriangleMultiplicativeUpdate variant in
the Pair Stack. Used in Multimer pipeline.
blocks_per_ckpt: blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks: clear_cache_between_blocks:
...@@ -815,6 +834,7 @@ class EvoformerStack(nn.Module): ...@@ -815,6 +834,7 @@ class EvoformerStack(nn.Module):
transition_n=transition_n, transition_n=transition_n,
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
no_column_attention=no_column_attention,
opm_first=opm_first, opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights, fuse_projection_weights=fuse_projection_weights,
inf=inf, inf=inf,
......
...@@ -33,6 +33,7 @@ from openfold.model.embedders import ( ...@@ -33,6 +33,7 @@ from openfold.model.embedders import (
TemplateEmbedder, TemplateEmbedder,
TemplateEmbedderMultimer, TemplateEmbedderMultimer,
ExtraMSAEmbedder, ExtraMSAEmbedder,
PreembeddingEmbedder,
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads from openfold.model.heads import AuxiliaryHeads
...@@ -80,11 +81,18 @@ class AlphaFold(nn.Module): ...@@ -80,11 +81,18 @@ class AlphaFold(nn.Module):
self.config = config.model self.config = config.model
self.template_config = self.config.template self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa self.extra_msa_config = self.config.extra_msa
self.seqemb_mode = config.globals.seqemb_mode_enabled
# Main trunk + structure module # Main trunk + structure module
if(self.globals.is_multimer): if self.globals.is_multimer:
self.input_embedder = InputEmbedderMultimer( self.input_embedder = InputEmbedderMultimer(
**self.config["input_embedder"], **self.config["input_embedder"]
)
elif self.seqemb_mode:
# If using seqemb mode, embed the sequence embeddings passed
# to the model ("preembeddings") instead of embedding the sequence
self.input_embedder = PreembeddingEmbedder(
**self.config["preembedding_embedder"],
) )
else: else:
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
...@@ -220,15 +228,23 @@ class AlphaFold(nn.Module): ...@@ -220,15 +228,23 @@ class AlphaFold(nn.Module):
seq_mask = feats["seq_mask"] seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"] msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
if (self.globals.is_multimer): if self.globals.is_multimer:
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder(feats) m, z = self.input_embedder(feats)
elif self.seqemb_mode:
# Initialize the SingleSeq and pair representations
# m: [*, 1, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["seq_embedding"]
)
else: else:
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder( m, z = self.input_embedder(
......
...@@ -159,7 +159,7 @@ def run_model(model, batch, tag, output_dir): ...@@ -159,7 +159,7 @@ def run_model(model, batch, tag, output_dir):
out = model(batch) out = model(batch)
inference_time = time.perf_counter() - t inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}") logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(output_dir, "timings.json")) update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))
model.config.template.enabled = template_enabled model.config.template.enabled = template_enabled
......
...@@ -58,6 +58,7 @@ from openfold.utils.trace_utils import ( ...@@ -58,6 +58,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq, pad_feature_dict_seq,
trace_model_, trace_model_,
) )
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args from scripts.utils import add_data_args
...@@ -80,17 +81,28 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer): ...@@ -80,17 +81,28 @@ def precompute_alignments(tags, seqs, alignment_dir, args, is_multimer):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( # In seqemb mode, use AlignmentRunner only to generate templates
jackhmmer_binary_path=args.jackhmmer_binary_path, if args.use_single_seq_mode:
hhblits_binary_path=args.hhblits_binary_path, alignment_runner = data_pipeline.AlignmentRunner(
uniref90_database_path=args.uniref90_database_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
mgnify_database_path=args.mgnify_database_path, uniref90_database_path=args.uniref90_database_path,
bfd_database_path=args.bfd_database_path, no_cpus=args.cpus,
uniref30_database_path=args.uniref30_database_path, )
uniclust30_database_path=args.uniclust30_database_path, embedding_generator = EmbeddingGenerator()
uniprot_database_path=args.uniprot_database_path, embedding_generator.run(tmp_fasta_path, alignment_dir)
no_cpus=args.cpus, else:
) alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniref30_database_path=args.uniref30_database_path,
uniclust30_database_path=args.uniclust30_database_path,
uniprot_database_path=args.uniprot_database_path,
no_cpus=args.cpus,
)
alignment_runner.run( alignment_runner.run(
tmp_fasta_path, local_alignment_dir tmp_fasta_path, local_alignment_dir
) )
...@@ -123,7 +135,9 @@ def generate_feature_dict( ...@@ -123,7 +135,9 @@ def generate_feature_dict(
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir fasta_path=tmp_fasta_path,
alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode,
) )
elif "multimer" in args.config_preset: elif "multimer" in args.config_preset:
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
...@@ -155,6 +169,9 @@ def main(args): ...@@ -155,6 +169,9 @@ def main(args):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model): if(args.trace_model):
...@@ -389,6 +406,10 @@ if __name__ == "__main__": ...@@ -389,6 +406,10 @@ if __name__ == "__main__":
help="""Path to alignment directory. If provided, alignment computation help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""" is skipped and database path arguments are ignored."""
) )
parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs."""
)
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default=os.getcwd(), "--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""", help="""Name of the directory in which to output the prediction""",
......
...@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..." ...@@ -46,4 +46,4 @@ echo "Downloading AlphaFold parameters..."
bash scripts/download_alphafold_params.sh openfold/resources bash scripts/download_alphafold_params.sh openfold/resources
# Decompress test data # Decompress test data
gunzip tests/test_data/sample_feats.pickle.gz gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.pickle
# Some functions borrowed from [ESM](https://www.github.com/facebookresearch/esm)
import argparse
import logging
import os
import torch
from openfold.data import parsers
logging.basicConfig(level=logging.INFO)
class SequenceDataset(object):
def __init__(self, labels, sequences) -> None:
self.labels = labels
self.sequences = sequences
@classmethod
def from_file(cls, fasta_file):
labels, sequences = [], []
with open(fasta_file, "r") as infile:
fasta_str = infile.read()
sequences, labels = parsers.parse_fasta(fasta_str)
assert len(set(labels)) == len(labels),\
"Sequence labels need to be unique. Duplicates found!"
return cls(labels, sequences)
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return self.labels[idx], self.sequences[idx]
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq):
sizes = [(len(s), i) for i, s in enumerate(self.sequences)]
sizes.sort()
batches = []
buf = []
max_len = 0
def _flush_current_buf():
nonlocal max_len, buf
if len(buf) == 0:
return
batches.append(buf)
buf = []
max_len = 0
for sz, i in sizes:
sz += extra_toks_per_seq
if max(sz, max_len) * (len(buf)+1) > toks_per_batch:
_flush_current_buf()
max_len = max(max_len, sz)
buf.append(i)
_flush_current_buf()
return batches
class EmbeddingGenerator:
"""Generates the ESM-1b embeddings for the single sequence model"""
def __init__(self,
toks_per_batch: int = 4096,
truncate: bool = True,
use_local_esm: str = None,
nogpu: bool = False,
):
self.toks_per_batch = toks_per_batch
self.truncate = truncate
self.use_local_esm = use_local_esm
self.nogpu = nogpu
# Generate embeddings in bulk
if self.use_local_esm:
self.model, self.alphabet = torch.hub.load(self.use_local_esm, "esm1b_t33_650M_UR50S", source='local')
else:
self.model, self.alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
if torch.cuda.is_available() and not self.nogpu:
self.model = self.model.to(device="cuda")
def parse_sequences(self, fasta_dir, output_dir):
labels = []
seqs = []
# Generate a single bulk file
for f in os.listdir(fasta_dir):
f_name, ext = os.path.splitext(f)
if ext != '.fasta' and ext != '.fa':
logging.warning(f"Ignoring non-FASTA file: {f}")
continue
with open(os.path.join(fasta_dir, f), 'r') as infile:
seq = infile.readlines()[1].strip()
labels.append(f_name)
seqs.append(seq)
lines = []
for label, seq in zip(labels, seqs):
lines += f'>{label}\n'
lines += f'{seq}\n'
os.makedirs(output_dir, exist_ok=True)
temp_fasta_file = os.path.join(output_dir, 'temp.fasta')
with open(temp_fasta_file, 'w') as outfile:
outfile.writelines(lines)
return temp_fasta_file
def run(
self,
fasta_file,
output_dir,
):
dataset = SequenceDataset.from_file(fasta_file)
batches = dataset.get_batch_indices(self.toks_per_batch, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(
dataset, collate_fn=self.alphabet.get_batch_converter(), batch_sampler=batches
)
logging.info("Loaded all sequences")
repr_layers = [33]
with torch.no_grad():
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
logging.info(f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)")
if torch.cuda.is_available() and not self.nogpu:
toks = toks.to(device="cuda", non_blocking=True)
if self.truncate:
toks = toks[:1022]
out = self.model(toks, repr_layers=repr_layers, return_contacts=False)
representations = {
33: out["representations"][33].to(device="cpu")
}
for i, label in enumerate(labels):
os.makedirs(os.path.join(output_dir, label), exist_ok=True)
result = {"label": label}
result["representations"] = {
33: representations[33][i, 1: len(strs[i]) + 1].clone()
}
torch.save(
result,
os.path.join(output_dir, label, label+".pt")
)
def main(args):
logging.info("Loading the model...")
embedding_generator = EmbeddingGenerator(
args.toks_per_batch,
args.truncate,
args.use_local_esm,
args.nogpu)
logging.info("Loading the sequences and running the inference...")
temp_fasta_file = embedding_generator.parse_sequences(
args.fasta_dir,
args.output_dir
)
embedding_generator.run(
temp_fasta_file,
args.output_dir
)
os.remove(temp_fasta_file)
logging.info("Completed.")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"fasta_dir", type=str,
help="""Path to directory containing FASTA files."""
)
parser.add_argument(
"output_dir", type=str,
help="Directory in which to output embeddings"
)
parser.add_argument(
"--toks_per_batch", type=int, default=4096,
help="maximum tokens in a batch"
)
parser.add_argument(
"--truncate", action="store_true", default=True,
help="Truncate sequences longer than 1022 (ESM restriction). Default: True"
)
parser.add_argument(
"--use_local_esm", type=str, default=None,
help="Use a local ESM repository instead of cloning from Github"
)
parser.add_argument(
"--nogpu", action="store_true",
help="Do not use GPU"
)
args = parser.parse_args()
main(args)
...@@ -20,6 +20,7 @@ from tests.data_utils import random_asym_ids ...@@ -20,6 +20,7 @@ from tests.data_utils import random_asym_ids
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer, InputEmbedderMultimer,
PreembeddingEmbedder,
RecyclingEmbedder, RecyclingEmbedder,
TemplateSingleEmbedder, TemplateSingleEmbedder,
TemplatePairEmbedder TemplatePairEmbedder
...@@ -66,6 +67,28 @@ class TestInputEmbedder(unittest.TestCase): ...@@ -66,6 +67,28 @@ class TestInputEmbedder(unittest.TestCase):
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
class TestPreembeddingEmbedder(unittest.TestCase):
def test_shape(self):
tf_dim = 22
preembedding_dim = 1280
c_z = 4
c_m = 6
relpos_k = 10
batch_size = 4
num_res = 20
tf = torch.rand((batch_size, num_res, tf_dim))
ri = torch.rand((batch_size, num_res))
preemb = torch.rand((batch_size, num_res, preembedding_dim))
pe = PreembeddingEmbedder(tf_dim, preembedding_dim, c_z, c_m, relpos_k)
seq_emb, pair_emb = pe(tf, ri, preemb)
self.assertTrue(seq_emb.shape == (batch_size, 1, num_res, c_m))
self.assertTrue(pair_emb.shape == (batch_size, num_res, num_res, c_z))
class TestRecyclingEmbedder(unittest.TestCase): class TestRecyclingEmbedder(unittest.TestCase):
def test_shape(self): def test_shape(self):
batch_size = 2 batch_size = 2
......
...@@ -68,8 +68,9 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -68,8 +68,9 @@ class TestEvoformerStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first, no_column_attention=False,
fuse_projection_weights, opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -91,6 +92,64 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -91,6 +92,64 @@ class TestEvoformerStack(unittest.TestCase):
self.assertTrue(z.shape == shape_z_before) self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s)) self.assertTrue(s.shape == (batch_size, n_res, c_s))
def test_shape_without_column_attention(self):
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c_z = consts.c_z
c_hidden_msa_att = 12
c_hidden_opm = 17
c_hidden_mul = 19
c_hidden_pair_att = 14
c_s = consts.c_s
no_heads_msa = 3
no_heads_pair = 7
no_blocks = 2
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
inf = 1e9
eps = 1e-10
es = EvoformerStack(
c_m,
c_z,
c_hidden_msa_att,
c_hidden_opm,
c_hidden_mul,
c_hidden_pair_att,
c_s,
no_heads_msa,
no_heads_pair,
no_blocks,
transition_n,
msa_dropout,
pair_stack_dropout,
no_column_attention=True,
opm_first=False,
fuse_projection_weights=False,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
).eval()
m_init = torch.rand((batch_size, n_seq, n_res, c_m))
z_init = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_m_before = m_init.shape
shape_z_before = z_init.shape
m, z, s = es(
m_init, z_init, chunk_size=4, msa_mask=msa_mask, pair_mask=pair_mask
)
self.assertTrue(m.shape == shape_m_before)
self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s))
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def run_ei(activations, masks): def run_ei(activations, masks):
...@@ -215,7 +274,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -215,7 +274,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
pair_mask = torch.randint( pair_mask = torch.randint(
0, 0,
2, 2,
...@@ -225,7 +284,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -225,7 +284,7 @@ class TestExtraMSAStack(unittest.TestCase):
n_res, n_res,
), ),
device="cuda", device="cuda",
) ).float()
shape_z_before = z.shape shape_z_before = z.shape
......
...@@ -60,7 +60,7 @@ class TestModel(unittest.TestCase): ...@@ -60,7 +60,7 @@ class TestModel(unittest.TestCase):
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
model = AlphaFold(c) model = AlphaFold(c).cuda()
model.eval() model.eval()
batch = {} batch = {}
...@@ -94,6 +94,49 @@ class TestModel(unittest.TestCase): ...@@ -94,6 +94,49 @@ class TestModel(unittest.TestCase):
) )
batch = tensor_tree_map(add_recycling_dims, batch) batch = tensor_tree_map(add_recycling_dims, batch)
to_cuda_device = lambda t: t.cuda()
batch = tensor_tree_map(to_cuda_device, batch)
with torch.no_grad():
out = model(batch)
def test_dry_run_seqemb_mode(self):
n_seq = 1
n_templ = consts.n_templ
n_res = consts.n_res
msa_dim = 49
c = model_config("seq_model_esm1b")
c.model.evoformer_stack.no_blocks = 2
c.model.evoformer_stack.blocks_per_ckpt = None
model = AlphaFold(c)
model.to(torch.device('cuda'))
model.eval()
batch = {}
tf = torch.randint(c.model.preembedding_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(tf, c.model.preembedding_embedder.tf_dim).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, msa_dim))
batch["seq_embedding"] = torch.rand((n_res, c.model.preembedding_embedder.preembedding_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["msa_mask"] = torch.randint(low=0, high=2, size=(n_seq, n_res)).float()
batch["no_recycling_iters"] = torch.tensor(2.)
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
batch = tensor_tree_map(add_recycling_dims, batch)
to_cuda_device = lambda t: t.to(torch.device("cuda"))
batch = tensor_tree_map(to_cuda_device, batch)
with torch.no_grad(): with torch.no_grad():
out = model(batch) out = model(batch)
......
...@@ -436,7 +436,11 @@ if __name__ == "__main__": ...@@ -436,7 +436,11 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None, "--train_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during training" help="Path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument(
"--use_single_seq_mode", type=str, default=False,
help="Use single sequence embeddings instead of MSAs."
) )
parser.add_argument( parser.add_argument(
"--distillation_data_dir", type=str, default=None, "--distillation_data_dir", type=str, default=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment