Commit 5aa54958 authored by Christina Floristean's avatar Christina Floristean
Browse files

Merge branch 'main' into deepspeed-evo-attention

parents f545323c 099769d2
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "daily"
...@@ -10,6 +10,6 @@ jobs: ...@@ -10,6 +10,6 @@ jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- name: Build the Docker image - name: Build the Docker image
run: docker build . --file Dockerfile --tag openfold:$(date +%s) run: docker build . --file Dockerfile --tag openfold:$(date +%s)
\ No newline at end of file
...@@ -4,8 +4,8 @@ jobs: ...@@ -4,8 +4,8 @@ jobs:
undefined_names: undefined_names:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- uses: actions/setup-python@v2 - uses: actions/setup-python@v4
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install flake8 - run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu18.04 FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04
# metainformation # metainformation
LABEL org.opencontainers.image.version = "1.0.0" LABEL org.opencontainers.image.version = "1.0.0"
...@@ -13,24 +13,23 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/ ...@@ -13,24 +13,23 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniconda3-latest-Linux-x86_64.sh && rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
COPY environment.yml /opt/openfold/environment.yml COPY environment.yml /opt/openfold/environment.yml
# installing into the base environment since the docker container wont do anything other than run openfold # installing into the base environment since the docker container wont do anything other than run openfold
RUN conda env update -n base --file /opt/openfold/environment.yml && conda clean --all RUN mamba env update -n base --file /opt/openfold/environment.yml && mamba clean --all
RUN export LD_LIBRARY_PATH=${CONDA_PREFIX}/lib:${LD_LIBRARY_PATH}
COPY openfold /opt/openfold/openfold COPY openfold /opt/openfold/openfold
COPY scripts /opt/openfold/scripts COPY scripts /opt/openfold/scripts
COPY run_pretrained_openfold.py /opt/openfold/run_pretrained_openfold.py COPY run_pretrained_openfold.py /opt/openfold/run_pretrained_openfold.py
COPY train_openfold.py /opt/openfold/train_openfold.py COPY train_openfold.py /opt/openfold/train_openfold.py
COPY setup.py /opt/openfold/setup.py COPY setup.py /opt/openfold/setup.py
COPY lib/openmm.patch /opt/openfold/lib/openmm.patch
RUN wget -q -P /opt/openfold/openfold/resources \ RUN wget -q -P /opt/openfold/openfold/resources \
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
RUN patch -p0 -d /opt/conda/lib/python3.9/site-packages/ < /opt/openfold/lib/openmm.patch
WORKDIR /opt/openfold WORKDIR /opt/openfold
RUN python3 setup.py install RUN python3 setup.py install
...@@ -29,7 +29,7 @@ vice versa (see `scripts/convert_of_weights_to_jax.py`). ...@@ -29,7 +29,7 @@ vice versa (see `scripts/convert_of_weights_to_jax.py`).
OpenFold has the following advantages over the reference implementation: OpenFold has the following advantages over the reference implementation:
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs. - **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on Ampere or higher architecture GPUs.
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention - **Inference on extremely long chains**, made possible by our implementation of low-memory attention
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of ([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading. sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
...@@ -49,37 +49,19 @@ and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (night ...@@ -49,37 +49,19 @@ and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (night
installed on your system. You'll need `git-lfs` to download OpenFold parameters. installed on your system. You'll need `git-lfs` to download OpenFold parameters.
Finally, some download scripts require `aria2c` and `aws`. Finally, some download scripts require `aria2c` and `aws`.
For convenience, we provide a script that installs Miniconda locally, creates a This package is currently supported for CUDA 11 and Pytorch 1.12
`conda` virtual environment, installs all Python dependencies, and downloads
useful resources, including both sets of model parameters. Run:
```bash To install:
scripts/install_third_party_dependencies.sh 1. Clone the repository, e.g. `git clone https://github.com/aqlaboratory/openfold.git`
``` 1. From the `openfold` repo:
- Create a [Mamba]("https://github.com/conda-forge/miniforge/releases/latest/download/) environment, e.g.
To activate the environment, run: `mamba env create -n openfold_env -f environment.yml`
Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process.
```bash - Activate the environment, e.g `conda activate openfold_env`
source scripts/activate_conda_env.sh 1. Run `scripts/install_third_party_dependencies.sh` to configure kernels and folding resources.
```
To deactivate it, run: For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
```bash
source scripts/deactivate_conda_env.sh
```
With the environment active, compile OpenFold's CUDA kernels with
```bash
python3 setup.py install
```
To install the HH-suite to `/usr/bin`, run
```bash
# scripts/install_hh_suite.sh
```
## Usage ## Usage
...@@ -233,6 +215,51 @@ efficent AlphaFold-Multimer more than double the time. Use the ...@@ -233,6 +215,51 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option `--long_sequence_inference` command line option
#### SoloSeq Inference
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
```bash
python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
```
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
Now, you are ready to run inference:
```bash
python run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--use_precomputed_alignments embeddings_output_dir \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt
```
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
```bash
python3 run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b_ptm" \
--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b_ptm.pt \
--uniref90_database_path data/uniref90/uniref90.fasta \
--pdb70_database_path data/pdb70/pdb70 \
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
```
For generating template information, you will need the UniRef90 and PDB70 databases and the JackHmmer and HHSearch binaries.
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
### Training ### Training
To train the model, you will first need to precompute protein alignments. To train the model, you will first need to precompute protein alignments.
...@@ -440,7 +467,7 @@ Please cite our paper: ...@@ -440,7 +467,7 @@ Please cite our paper:
```bibtex ```bibtex
@article {Ahdritz2022.11.20.517210, @article {Ahdritz2022.11.20.517210,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed}, author = {Ahdritz, Gustaf and Bouatta, Nazim and Floristean, Christina and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
title = {{O}pen{F}old: {R}etraining {A}lpha{F}old2 yields new insights into its learning mechanisms and capacity for generalization}, title = {{O}pen{F}old: {R}etraining {A}lpha{F}old2 yields new insights into its learning mechanisms and capacity for generalization},
elocation-id = {2022.11.20.517210}, elocation-id = {2022.11.20.517210},
year = {2022}, year = {2022},
......
name: openfold_venv name: openfold-venv
channels: channels:
- conda-forge - conda-forge
- bioconda - bioconda
- pytorch - pytorch
dependencies: dependencies:
- conda-forge::python=3.9 - python=3.9
- conda-forge::setuptools=59.5.0 - libgcc=7.2
- conda-forge::pip - setuptools=59.5.0
- conda-forge::openmm=7.5.1 - pip
- conda-forge::pdbfixer - openmm=7.7
- conda-forge::cudatoolkit==11.3.* - pdbfixer
- cudatoolkit==11.3.*
- pytorch-lightning==1.5.10
- biopython==1.79
- numpy==1.21
- PyYAML==5.4.1
- requests
- scipy==1.7
- tqdm==4.62.2
- typing-extensions==3.10
- wandb==0.12.21
- modelcif==0.7
- awscli
- ml-collections
- aria2
- git
- bioconda::hmmer==3.3.2 - bioconda::hmmer==3.3.2
- bioconda::hhsuite==3.3.0 - bioconda::hhsuite==3.3.0
- bioconda::kalign2==2.04 - bioconda::kalign2==2.04
- pytorch::pytorch=1.12.* - pytorch::pytorch=1.12.*
- pip: - pip:
- biopython==1.79 - deepspeed==0.12.2
- dm-tree==0.1.6 - dm-tree==0.1.6
- ml-collections==0.1.0
- numpy==1.21.2
- PyYAML==5.4.1
- requests==2.26.0
- scipy==1.7.1
- tqdm==4.62.2
- typing-extensions==3.10.0.2
- pytorch_lightning==1.5.10
- wandb==0.12.21
- modelcif==0.7
- git+https://github.com/NVIDIA/dllogger.git - git+https://github.com/NVIDIA/dllogger.git
- git+https://github.com/microsoft/DeepSpeed.git - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8
# TODO: Replace above when version becomes available
# - deepspeed==0.10.4
Index: simtk/openmm/app/topology.py
===================================================================
--- simtk.orig/openmm/app/topology.py
+++ simtk/openmm/app/topology.py
@@ -356,19 +356,35 @@
def isCyx(res):
names = [atom.name for atom in res._atoms]
return 'SG' in names and 'HG' not in names
+ # This function is used to prevent multiple di-sulfide bonds from being
+ # assigned to a given atom. This is a DeepMind modification.
+ def isDisulfideBonded(atom):
+ for b in self._bonds:
+ if (atom in b and b[0].name == 'SG' and
+ b[1].name == 'SG'):
+ return True
+
+ return False
cyx = [res for res in self.residues() if res.name == 'CYS' and isCyx(res)]
atomNames = [[atom.name for atom in res._atoms] for res in cyx]
for i in range(len(cyx)):
sg1 = cyx[i]._atoms[atomNames[i].index('SG')]
pos1 = positions[sg1.index]
+ candidate_distance, candidate_atom = 0.3*nanometers, None
for j in range(i):
sg2 = cyx[j]._atoms[atomNames[j].index('SG')]
pos2 = positions[sg2.index]
delta = [x-y for (x,y) in zip(pos1, pos2)]
distance = sqrt(delta[0]*delta[0] + delta[1]*delta[1] + delta[2]*delta[2])
- if distance < 0.3*nanometers:
- self.addBond(sg1, sg2)
+ if distance < candidate_distance and not isDisulfideBonded(sg2):
+ candidate_distance = distance
+ candidate_atom = sg2
+ # Assign bond to closest pair.
+ if candidate_atom:
+ self.addBond(sg1, candidate_atom)
+
+
class Chain(object):
"""A Chain object represents a chain within a Topology."""
...@@ -152,9 +152,42 @@ def model_config( ...@@ -152,9 +152,42 @@ def model_config(
c.model.template.enabled = False c.model.template.enabled = False
c.model.heads.tm.enabled = True c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1 c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "seq_model_esm1b":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
elif name == "seq_model_esm1b_ptm":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else: else:
raise ValueError("Invalid model name") raise ValueError("Invalid model name")
if name.startswith("seq"):
# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.update(seq_mode_config.copy_and_resolve_references())
if long_sequence_inference: if long_sequence_inference:
assert(not train) assert(not train)
c.globals.offload_inference = True c.globals.offload_inference = True
...@@ -189,6 +222,11 @@ c_m = mlc.FieldReference(256, field_type=int) ...@@ -189,6 +222,11 @@ c_m = mlc.FieldReference(256, field_type=int)
c_t = mlc.FieldReference(64, field_type=int) c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int) c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int) c_s = mlc.FieldReference(384, field_type=int)
# For seqemb mode, dimension size of the per-residue sequence embedding passed to the model
# In current model, the dimension size is the ESM-1b dimension size i.e. 1280.
preemb_dim_size = mlc.FieldReference(1280, field_type=int)
blocks_per_ckpt = mlc.FieldReference(None, field_type=int) blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int) chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int) aux_distogram_bins = mlc.FieldReference(64, field_type=int)
...@@ -301,6 +339,9 @@ config = mlc.ConfigDict( ...@@ -301,6 +339,9 @@ config = mlc.ConfigDict(
"use_templates": templates_enabled, "use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles, "use_template_torsion_angles": embed_template_torsion_angles,
}, },
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": False, # If True, use seq emb instead of MSA
},
"supervised": { "supervised": {
"clamp_prob": 0.9, "clamp_prob": 0.9,
"supervised_features": [ "supervised_features": [
...@@ -365,6 +406,7 @@ config = mlc.ConfigDict( ...@@ -365,6 +406,7 @@ config = mlc.ConfigDict(
}, },
# Recurring FieldReferences that can be changed globally here # Recurring FieldReferences that can be changed globally here
"globals": { "globals": {
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size, "chunk_size": chunk_size,
# Use DeepSpeed memory-efficient attention kernel. Mutually # Use DeepSpeed memory-efficient attention kernel. Mutually
...@@ -497,6 +539,7 @@ config = mlc.ConfigDict( ...@@ -497,6 +539,7 @@ config = mlc.ConfigDict(
"transition_n": 4, "transition_n": 4,
"msa_dropout": 0.15, "msa_dropout": 0.15,
"pair_dropout": 0.25, "pair_dropout": 0.25,
"no_column_attention": False,
"blocks_per_ckpt": blocks_per_ckpt, "blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False, "clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size, "tune_chunk_size": tune_chunk_size,
...@@ -618,3 +661,31 @@ config = mlc.ConfigDict( ...@@ -618,3 +661,31 @@ config = mlc.ConfigDict(
"ema": {"decay": 0.999}, "ema": {"decay": 0.999},
} }
) )
seq_mode_config = mlc.ConfigDict({
"data": {
"common": {
"feat": {
"seq_embedding": [NUM_RES, None],
},
"seqemb_features": [ # List of features to be generated in seqemb mode
"seq_embedding"
],
},
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": True, # If True, use seq emb instead of MSA
},
},
"globals": {
"seqemb_mode_enabled": True,
},
"model": {
"preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22,
"preembedding_dim": preemb_dim_size,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
}
})
\ No newline at end of file
...@@ -186,7 +186,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -186,7 +186,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object, mmcif=mmcif_object,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
chain_id=chain_id, chain_id=chain_id,
alignment_index=alignment_index alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled
) )
return data return data
...@@ -239,6 +240,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -239,6 +240,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
elif(ext == ".core"): elif(ext == ".core"):
data = self.data_pipeline.process_core( data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index, path, alignment_dir, alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
) )
elif(ext == ".pdb"): elif(ext == ".pdb"):
structure_index = None structure_index = None
...@@ -251,6 +253,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -251,6 +253,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
chain_id=chain_id, chain_id=chain_id,
alignment_index=alignment_index, alignment_index=alignment_index,
_structure_index=structure_index, _structure_index=structure_index,
seqemb_mode=self.config.seqemb_mode.enabled,
) )
else: else:
raise ValueError("Extension branch missing") raise ValueError("Extension branch missing")
...@@ -260,6 +263,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -260,6 +263,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
fasta_path=path, fasta_path=path,
alignment_dir=alignment_dir, alignment_dir=alignment_dir,
alignment_index=alignment_index, alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
) )
if(self._output_raw): if(self._output_raw):
......
...@@ -19,6 +19,7 @@ from multiprocessing import cpu_count ...@@ -19,6 +19,7 @@ from multiprocessing import cpu_count
from typing import Mapping, Optional, Sequence, Any from typing import Mapping, Optional, Sequence, Any
import numpy as np import numpy as np
import torch
from openfold.data import templates, parsers, mmcif_parsing from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.templates import get_custom_template_features from openfold.data.templates import get_custom_template_features
...@@ -260,6 +261,18 @@ def make_msa_features( ...@@ -260,6 +261,18 @@ def make_msa_features(
return features return features
# Generate 1-sequence MSA features having only the input sequence
def make_dummy_msa_feats(input_sequence):
msas = [[input_sequence]]
deletion_matrices = [[[0 for _ in input_sequence]]]
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)
return msa_features
def make_sequence_features_with_custom_template( def make_sequence_features_with_custom_template(
sequence: str, sequence: str,
mmcif_path: str, mmcif_path: str,
...@@ -627,11 +640,28 @@ class DataPipeline: ...@@ -627,11 +640,28 @@ class DataPipeline:
return msa_features return msa_features
# Load and process sequence embedding features
def _process_seqemb_features(self,
alignment_dir: str,
) -> Mapping[str, Any]:
seqemb_features = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
if (ext == ".pt"):
# Load embedding file
seqemb_data = torch.load(path)
seqemb_features["seq_embedding"] = seqemb_data["representations"][33]
return seqemb_features
def process_fasta( def process_fasta(
self, self,
fasta_path: str, fasta_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file""" """Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f: with open(fasta_path) as f:
...@@ -658,12 +688,19 @@ class DataPipeline: ...@@ -658,12 +688,19 @@ class DataPipeline:
num_res=num_res, num_res=num_res,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return { return {
**sequence_features, **sequence_features,
**msa_features, **msa_features,
**template_features **template_features,
**sequence_embedding_features,
} }
def process_mmcif( def process_mmcif(
...@@ -672,6 +709,7 @@ class DataPipeline: ...@@ -672,6 +709,7 @@ class DataPipeline:
alignment_dir: str, alignment_dir: str,
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a specific chain in an mmCIF object. Assembles features for a specific chain in an mmCIF object.
...@@ -696,10 +734,16 @@ class DataPipeline: ...@@ -696,10 +734,16 @@ class DataPipeline:
self.template_featurizer, self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"]) query_release_date=to_date(mmcif.header["release_date"])
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features} sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_pdb( def process_pdb(
self, self,
...@@ -709,6 +753,7 @@ class DataPipeline: ...@@ -709,6 +753,7 @@ class DataPipeline:
chain_id: Optional[str] = None, chain_id: Optional[str] = None,
_structure_index: Optional[str] = None, _structure_index: Optional[str] = None,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a PDB file. Assembles features for a protein in a PDB file.
...@@ -742,15 +787,22 @@ class DataPipeline: ...@@ -742,15 +787,22 @@ class DataPipeline:
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) sequence_embedding_features = {}
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features} return {**pdb_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_core( def process_core(
self, self,
core_path: str, core_path: str,
alignment_dir: str, alignment_dir: str,
alignment_index: Optional[str] = None, alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict: ) -> FeatureDict:
""" """
Assembles features for a protein in a ProteinNet .core file. Assembles features for a protein in a ProteinNet .core file.
...@@ -770,9 +822,15 @@ class DataPipeline: ...@@ -770,9 +822,15 @@ class DataPipeline:
self.template_featurizer, self.template_featurizer,
) )
msa_features = self._process_msa_feats(alignment_dir, input_sequence) sequence_embedding_features = {}
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
return {**core_feats, **template_features, **msa_features} return {**core_feats, **template_features, **msa_features, **sequence_embedding_features}
def process_multiseq_fasta(self, def process_multiseq_fasta(self,
fasta_path: str, fasta_path: str,
......
...@@ -40,9 +40,11 @@ def np_to_tensor_dict( ...@@ -40,9 +40,11 @@ def np_to_tensor_dict(
Returns: Returns:
A dictionary of features mapping feature names to features. Only the given A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out. features are returned, all other ones are filtered out.
""" """
# torch generates warnings if feature is already a torch Tensor
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = { tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features k: to_tensor(v) for k, v in np_example.items() if k in features
} }
return tensor_dict return tensor_dict
...@@ -61,6 +63,10 @@ def make_data_config( ...@@ -61,6 +63,10 @@ def make_data_config(
feature_names = cfg.common.unsupervised_features feature_names = cfg.common.unsupervised_features
# Add seqemb related features if using seqemb mode.
if cfg.seqemb_mode.enabled:
feature_names += cfg.common.seqemb_features
if cfg.common.use_templates: if cfg.common.use_templates:
feature_names += cfg.common.template_features feature_names += cfg.common.template_features
......
...@@ -139,6 +139,100 @@ class InputEmbedder(nn.Module): ...@@ -139,6 +139,100 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class PreembeddingEmbedder(nn.Module):
"""
Embeds the sequence pre-embedding passed to the model and the target_feat features.
"""
def __init__(
self,
tf_dim: int,
preembedding_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
End channel dimension of the incoming target features
preembedding_dim:
End channel dimension of the incoming embeddings
c_z:
Pair embedding dimension
c_m:
Single-Seq embedding dimension
relpos_k:
Window size used in relative position encoding
"""
super(PreembeddingEmbedder, self).__init__()
self.tf_dim = tf_dim
self.preembedding_dim = preembedding_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_preemb_m = Linear(self.preembedding_dim, c_m)
self.linear_preemb_z_i = Linear(self.preembedding_dim, c_z)
self.linear_preemb_z_j = Linear(self.preembedding_dim, c_z)
# Relative Positional Encoding
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Args:
ri:
"residue_index" feature of shape [*, N]
Returns:
Relative positional encoding of protein using the
residue_index feature
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
reshaped_bins = boundaries.view(((1,) * len(d.shape)) + (len(boundaries),))
d = d[..., None] - reshaped_bins
d = torch.abs(d)
d = torch.argmin(d, dim=-1)
d = nn.functional.one_hot(d, num_classes=len(boundaries)).float()
d = d.to(ri.dtype)
return self.linear_relpos(d)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
preemb: torch.Tensor,
inplace_safe: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
)
preemb_emb = self.linear_preemb_m(preemb[..., None, :, :]) + tf_m
preemb_emb_i = self.linear_preemb_z_i(preemb)
preemb_emb_j = self.linear_preemb_z_j(preemb)
pair_emb = self.relpos(ri.type(preemb_emb_i.dtype))
pair_emb = add(pair_emb,
preemb_emb_i[..., None, :],
inplace=inplace_safe)
pair_emb = add(pair_emb,
preemb_emb_j[..., None, :, :],
inplace=inplace_safe)
return preemb_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
Embeds the output of an iteration of the model for recycling. Embeds the output of an iteration of the model for recycling.
......
...@@ -87,7 +87,6 @@ class MSATransition(nn.Module): ...@@ -87,7 +87,6 @@ class MSATransition(nn.Module):
no_batch_dims=len(m.shape[:-2]), no_batch_dims=len(m.shape[:-2]),
) )
def forward( def forward(
self, self,
m: torch.Tensor, m: torch.Tensor,
...@@ -326,6 +325,7 @@ class EvoformerBlock(nn.Module): ...@@ -326,6 +325,7 @@ class EvoformerBlock(nn.Module):
transition_n: int, transition_n: int,
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
no_column_attention: bool,
inf: float, inf: float,
eps: float, eps: float,
): ):
...@@ -339,12 +339,15 @@ class EvoformerBlock(nn.Module): ...@@ -339,12 +339,15 @@ class EvoformerBlock(nn.Module):
inf=inf, inf=inf,
) )
self.msa_att_col = MSAColumnAttention( # Specifically, seqemb mode does not use column attention
c_m, self.no_column_attention = no_column_attention
c_hidden_msa_att, if not self.no_column_attention:
no_heads_msa, self.msa_att_col = MSAColumnAttention(
inf=inf, c_m,
) c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout) self.msa_dropout_layer = DropoutRowwise(msa_dropout)
...@@ -402,18 +405,20 @@ class EvoformerBlock(nn.Module): ...@@ -402,18 +405,20 @@ class EvoformerBlock(nn.Module):
), ),
inplace=inplace_safe, inplace=inplace_safe,
) )
m = add(m, # Specifically, column attention is not used in seqemb mode.
self.msa_att_col( if not self.no_column_attention:
m, m = add(m,
mask=msa_mask, self.msa_att_col(
chunk_size=chunk_size, m,
use_deepspeed_evo_attention=use_deepspeed_evo_attention, mask=msa_mask,
use_lma=use_lma, chunk_size=chunk_size,
use_flash=use_flash, use_deepspeed_evo_attention=use_deepspeed_evo_attention,
), use_lma=use_lma,
inplace=inplace_safe, use_flash=use_flash,
) ),
inplace=inplace_safe,
)
if(not inplace_safe): if(not inplace_safe):
input_tensors = [m, input_tensors[1]] input_tensors = [m, input_tensors[1]]
...@@ -605,6 +610,7 @@ class EvoformerStack(nn.Module): ...@@ -605,6 +610,7 @@ class EvoformerStack(nn.Module):
msa_dropout: float, msa_dropout: float,
pair_dropout: float, pair_dropout: float,
blocks_per_ckpt: int, blocks_per_ckpt: int,
no_column_attention: bool,
inf: float, inf: float,
eps: float, eps: float,
clear_cache_between_blocks: bool = False, clear_cache_between_blocks: bool = False,
...@@ -642,6 +648,9 @@ class EvoformerStack(nn.Module): ...@@ -642,6 +648,9 @@ class EvoformerStack(nn.Module):
Dropout used for pair activations Dropout used for pair activations
blocks_per_ckpt: blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint Number of Evoformer blocks in each activation checkpoint
no_column_attention:
When True, doesn't use column attention. Required for running
sequence embedding mode
clear_cache_between_blocks: clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation stack. Slows down each block but can reduce fragmentation
...@@ -668,6 +677,7 @@ class EvoformerStack(nn.Module): ...@@ -668,6 +677,7 @@ class EvoformerStack(nn.Module):
transition_n=transition_n, transition_n=transition_n,
msa_dropout=msa_dropout, msa_dropout=msa_dropout,
pair_dropout=pair_dropout, pair_dropout=pair_dropout,
no_column_attention=no_column_attention,
inf=inf, inf=inf,
eps=eps, eps=eps,
) )
......
...@@ -24,6 +24,7 @@ from openfold.model.embedders import ( ...@@ -24,6 +24,7 @@ from openfold.model.embedders import (
TemplateAngleEmbedder, TemplateAngleEmbedder,
TemplatePairEmbedder, TemplatePairEmbedder,
ExtraMSAEmbedder, ExtraMSAEmbedder,
PreembeddingEmbedder,
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
from openfold.model.heads import AuxiliaryHeads from openfold.model.heads import AuxiliaryHeads
...@@ -71,11 +72,19 @@ class AlphaFold(nn.Module): ...@@ -71,11 +72,19 @@ class AlphaFold(nn.Module):
self.config = config.model self.config = config.model
self.template_config = self.config.template self.template_config = self.config.template
self.extra_msa_config = self.config.extra_msa self.extra_msa_config = self.config.extra_msa
self.seqemb_mode = config.globals.seqemb_mode_enabled
# Main trunk + structure module # Main trunk + structure module
self.input_embedder = InputEmbedder( # If using seqemb mode, embed the sequence embeddings passed
**self.config["input_embedder"], # to the model ("preembeddings") instead of embedding the sequence
) if self.seqemb_mode:
self.input_embedder = PreembeddingEmbedder(
**self.config["preembedding_embedder"],
)
else:
self.input_embedder = InputEmbedder(
**self.config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**self.config["recycling_embedder"], **self.config["recycling_embedder"],
) )
...@@ -238,17 +247,27 @@ class AlphaFold(nn.Module): ...@@ -238,17 +247,27 @@ class AlphaFold(nn.Module):
seq_mask = feats["seq_mask"] seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :] pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"] msa_mask = feats["msa_mask"]
## Initialize the MSA and pair representations
# m: [*, S_c, N, C_m] ## Initialize the SingleSeq and pair representations
# m: [*, 1, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder( if self.seqemb_mode:
feats["target_feat"], m, z = self.input_embedder(
feats["residue_index"], feats["target_feat"],
feats["msa_feat"], feats["residue_index"],
inplace_safe=inplace_safe, feats["seq_embedding"]
) )
else:
## Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
inplace_safe=inplace_safe,
)
# Unpack the recycling embeddings. Removing them from the list allows # Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function, saving memory # them to be freed further down in this function, saving memory
......
...@@ -28,18 +28,10 @@ import openfold.utils.loss as loss ...@@ -28,18 +28,10 @@ import openfold.utils.loss as loss
from openfold.np.relax import cleanup, utils from openfold.np.relax import cleanup, utils
import ml_collections import ml_collections
import numpy as np import numpy as np
try: import openmm
# openmm >= 7.6 from openmm import unit
import openmm from openmm import app as openmm_app
from openmm import unit from openmm.app.internal.pdbstructure import PdbStructure
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk import openmm
from simtk import unit
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
ENERGY = unit.kilocalories_per_mole ENERGY = unit.kilocalories_per_mole
LENGTH = unit.angstroms LENGTH = unit.angstroms
......
...@@ -20,14 +20,8 @@ cases like removing chains of length one (see clean_structure). ...@@ -20,14 +20,8 @@ cases like removing chains of length one (see clean_structure).
import io import io
import pdbfixer import pdbfixer
try: from openmm import app
# openmm >= 7.6 from openmm.app import element
from openmm import app
from openmm.app import element
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app
from simtk.openmm.app import element
def fix_pdb(pdbfile, alterations_info): def fix_pdb(pdbfile, alterations_info):
......
...@@ -18,14 +18,8 @@ import io ...@@ -18,14 +18,8 @@ import io
from openfold.np import residue_constants from openfold.np import residue_constants
from Bio import PDB from Bio import PDB
import numpy as np import numpy as np
try: from openmm import app as openmm_app
# openmm >= 7.6 from openmm.app.internal.pdbstructure import PdbStructure
from openmm import app as openmm_app
from openmm.app.internal.pdbstructure import PdbStructure
except ImportError:
# openmm < 7.6 (requires DeepMind patch)
from simtk.openmm import app as openmm_app
from simtk.openmm.app.internal.pdbstructure import PdbStructure
def overwrite_pdb_coordinates(pdb_str: str, pos) -> str: def overwrite_pdb_coordinates(pdb_str: str, pos) -> str:
......
...@@ -159,7 +159,7 @@ def run_model(model, batch, tag, output_dir): ...@@ -159,7 +159,7 @@ def run_model(model, batch, tag, output_dir):
out = model(batch) out = model(batch)
inference_time = time.perf_counter() - t inference_time = time.perf_counter() - t
logger.info(f"Inference time: {inference_time}") logger.info(f"Inference time: {inference_time}")
update_timings({"inference": inference_time}, os.path.join(output_dir, "timings.json")) update_timings({tag: {"inference": inference_time}}, os.path.join(output_dir, "timings.json"))
model.config.template.enabled = template_enabled model.config.template.enabled = template_enabled
......
...@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import ( ...@@ -55,6 +55,7 @@ from openfold.utils.trace_utils import (
pad_feature_dict_seq, pad_feature_dict_seq,
trace_model_, trace_model_,
) )
from scripts.precompute_embeddings import EmbeddingGenerator
from scripts.utils import add_data_args from scripts.utils import add_data_args
...@@ -73,17 +74,29 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -73,17 +74,29 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
os.makedirs(local_alignment_dir) os.makedirs(local_alignment_dir)
alignment_runner = data_pipeline.AlignmentRunner( # In seqemb mode, use AlignmentRunner only to generate templates
jackhmmer_binary_path=args.jackhmmer_binary_path, if args.use_single_seq_mode:
hhblits_binary_path=args.hhblits_binary_path, alignment_runner = data_pipeline.AlignmentRunner(
hhsearch_binary_path=args.hhsearch_binary_path, jackhmmer_binary_path=args.jackhmmer_binary_path,
uniref90_database_path=args.uniref90_database_path, hhsearch_binary_path=args.hhsearch_binary_path,
mgnify_database_path=args.mgnify_database_path, uniref90_database_path=args.uniref90_database_path,
bfd_database_path=args.bfd_database_path, pdb70_database_path=args.pdb70_database_path,
uniclust30_database_path=args.uniclust30_database_path, no_cpus=args.cpus,
pdb70_database_path=args.pdb70_database_path, )
no_cpus=args.cpus, embedding_generator = EmbeddingGenerator()
) embedding_generator.run(tmp_fasta_path, alignment_dir)
else:
alignment_runner = data_pipeline.AlignmentRunner(
jackhmmer_binary_path=args.jackhmmer_binary_path,
hhblits_binary_path=args.hhblits_binary_path,
hhsearch_binary_path=args.hhsearch_binary_path,
uniref90_database_path=args.uniref90_database_path,
mgnify_database_path=args.mgnify_database_path,
bfd_database_path=args.bfd_database_path,
uniclust30_database_path=args.uniclust30_database_path,
pdb70_database_path=args.pdb70_database_path,
no_cpus=args.cpus,
)
alignment_runner.run( alignment_runner.run(
tmp_fasta_path, local_alignment_dir tmp_fasta_path, local_alignment_dir
) )
...@@ -116,7 +129,9 @@ def generate_feature_dict( ...@@ -116,7 +129,9 @@ def generate_feature_dict(
local_alignment_dir = os.path.join(alignment_dir, tag) local_alignment_dir = os.path.join(alignment_dir, tag)
feature_dict = data_processor.process_fasta( feature_dict = data_processor.process_fasta(
fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir fasta_path=tmp_fasta_path,
alignment_dir=local_alignment_dir,
seqemb_mode=args.use_single_seq_mode,
) )
else: else:
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
...@@ -140,6 +155,8 @@ def main(args): ...@@ -140,6 +155,8 @@ def main(args):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
if args.config_preset.startswith("seq"):
args.use_single_seq_mode = True
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model): if(args.trace_model):
...@@ -314,6 +331,10 @@ if __name__ == "__main__": ...@@ -314,6 +331,10 @@ if __name__ == "__main__":
help="""Path to alignment directory. If provided, alignment computation help="""Path to alignment directory. If provided, alignment computation
is skipped and database path arguments are ignored.""" is skipped and database path arguments are ignored."""
) )
parser.add_argument(
"--use_single_seq_mode", action="store_true", default=False,
help="""Use single sequence embeddings instead of MSAs."""
)
parser.add_argument( parser.add_argument(
"--output_dir", type=str, default=os.getcwd(), "--output_dir", type=str, default=os.getcwd(),
help="""Name of the directory in which to output the prediction""", help="""Name of the directory in which to output the prediction""",
......
#!/bin/bash #!/bin/bash
CONDA_INSTALL_URL=${CONDA_INSTALL_URL:-"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh"}
source scripts/vars.sh
# Install Miniconda locally
rm -rf lib/conda
rm -f /tmp/Miniconda3-latest-Linux-x86_64.sh
wget -P /tmp \
"${CONDA_INSTALL_URL}" \
&& bash /tmp/Miniconda3-latest-Linux-x86_64.sh -b -p lib/conda \
&& rm /tmp/Miniconda3-latest-Linux-x86_64.sh
# Grab conda-only packages
export PATH=lib/conda/bin:$PATH
lib/conda/bin/python3 -m pip install nvidia-pyindex
conda env create --name=${ENV_NAME} -f environment.yml
source scripts/activate_conda_env.sh
echo "Attempting to install FlashAttention"
git clone https://github.com/HazyResearch/flash-attention
CUR_DIR=$PWD
cd flash-attention
git checkout 5b838a8bef
python3 setup.py install
cd $CUR_DIR
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass.git
conda env config vars set CUTLASS_PATH=$PWD/cutlass
source scripts/activate_conda_env.sh
# Install DeepMind's OpenMM patch
OPENFOLD_DIR=$PWD
pushd lib/conda/envs/$ENV_NAME/lib/python3.9/site-packages/ \
&& patch -p0 < $OPENFOLD_DIR/lib/openmm.patch \
&& popd
# Download folding resources # Download folding resources
wget --no-check-certificate -P openfold/resources \ wget -N --no-check-certificate -P openfold/resources \
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
# Certain tests need access to this file # Certain tests need access to this file
mkdir -p tests/test_data/alphafold/common mkdir -p tests/test_data/alphafold/common
ln -rs openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common ln -rs openfold/resources/stereo_chemical_props.txt tests/test_data/alphafold/common
echo "Downloading OpenFold parameters..." # Decompress test data
bash scripts/download_openfold_params.sh openfold/resources gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.pickle
echo "Downloading AlphaFold parameters..." python setup.py install
bash scripts/download_alphafold_params.sh openfold/resources
# Decompress test data export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
gunzip tests/test_data/sample_feats.pickle.gz
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass
# This setting is used to fix a worker assignment issue during data loading
conda env config vars set KMP_AFFINITY=none
# Reactivate env so that the above environment variables take effect
conda activate $CONDA_PREFIX
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment