README.md 19.2 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
![header ](imgs/of_banner.png)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
2
_Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
3

4

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
5
6
# OpenFold

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
7
A faithful but trainable PyTorch reproduction of DeepMind's 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
8
9
[AlphaFold 2](https://github.com/deepmind/alphafold).

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
11
12
## Features

OpenFold carefully reproduces (almost) all of the features of the original open
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
13
source inference code (v2.0.1). The sole exception is model ensembling, which 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
14
fared poorly in DeepMind's own ablation testing and is being phased out in future
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
15
DeepMind experiments. It is omitted here for the sake of reducing clutter. In 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
16
cases where the *Nature* paper differs from the source, we always defer to the 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
latter.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19
20
OpenFold is trainable in full precision or `bfloat16` with or without DeepSpeed, 
and we've trained it from scratch, matching the performance of the original. 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
21
We've publicly released model weights and our training data — some 400,000 
22
23
MSAs and PDB70 template hit files — under a permissive license. Model weights 
are available via scripts in this repository while the MSAs are hosted by the 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
24
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold). 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26

27
28
OpenFold also supports inference using AlphaFold's official parameters, and 
vice versa (see `scripts/convert_of_weights_to_jax.py`).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
29

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
30
OpenFold has the following advantages over the reference implementation:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
32
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on (>= Ampere) GPUs.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
33
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
  sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
37
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s 
kernels support in-place attention during inference and training. They use 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
39
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch 
implementations, respectively.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
- **FlashAttention** support greatly speeds up MSA attention.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
43
## Installation (Linux)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
44

45
46
47
All Python dependencies are specified in `environment.yml`. For producing sequence 
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite), 
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)} 
48
installed on on your system. You'll need `git-lfs` to download OpenFold parameters. 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
Finally, some download scripts require `aria2c` and `aws`.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51
For convenience, we provide a script that installs Miniconda locally, creates a 
52
`conda` virtual environment, installs all Python dependencies, and downloads
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
useful resources, including both sets of model parameters. Run:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
54
55

```bash
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
57
58
scripts/install_third_party_dependencies.sh
```

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59
To activate the environment, run:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
60
61

```bash
sft-managed's avatar
sft-managed committed
62
source scripts/activate_conda_env.sh
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
64
```

65
To deactivate it, run:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
67

```bash
sft-managed's avatar
sft-managed committed
68
source scripts/deactivate_conda_env.sh
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
70
```

71
72
73
74
75
76
With the environment active, compile OpenFold's CUDA kernels with

```bash
python3 setup.py install
```

77
78
79
80
81
82
To install the HH-suite to `/usr/bin`, run

```bash
# scripts/install_hh_suite.sh
```

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83
## Usage
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85
To download the databases used to train OpenFold and AlphaFold run:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
86
87

```bash
Eric Ma's avatar
Eric Ma committed
88
bash scripts/download_data.sh data/
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
90
```

Gustaf's avatar
Gustaf committed
91
92
93
You have two choices for downloading protein databases, depending on whether 
you want to use DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or 
[ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster
94
MMseqs2 instead. For the former, run:
Gustaf's avatar
Gustaf committed
95
96

```bash
Eric Ma's avatar
Eric Ma committed
97
bash scripts/download_alphafold_dbs.sh data/
Gustaf's avatar
Gustaf committed
98
99
100
101
102
```

For the latter, run:

```bash
Eric Ma's avatar
Eric Ma committed
103
104
bash scripts/download_mmseqs_dbs.sh data/    # downloads .tar files
bash scripts/prep_mmseqs_dbs.sh data/        # unpacks and preps the databases
Gustaf's avatar
Gustaf committed
105
106
107
108
109
110
```

Make sure to run the latter command on the machine that will be used for MSA
generation (the script estimates how the precomputed database index used by
MMseqs2 should be split according to the memory available on the system).

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
Alternatively, you can use raw MSAs from our aforementioned MSA database or
112
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
114
115
116
117
the latter database, use `scripts/prep_proteinnet_msas.py` to convert the data 
into a format recognized by the OpenFold parser. The resulting directory 
becomes the `alignment_dir` used in subsequent steps. Use 
`scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text 
files.
118

119
120
121
For both inference and training, the model's hyperparameters can be tuned from
`openfold/config.py`. Of course, if you plan to perform inference using 
DeepMind's pretrained parameters, you will only be able to make changes that
122
123
do not affect the shapes of model parameters. For an example of initializing
the model, consult `run_pretrained_openfold.py`.
124

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125
### Inference
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
126

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
127
128
To run inference on a sequence or multiple sequences using a set of DeepMind's 
pretrained parameters, run e.g.:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
130
```bash
131
python3 run_pretrained_openfold.py \
132
    fasta_dir \
133
    data/pdb_mmcif/mmcif_files/ \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
134
135
136
137
    --uniref90_database_path data/uniref90/uniref90.fasta \
    --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
    --pdb70_database_path data/pdb70/pdb70 \
    --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
138
139
    --output_dir ./ \
    --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
140
    --model_device "cuda:0" \
sft-managed's avatar
sft-managed committed
141
142
143
144
    --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
    --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
    --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
    --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
145
    --config_preset "model_1_ptm"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
146
    --openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
147
```
148

Gustaf's avatar
Gustaf committed
149
150
151
where `data` is the same directory as in the previous step. If `jackhmmer`, 
`hhblits`, `hhsearch` and `kalign` are available at the default path of 
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152
If you've already computed alignments for the query, you have the option to 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
153
skip the expensive alignment computation here with 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154
155
`--use_precomputed_alignments`.

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
156
157
158
159
160
161
162
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files, 
respectively. For a breakdown of the differences between the different parameter 
files, see the README downloaded to `openfold/resources/openfold_params/`. Since 
OpenFold was trained under a newer training schedule than the one from which the 
`model_n` config presets are derived, there is no clean correspondence between 
`config_preset` settings and OpenFold checkpoints; the only restraint is that `*_ptm`
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
checkpoints must be run with `*_ptm` config presets.
164

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165
166
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
169
170
to `None` in the config. If a value is specified, OpenFold will attempt to 
dynamically tune it, considering the chunk size specified in the config as a 
minimum. This tuning process automatically ensures consistently fast runtimes 
regardless of input sequence length, but it also introduces some runtime 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171
172
173
variability, which may be undesirable for certain users. It is also recommended
to disable this feature for very long chains (see below). To do so, set the 
`tune_chunk_size` option in the config to `False`.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
176
177
178
For large-scale batch inference, we offer an optional tracing mode, which
massively improves runtimes at the cost of a lengthy model compilation process.
To enable it, add `--trace_model` to the inference command.

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179
180
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
in the config. Note that it appears to work best for sequences with < 1000 residues.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
181

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182
183
184
185
Input FASTA files containing multiple sequences are treated as complexes. In
this case, the inference script runs AlphaFold-Gap, a hack proposed
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer). To
186
187
run inference with AlphaFold-Multimer, use the (experimental) `multimer` branch 
instead.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
188

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189
To minimize memory usage during inference on long sequences, consider the
190
following changes:
191

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
192
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
193
194
195
196
197
198
stack is a major memory bottleneck for inference on long sequences. OpenFold
supports two mutually exclusive inference modes to address this issue. One,
`average_templates` in the `template` section of the config, is similar to the
solution offered by AlphaFold-Multimer, which is simply to average individual
template representations. Our version is modified slightly to accommodate 
weights trained using the standard template algorithm. Using said weights, we
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199
notice no significant difference in performance between our averaged template 
200
201
embeddings and the standard ones. The second, `offload_templates`, temporarily 
offloads individual template embeddings into CPU memory. The former is an 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
202
203
204
205
approximation while the latter is slightly slower; both are memory-efficient 
and allow the model to utilize arbitrarily many templates across sequence 
lengths. Both are disabled by default, and it is up to the user to determine 
which best suits their needs, if either.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
206
207
208
209
210
211
212
- Inference-time low-memory attention (LMA) can be enabled in the model config.
This setting trades off speed for vastly improved memory usage. By default,
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
These represent a favorable tradeoff in most memory-constrained cases.
Powerusers can choose to tweak these settings in 
`openfold/model/primitives.py`. For more information on the LMA algorithm,
see the aforementioned Staats & Rabe preprint.
213
214
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
wastes time.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215
216
217
- As a last resort, consider enabling `offload_inference`. This enables more
extensive CPU offloading at various bottlenecks throughout the model.

218
219
Using the most conservative settings, we were able to run inference on a 
4600-residue complex with a single A100. Compared to AlphaFold's own memory 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220
offloading mode, ours is considerably faster; the same complex takes the more 
221
efficent AlphaFold-Multimer more than double the time.
222

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
223
### Training
224

Gustaf's avatar
Gustaf committed
225
226
227
228
To train the model, you will first need to precompute protein alignments. 

You have two options. You can use the same procedure DeepMind used by running
the following:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
230
231

```bash
python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \
232
233
234
235
    --uniref90_database_path data/uniref90/uniref90.fasta \
    --mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
    --pdb70_database_path data/pdb70/pdb70 \
    --uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
236
    --bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
sft-managed's avatar
sft-managed committed
237
238
239
240
241
    --cpus 16 \
    --jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
    --hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
    --hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
    --kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
242
```
Gustaf's avatar
Gustaf committed
243

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244
245
246
As noted before, you can skip the `binary_path` arguments if these binaries are 
at `/usr/bin`. Expect this step to take a very long time, even for small 
numbers of proteins.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247

Gustaf's avatar
Gustaf committed
248
249
250
251
252
253
254
Alternatively, you can generate MSAs with the ColabFold pipeline (and templates
with HHsearch) with:

```bash
python3 scripts/precompute_alignments_mmseqs.py input.fasta \
    data/mmseqs_dbs \
    uniref30_2103_db \
Gustaf's avatar
Gustaf committed
255
    alignment_dir \
Gustaf's avatar
Gustaf committed
256
257
258
259
260
261
262
    ~/MMseqs2/build/bin/mmseqs \
    /usr/bin/hhsearch \
    --env_db colabfold_envdb_202108_db
    --pdb70 data/pdb70/pdb70
```

where `input.fasta` is a FASTA file containing one or more query sequences. To 
Gustaf's avatar
Gustaf committed
263
264
generate an input FASTA from a directory of mmCIF and/or ProteinNet .core 
files, we provide `scripts/data_dir_to_fasta.py`.
Gustaf's avatar
Gustaf committed
265

266
Next, generate a cache of certain datapoints in the template mmCIF files:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
267
268

```bash
269
python3 scripts/generate_mmcif_cache.py \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
270
271
272
    mmcif_dir/ \
    mmcif_cache.json \
    --no_workers 16
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
273
274
```

275
276
277
278
This cache is used to pre-filter templates. 

Next, generate a separate chain-level cache with data used for training-time 
data filtering:
279
280
281
282
283
284
285
286
287
288
289
290
291

```bash
python3 scripts/generate_chain_data_cache.py \
    mmcif_dir/ \
    chain_data_cache.json \
    --cluster_file clusters-by-entity-40.txt \
    --no_workers 16
```

where the `cluster_file` argument is a file of chain clusters, one cluster
per line (e.g. [PDB40](https://cdn.rcsb.org/resources/sequence/clusters/clusters-by-entity-40.txt)).

Finally, call the training script:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
292
293
294
295
296
297
298
299

```bash
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \
    2021-10-10 \ 
    --template_release_dates_cache_path mmcif_cache.json \ 
    --precision 16 \
    --gpus 8 --replace_sampler_ddp=True \
    --seed 42 \ # in multi-gpu settings, the seed must be specified
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
300
    --deepspeed_config_path deepspeed_config.json \
301
    --checkpoint_every_epoch \
302
    --resume_from_ckpt ckpt_dir/ \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
304
    --train_chain_data_cache_path chain_data_cache.json \
    --obsolete_pdbs_file_path obsolete.dat
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305
306
```

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
307
where `--template_release_dates_cache_path` is a path to the mmCIF cache. 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
308
309
Note that `template_mmcif_dir` can be the same as `mmcif_dir` which contains
training targets. A suitable DeepSpeed configuration file can be generated with 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
310
`scripts/build_deepspeed_config.py`. The training script is 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
311
312
written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) 
and supports the full range of training options that entails, including 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
314
315
316
multi-node distributed training, validation, and so on. For more information, 
consult PyTorch Lightning documentation and the `--help` flag of the training 
script.

317
318
319
320
If you're using your own MSAs or MSAs from the RODA repository, make sure that
the `alignment_dir` contains one directory per chain and that each of these
contains alignments (.sto, .a3m, and .hhr) corresponding to that chain.

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
321
322
323
324
Note that, despite its variable name, `mmcif_dir` can also contain PDB files 
or even ProteinNet .core files. To emulate the AlphaFold training procedure, 
which uses a self-distillation set subject to special preprocessing steps, use
the family of `--distillation` flags.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
326
327
328
329
330
331
332
333
334
335
336
337
In cases where it may be burdensome to create separate files for each chain's
alignments, alignment directories can be consolidated using the scripts in 
`scripts/alignment_db_scripts/`. First, run `create_alignment_db.py` to
consolidate an alignment directory into a pair of database and index files.
Once all alignment directories (or shards of a single alignment directory)
have been compiled, unify the indices with `unify_alignment_db_indices`. The
resulting index, `super.index` can be passed to the training script flags
containing the phrase `alignment_index`. In this scenario, the `alignment_dir`
flags instead represent the directory containing the compiled alignment
databases. Both the training and distillation datasets can be compiled in this
way.

338
339
340
341
342
343
344
345
346
## Testing

To run unit tests, use

```bash
scripts/run_unit_tests.sh
```

The script is a thin wrapper around Python's `unittest` suite, and recognizes
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
347
`unittest` arguments. E.g., to run a specific test verbosely:
348
349
350
351
352

```bash
scripts/run_unit_tests.sh -v tests.test_model
```

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353
Certain tests require that AlphaFold (v2.0.1) be installed in the same Python
354
355
environment. These run components of AlphaFold and OpenFold side by side and
ensure that output activations are adequately similar. For most modules, we
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356
target a maximum pointwise difference of `1e-4`.
357

358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
## Building and using the docker container

### Building the docker image

Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:

```bash
docker build -t openfold .
```

### Running the docker container 

The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above. 

The docker container installs all conda components to the base conda environment in `/opt/conda`, and installs openfold itself in `/opt/openfold`,

Before running the docker container, you can verify that your docker installation is able to properly communicate with your GPU by running the following command:


```bash
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
```

Note the `--gpus all` option passed to `docker run`. This option is necessary in order for the container to use the GPUs on the host machine.

To run the inference code under docker, you can use a command like the one below.  In this example, parameters and sequences from the alphafold dataset are being used and are located at `/mnt/alphafold_database` on the host machine, and the input files are located in the current working directory. You can adjust the volume mount locations as needed to reflect the locations of your data. 

```bash
docker run \
--gpus all \
-v $PWD/:/data \
-v /mnt/alphafold_database/:/database \
-ti openfold:latest \
python3 /opt/openfold/run_pretrained_openfold.py \
392
/data/fasta_dir \
393
/database/pdb_mmcif/mmcif_files/ \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
394
395
396
397
--uniref90_database_path /database/uniref90/uniref90.fasta \
--mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \
--pdb70_database_path /database/pdb70/pdb70 \
--uniclust30_database_path /database/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
398
399
400
401
402
403
404
--output_dir /data \
--bfd_database_path /database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--model_device cuda:0 \
--jackhmmer_binary_path /opt/conda/bin/jackhmmer \
--hhblits_binary_path /opt/conda/bin/hhblits \
--hhsearch_binary_path /opt/conda/bin/hhsearch \
--kalign_binary_path /opt/conda/bin/kalign \
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
406
407
```

408
409
410
411
## Copyright notice

While AlphaFold's and, by extension, OpenFold's source code is licensed under
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters 
412
413
fall under the CC BY 4.0 license, a copy of which is downloaded to 
`openfold/resources/params` by the installation script. Note that the latter
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
416
417

## Contributing

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418
419
If you encounter problems using OpenFold, feel free to create an issue! We also
welcome pull requests from the community.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
420
421
422

## Citing this work

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
For now, cite OpenFold as follows:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
425
426
427
428
429
430
431
432
433
434
```bibtex
@software{Ahdritz_OpenFold_2021,
  author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and AlQuraishi, Mohammed},
  doi = {10.5281/zenodo.5709539},
  month = {11},
  title = {{OpenFold}},
  url = {https://github.com/aqlaboratory/openfold},
  year = {2021}
}
```
Gustaf Ahdritz's avatar
Add DOI  
Gustaf Ahdritz committed
435
436

Any work that cites OpenFold should also cite AlphaFold.