Commit 63293890 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Merge branch 'main' of github.com:aqlaboratory/openfold

parents e14a313e b59d0a81
cff-version: 1.2.0 cff-version: 1.2.0
message: "For now, cite OpenFold with its DOI." preferred-citation:
authors: authors:
- family-names: "Ahdritz" - family-names: "Ahdritz"
given-names: "Gustaf" given-names: "Gustaf"
orcid: https://orcid.org/0000-0001-8283-5324 orcid: https://orcid.org/0000-0001-8283-5324
- family-names: "Bouatta" - family-names: "Bouatta"
given-names: "Nazim" given-names: "Nazim"
orcid: https://orcid.org/0000-0002-6524-874X orcid: https://orcid.org/0000-0002-6524-874X
- family-names: "Kadyan" - family-names: "Kadyan"
given-names: "Sachin" given-names: "Sachin"
- family-names: "Xia" orcid: https://orcid.org/0000-0002-6079-7627
given-names: "Qinghui" - family-names: "Xia"
- family-names: "Gerecke" given-names: "Qinghui"
given-names: "William" - family-names: "Gerecke"
- family-names: "AlQuraishi" given-names: "William"
given-names: "Mohammed" orcid: https://orcid.org/0000-0002-9777-6192
orcid: https://orcid.org/0000-0001-6817-1322 - family-names: "O'Donnell"
title: "OpenFold" given-names: "Timothy J"
doi: 10.5281/zenodo.5709539 orcid: https://orcid.org/0000-0002-9949-069X
- family-names: "Berenberg"
given-names: "Daniel"
orcid: https://orcid.org/0000-0003-4631-0947
- family-names: "Fisk"
given-names: "Ian"
- family-names: "Zanichelli"
given-names: "Niccolò"
orcid: https://orcid.org/0000-0002-3093-3587
- family-names: "Zhang"
given-names: "Bo"
orcid: https://orcid.org/0000-0002-9714-2827
- family-names: "Nowaczynski"
given-names: "Arkadiusz"
orcid: https://orcid.org/0000-0002-3351-9584
- family-names: "Wang"
given-names: "Bei"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Stepniewska-Dziubinska"
given-names: "Marta M"
orcid: https://orcid.org/0000-0003-4942-9652
- family-names: "Zhang"
given-names: "Shang"
orcid: https://orcid.org/0000-0003-0759-2080
- family-names: "Ojewole"
given-names: "Adegoke"
orcid: https://orcid.org/0000-0003-2661-4388
- family-names: "Guney"
given-names: "Murat Efe"
- family-names: "Biderman"
given-names: "Stella"
orcid: https://orcid.org/0000-0001-8228-1042
- family-names: "Watkins"
given-names: "Andrew M"
orcid: https://orcid.org/0000-0003-1617-1720
- family-names: "Ra"
given-names: "Stephen"
orcid: https://orcid.org/0000-0002-2820-0050
- family-names: "Lorenzo"
given-names: "Pablo Ribalta"
orcid: https://orcid.org/0000-0002-3657-8053
- family-names: "Nivon"
given-names: "Lucas"
- family-names: "Weitzner"
given-names: "Brian"
orcid: https://orcid.org/0000-0002-1909-0961
- family-names: "Ban"
given-names: "Yih-En"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Ban"
given-names: "Yih-En Andrew"
orcid: https://orcid.org/0000-0003-3698-3574
- family-names: "Sorger"
given-names: "Peter K"
orcid: https://orcid.org/0000-0002-3364-1838
- family-names: "Mostaque"
given-names: "Emad"
- family-names: "Zhang"
given-names: "Zhao"
orcid: https://orcid.org/0000-0001-5921-0035
- family-names: "Bonneau"
given-names: "Richard"
orcid: https://orcid.org/0000-0003-4354-7906
- family-names: "AlQuraishi"
given-names: "Mohammed"
orcid: https://orcid.org/0000-0001-6817-1322
title: "OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization"
type: article
doi: 10.1101/2022.11.20.517210
doi: 10.1101/2022.11.20.517210
date-released: 2021-11-12 date-released: 2021-11-12
url: "https://github.com/aqlaboratory/openfold" url: "https://doi.org/10.1101/2022.11.20.517210"
...@@ -229,7 +229,8 @@ Using the most conservative settings, we were able to run inference on a ...@@ -229,7 +229,8 @@ Using the most conservative settings, we were able to run inference on a
offloading mode, ours is considerably faster; the same complex takes the more offloading mode, ours is considerably faster; the same complex takes the more
efficent AlphaFold-Multimer more than double the time. Use the efficent AlphaFold-Multimer more than double the time. Use the
`long_sequence_inference` config option to enable all of these interventions `long_sequence_inference` config option to enable all of these interventions
at once. at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option
### Training ### Training
...@@ -434,16 +435,20 @@ welcome pull requests from the community. ...@@ -434,16 +435,20 @@ welcome pull requests from the community.
## Citing this work ## Citing this work
For now, cite OpenFold as follows: Please cite our paper:
```bibtex ```bibtex
@software{Ahdritz_OpenFold_2021, @article {Ahdritz2022.11.20.517210,
author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and AlQuraishi, Mohammed}, author = {Ahdritz, Gustaf and Bouatta, Nazim and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
doi = {10.5281/zenodo.5709539}, title = {OpenFold: Retraining AlphaFold2 yields new insights into its learning mechanisms and capacity for generalization},
month = {11}, elocation-id = {2022.11.20.517210},
title = {{OpenFold}}, year = {2022},
url = {https://github.com/aqlaboratory/openfold}, doi = {10.1101/2022.11.20.517210},
year = {2021} publisher = {Cold Spring Harbor Laboratory},
abstract = {AlphaFold2 revolutionized structural biology with the ability to predict protein structures with exceptionally high accuracy. Its implementation, however, lacks the code and data required to train new models. These are necessary to (i) tackle new tasks, like protein-ligand complex structure prediction, (ii) investigate the process by which the model learns, which remains poorly understood, and (iii) assess the model{\textquoteright}s generalization capacity to unseen regions of fold space. Here we report OpenFold, a fast, memory-efficient, and trainable implementation of AlphaFold2, and OpenProteinSet, the largest public database of protein multiple sequence alignments. We use OpenProteinSet to train OpenFold from scratch, fully matching the accuracy of AlphaFold2. Having established parity, we assess OpenFold{\textquoteright}s capacity to generalize across fold space by retraining it using carefully designed datasets. We find that OpenFold is remarkably robust at generalizing despite extreme reductions in training set size and diversity, including near-complete elisions of classes of secondary structure elements. By analyzing intermediate structures produced by OpenFold during training, we also gain surprising insights into the manner in which the model learns to fold proteins, discovering that spatial dimensions are learned sequentially. Taken together, our studies demonstrate the power and utility of OpenFold, which we believe will prove to be a crucial new resource for the protein modeling community.},
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
journal = {bioRxiv}
} }
``` ```
......
...@@ -334,7 +334,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -334,7 +334,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
""" """
def __init__(self, def __init__(self,
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[float],
epoch_len: int, epoch_len: int,
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
...@@ -440,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -440,11 +440,6 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
stage_cfg = self.config[self.stage] stage_cfg = self.config[self.stage]
max_iters = self.config.common.max_recycling_iters max_iters = self.config.common.max_recycling_iters
if(stage_cfg.supervised):
clamp_prob = self.config.supervised.clamp_prob
keyed_probs.append(
("use_clamped_fape", [1 - clamp_prob, clamp_prob])
)
if(stage_cfg.uniform_recycling): if(stage_cfg.uniform_recycling):
recycling_probs = [ recycling_probs = [
......
...@@ -94,6 +94,21 @@ def np_example_to_features( ...@@ -94,6 +94,21 @@ def np_example_to_features(
cfg[mode], cfg[mode],
) )
if mode == "train":
p = torch.rand(1).item()
use_clamped_fape_value = float(p < cfg.supervised.clamp_prob)
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=use_clamped_fape_value,
dtype=torch.float32,
)
else:
features["use_clamped_fape"] = torch.full(
size=[cfg.common.max_recycling_iters + 1],
fill_value=0.0,
dtype=torch.float32,
)
return {k: v for k, v in features.items()} return {k: v for k, v in features.items()}
......
...@@ -130,6 +130,22 @@ def _is_after_cutoff( ...@@ -130,6 +130,22 @@ def _is_after_cutoff(
return False return False
def _replace_obsolete_references(obsolete_mapping) -> Mapping[str, str]:
"""Generates a new obsolete by tracing all cross-references and store the latest leaf to all referencing nodes"""
obsolete_new = {}
obsolete_keys = obsolete_mapping.keys()
def _new_target(k):
v = obsolete_mapping[k]
if v in obsolete_keys:
return _new_target(v)
return v
for k in obsolete_keys:
obsolete_new[k] = _new_target(k)
return obsolete_new
def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
"""Parses the data file from PDB that lists which PDB ids are obsolete.""" """Parses the data file from PDB that lists which PDB ids are obsolete."""
with open(obsolete_file_path) as f: with open(obsolete_file_path) as f:
...@@ -143,7 +159,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]: ...@@ -143,7 +159,7 @@ def _parse_obsolete(obsolete_file_path: str) -> Mapping[str, str]:
from_id = line[20:24].lower() from_id = line[20:24].lower()
to_id = line[29:33].lower() to_id = line[29:33].lower()
result[from_id] = to_id result[from_id] = to_id
return result return _replace_obsolete_references(result)
def generate_release_dates_cache(mmcif_dir: str, out_path: str): def generate_release_dates_cache(mmcif_dir: str, out_path: str):
......
...@@ -22,6 +22,7 @@ from openfold.utils.loss import ( ...@@ -22,6 +22,7 @@ from openfold.utils.loss import (
compute_tm, compute_tm,
compute_predicted_aligned_error, compute_predicted_aligned_error,
) )
from openfold.utils.precision_utils import is_fp16_enabled
class AuxiliaryHeads(nn.Module): class AuxiliaryHeads(nn.Module):
...@@ -150,15 +151,14 @@ class DistogramHead(nn.Module): ...@@ -150,15 +151,14 @@ class DistogramHead(nn.Module):
logits = logits + logits.transpose(-2, -3) logits = logits + logits.transpose(-2, -3)
return logits return logits
def forward(self, z): def forward(self, z):
if(is_fp16_enabled()):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return self._forward(z.float()) return self._forward(z.float())
else: else:
return self._forward(z) return self._forward(z)
class TMScoreHead(nn.Module): class TMScoreHead(nn.Module):
""" """
For use in computation of TM-score, subsection 1.9.7 For use in computation of TM-score, subsection 1.9.7
......
...@@ -21,6 +21,7 @@ import torch.nn as nn ...@@ -21,6 +21,7 @@ import torch.nn as nn
from openfold.model.primitives import Linear from openfold.model.primitives import Linear
from openfold.utils.chunk_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
class OuterProductMean(nn.Module): class OuterProductMean(nn.Module):
...@@ -150,9 +151,7 @@ class OuterProductMean(nn.Module): ...@@ -150,9 +151,7 @@ class OuterProductMean(nn.Module):
chunk_size: Optional[int] = None, chunk_size: Optional[int] = None,
inplace_safe: bool = False, inplace_safe: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if(is_fp16_enabled()):
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return self._forward(m.float(), mask, chunk_size, inplace_safe) return self._forward(m.float(), mask, chunk_size, inplace_safe)
else: else:
......
...@@ -35,6 +35,7 @@ from scipy.stats import truncnorm ...@@ -35,6 +35,7 @@ from scipy.stats import truncnorm
from openfold.utils.checkpointing import get_checkpoint_fn from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.chunk_utils import _chunk_slice from openfold.utils.chunk_utils import _chunk_slice
from openfold.utils.kernel.attention_core import attention_core from openfold.utils.kernel.attention_core import attention_core
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
permute_final_dims, permute_final_dims,
flatten_final_dims, flatten_final_dims,
...@@ -479,9 +480,9 @@ class Attention(nn.Module): ...@@ -479,9 +480,9 @@ class Attention(nn.Module):
q, k, v = self._prep_qkv(q_x, kv_x) q, k, v = self._prep_qkv(q_x, kv_x)
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) if is_fp16_enabled():
if float16_enabled:
use_memory_efficient_kernel = False use_memory_efficient_kernel = False
if(use_memory_efficient_kernel): if(use_memory_efficient_kernel):
if(len(biases) > 2): if(len(biases) > 2):
raise ValueError( raise ValueError(
......
...@@ -33,6 +33,7 @@ from openfold.utils.feats import ( ...@@ -33,6 +33,7 @@ from openfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos, frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames, torsion_angles_to_frames,
) )
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
dict_multimap, dict_multimap,
...@@ -312,8 +313,7 @@ class InvariantPointAttention(nn.Module): ...@@ -312,8 +313,7 @@ class InvariantPointAttention(nn.Module):
z[0] = z[0].cpu() z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16) if(is_fp16_enabled()):
if float16_enabled and torch.is_autocast_enabled():
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
a = torch.matmul( a = torch.matmul(
permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden]
...@@ -324,6 +324,7 @@ class InvariantPointAttention(nn.Module): ...@@ -324,6 +324,7 @@ class InvariantPointAttention(nn.Module):
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a *= math.sqrt(1.0 / (3 * self.c_hidden)) a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
......
...@@ -21,6 +21,7 @@ import torch.nn as nn ...@@ -21,6 +21,7 @@ import torch.nn as nn
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.chunk_utils import chunk_layer from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims from openfold.utils.tensor_utils import add, permute_final_dims
...@@ -391,12 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -391,12 +392,13 @@ class TriangleMultiplicativeUpdate(nn.Module):
b = mask b = mask
b = b * self.sigmoid(self.linear_b_g(z)) b = b * self.sigmoid(self.linear_b_g(z))
b = b * self.linear_b_p(z) b = b * self.linear_b_p(z)
float16_enabled = (torch.get_autocast_gpu_dtype() == torch.float16)
if float16_enabled and torch.is_autocast_enabled(): if(is_fp16_enabled()):
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float()) x = self._combine_projections(a.float(), b.float())
else: else:
x = self._combine_projections(a, b) x = self._combine_projections(a, b)
del a, b del a, b
x = self.layer_norm_out(x) x = self.layer_norm_out(x)
x = self.linear_z(x) x = self.linear_z(x)
......
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
#include <torch/extension.h>
void attn_softmax_inplace_forward_(
at::Tensor input,
long long rows, int cols
)
{
throw std::runtime_error("attn_softmax_inplace_forward_ not implemented on CPU");
};
void attn_softmax_inplace_backward_(
at::Tensor output,
at::Tensor d_ov,
at::Tensor values,
long long rows,
int cols_output,
int cols_values
)
{
throw std::runtime_error("attn_softmax_inplace_backward_ not implemented on CPU");
};
\ No newline at end of file
# Copyright 2022 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import torch
def is_fp16_enabled():
# Autocast world
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
return fp16_enabled
...@@ -140,7 +140,7 @@ def main(args): ...@@ -140,7 +140,7 @@ def main(args):
# Create the output directory # Create the output directory
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
config = model_config(args.config_preset) config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
if(args.trace_model): if(args.trace_model):
if(not config.data.predict.fixed_size): if(not config.data.predict.fixed_size):
...@@ -369,6 +369,10 @@ if __name__ == "__main__": ...@@ -369,6 +369,10 @@ if __name__ == "__main__":
help=""""Whether to output (100 - pLDDT) in the B-factor column instead help=""""Whether to output (100 - pLDDT) in the B-factor column instead
of the pLDDT itself""" of the pLDDT itself"""
) )
parser.add_argument(
"--long_sequence_inference", action="store_true", default=False,
help="""enable options to reduce memory usage at the cost of speed, helps longer sequences fit into GPU memory, see the README for details"""
)
add_data_args(parser) add_data_args(parser)
args = parser.parse_args() args = parser.parse_args()
......
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
File mode changed from 100644 to 100755
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment