Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
dba44612
Commit
dba44612
authored
Apr 28, 2022
by
Gustaf Ahdritz
Browse files
Resolve merge conflicts
parents
4bd1b4d5
576174f0
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
374 additions
and
177 deletions
+374
-177
Dockerfile
Dockerfile
+4
-3
README.md
README.md
+21
-8
deepspeed_config.json
deepspeed_config.json
+2
-9
environment.yml
environment.yml
+4
-2
openfold/config.py
openfold/config.py
+1
-0
openfold/data/data_pipeline.py
openfold/data/data_pipeline.py
+16
-7
openfold/data/data_transforms.py
openfold/data/data_transforms.py
+27
-11
openfold/data/feature_pipeline.py
openfold/data/feature_pipeline.py
+2
-1
openfold/model/embedders.py
openfold/model/embedders.py
+9
-12
openfold/model/evoformer.py
openfold/model/evoformer.py
+50
-28
openfold/model/model.py
openfold/model/model.py
+4
-3
openfold/model/msa.py
openfold/model/msa.py
+29
-15
openfold/model/primitives.py
openfold/model/primitives.py
+40
-19
openfold/np/relax/amber_minimize.py
openfold/np/relax/amber_minimize.py
+0
-57
openfold/utils/__init__.py
openfold/utils/__init__.py
+3
-1
openfold/utils/feats.py
openfold/utils/feats.py
+4
-1
openfold/utils/kernel/__init__.py
openfold/utils/kernel/__init__.py
+0
-0
openfold/utils/kernel/attention_core.py
openfold/utils/kernel/attention_core.py
+103
-0
openfold/utils/kernel/csrc/compat.h
openfold/utils/kernel/csrc/compat.h
+11
-0
openfold/utils/kernel/csrc/softmax_cuda.cpp
openfold/utils/kernel/csrc/softmax_cuda.cpp
+44
-0
No files found.
Dockerfile
View file @
dba44612
FROM
nvidia/cuda:1
1.0
-cudnn8-runtime-ubuntu18.04
FROM
nvidia/cuda:1
0.2
-cudnn8-runtime-ubuntu18.04
RUN
apt-get update
&&
apt-get
install
-y
wget cuda-minimal-build-1
1-0
git
RUN
apt-get update
&&
apt-get
install
-y
wget cuda-minimal-build-1
0-2
git
RUN
wget
-P
/tmp
\
"https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh"
\
&&
bash /tmp/Miniconda3-latest-Linux-x86_64.sh
-b
-p
/opt/conda
\
...
...
@@ -21,4 +21,5 @@ COPY lib/openmm.patch /opt/openfold/lib/openmm.patch
RUN
wget
-q
-P
/opt/openfold/openfold/resources
\
https://git.scicore.unibas.ch/schwede/openstructure/-/raw/7102c63615b64735c4941278d92b554ec94415f8/modules/mol/alg/src/stereo_chemical_props.txt
RUN
patch
-p0
-d
/opt/conda/lib/python3.7/site-packages/ < /opt/openfold/lib/openmm.patch
RUN
python3 /opt/openfold/setup.py
install
WORKDIR
/opt/openfold
RUN
python3 setup.py
install
README.md
View file @
dba44612
...
...
@@ -15,12 +15,26 @@ cases where the *Nature* paper differs from the source, we always defer to the
latter.
OpenFold is built to support inference with AlphaFold's original JAX weights.
Try it out with our
[
Colab notebook
](
https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb
)
.
It's also faster than the official code on GPU. Try it out for yourself with
our
[
Colab notebook
](
https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb
)
.
Unlike DeepMind's public code, OpenFold is also trainable. It can be trained
with
[
DeepSpeed
](
https://github.com/microsoft/deepspeed
)
and with either
`fp16`
or
`bfloat16`
half-precision.
OpenFold is equipped with an implementation of low-memory attention
(
[
Rabe & Staats 2021
](
https://arxiv.org/pdf/2112.05682.pdf
)
), which
enables inference on extremely long chains.
We've modified
[
FastFold
](
https://github.com/hpcaitech/FastFold
)
's custom CUDA
kernels to support in-place attention during inference and training. These use
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
implementations, respectively.
We also make available efficient scripts for generating alignments. We've
used them to generate millions of alignments that will be released alongside
original OpenFold weights, trained from scratch using our code (more on that soon).
## Installation (Linux)
All Python dependencies are specified in
`environment.yml`
. For producing sequence
...
...
@@ -48,6 +62,12 @@ To deactivate it, run:
source
scripts/deactivate_conda_env.sh
```
With the environment active, compile OpenFold's CUDA kernels with
```
bash
python3 setup.py
install
```
To install the HH-suite to
`/usr/bin`
, run
```
bash
...
...
@@ -129,13 +149,6 @@ to `None` in the config.
### Training
After activating the OpenFold environment with
`source scripts/activate_conda_env.sh`
, install OpenFold by running
```
bash
python setup.py
install
```
To train the model, you will first need to precompute protein alignments.
You have two options. You can use the same procedure DeepMind used by running
...
...
deepspeed_config.json
View file @
dba44612
{
"optimizer"
:
{
"type"
:
"Adam"
,
"params"
:
{
"lr"
:
0.001
,
"eps"
:
1e-05
}
},
"fp16"
:
{
"enabled"
:
tru
e
,
"enabled"
:
fals
e
,
"min_loss_scale"
:
1
},
"amp"
:
{
...
...
@@ -15,7 +8,7 @@
"opt_level"
:
"O2"
},
"bfloat16"
:
{
"enabled"
:
fals
e
"enabled"
:
tru
e
},
"zero_optimization"
:
{
"stage"
:
2
,
...
...
environment.yml
View file @
dba44612
...
...
@@ -6,7 +6,7 @@ channels:
dependencies
:
-
pip
:
-
biopython==1.79
-
deepspeed==0.5.
3
-
deepspeed==0.5.
9
-
dm-tree==0.1.6
-
ml-collections==0.1.0
-
numpy==1.21.2
...
...
@@ -15,7 +15,7 @@ dependencies:
-
scipy==1.7.1
-
tqdm==4.62.2
-
typing-extensions==3.10.0.2
-
pytorch_lightning==1.5.0
-
pytorch_lightning==1.5.
1
0
-
git+https://github.com/NVIDIA/dllogger.git
-
pytorch::pytorch=1.10.*
-
conda-forge::python=3.7
...
...
@@ -23,6 +23,8 @@ dependencies:
-
conda-forge::pip
-
conda-forge::openmm=7.5.1
-
conda-forge::pdbfixer
-
conda-forge::cudatoolkit==10.2.*
-
conda-forge::cudatoolkit-dev==10.*
-
bioconda::hmmer==3.3.2
-
bioconda::hhsuite==3.3.0
-
bioconda::kalign2==2.04
openfold/config.py
View file @
dba44612
...
...
@@ -267,6 +267,7 @@ config = mlc.ConfigDict(
"clamp_prob"
:
0.9
,
"max_distillation_msa_clusters"
:
1000
,
"uniform_recycling"
:
True
,
"distillation_prob"
:
0.75
,
},
"data_module"
:
{
"use_small_bfd"
:
False
,
...
...
openfold/data/data_pipeline.py
View file @
dba44612
...
...
@@ -176,8 +176,8 @@ def make_protein_features(
def
make_pdb_features
(
protein_object
:
protein
.
Protein
,
description
:
str
,
confidence_threshold
:
float
=
0.5
,
is_distillation
:
bool
=
True
,
confidence_threshold
:
float
=
50.
,
)
->
FeatureDict
:
pdb_feats
=
make_protein_features
(
protein_object
,
description
,
_is_distillation
=
True
...
...
@@ -186,9 +186,7 @@ def make_pdb_features(
if
(
is_distillation
):
high_confidence
=
protein_object
.
b_factors
>
confidence_threshold
high_confidence
=
np
.
any
(
high_confidence
,
axis
=-
1
)
for
i
,
confident
in
enumerate
(
high_confidence
):
if
(
not
confident
):
pdb_feats
[
"all_atom_mask"
][
i
]
=
0
pdb_feats
[
"all_atom_mask"
]
*=
high_confidence
[...,
None
]
return
pdb_feats
...
...
@@ -832,13 +830,24 @@ class DataPipeline:
alignment_dir
:
str
,
is_distillation
:
bool
=
True
,
chain_id
:
Optional
[
str
]
=
None
,
_structure_index
:
Optional
[
str
]
=
None
,
_alignment_index
:
Optional
[
str
]
=
None
,
)
->
FeatureDict
:
"""
Assembles features for a protein in a PDB file.
"""
with
open
(
pdb_path
,
'r'
)
as
f
:
pdb_str
=
f
.
read
()
if
(
_structure_index
is
not
None
):
db_dir
=
os
.
path
.
dirname
(
pdb_path
)
db
=
_structure_index
[
"db"
]
db_path
=
os
.
path
.
join
(
db_dir
,
db
)
fp
=
open
(
db_path
,
"rb"
)
_
,
offset
,
length
=
_structure_index
[
"files"
][
0
]
fp
.
seek
(
offset
)
pdb_str
=
fp
.
read
(
length
).
decode
(
"utf-8"
)
fp
.
close
()
else
:
with
open
(
pdb_path
,
'r'
)
as
f
:
pdb_str
=
f
.
read
()
protein_object
=
protein
.
from_pdb_string
(
pdb_str
,
chain_id
)
input_sequence
=
_aatype_to_str_sequence
(
protein_object
.
aatype
)
...
...
@@ -846,7 +855,7 @@ class DataPipeline:
pdb_feats
=
make_pdb_features
(
protein_object
,
description
,
is_distillation
is_distillation
=
is_distillation
)
hits
=
self
.
_parse_template_hits
(
...
...
openfold/data/data_transforms.py
View file @
dba44612
...
...
@@ -50,7 +50,7 @@ def cast_to_64bit_ints(protein):
def
make_one_hot
(
x
,
num_classes
):
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
)
x_one_hot
=
torch
.
zeros
(
*
x
.
shape
,
num_classes
,
device
=
x
.
device
)
x_one_hot
.
scatter_
(
-
1
,
x
.
unsqueeze
(
-
1
),
1
)
return
x_one_hot
...
...
@@ -92,9 +92,9 @@ def fix_templates_aatype(protein):
)
# Map hhsearch-aatype to our aatype.
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
new_order_list
,
dtype
=
torch
.
int64
).
expand
(
n
um_templates
,
-
1
)
new_order
=
torch
.
tensor
(
n
ew_order_list
,
dtype
=
torch
.
int64
,
device
=
protein
[
"aatype"
].
device
,
)
.
expand
(
num_templates
,
-
1
)
protein
[
"template_aatype"
]
=
torch
.
gather
(
new_order
,
1
,
index
=
protein
[
"template_aatype"
]
)
...
...
@@ -106,7 +106,8 @@ def correct_msa_restypes(protein):
"""Correct MSA restype to have the same order as rc."""
new_order_list
=
rc
.
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
new_order
=
torch
.
tensor
(
[
new_order_list
]
*
protein
[
"msa"
].
shape
[
1
],
dtype
=
protein
[
"msa"
].
dtype
[
new_order_list
]
*
protein
[
"msa"
].
shape
[
1
],
device
=
protein
[
"msa"
].
device
,
).
transpose
(
0
,
1
)
protein
[
"msa"
]
=
torch
.
gather
(
new_order
,
0
,
protein
[
"msa"
])
...
...
@@ -187,7 +188,10 @@ def sample_msa(protein, max_seq, keep_extra, seed=None):
if
seed
is
not
None
:
g
.
manual_seed
(
seed
)
shuffled
=
torch
.
randperm
(
num_seq
-
1
,
generator
=
g
)
+
1
index_order
=
torch
.
cat
((
torch
.
tensor
([
0
]),
shuffled
),
dim
=
0
)
index_order
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
shuffled
.
device
),
shuffled
),
dim
=
0
)
num_sel
=
min
(
max_seq
,
num_seq
)
sel_seq
,
not_sel_seq
=
torch
.
split
(
index_order
,
[
num_sel
,
num_seq
-
num_sel
]
...
...
@@ -242,7 +246,7 @@ def delete_extra_msa(protein):
def
block_delete_msa
(
protein
,
config
):
num_seq
=
protein
[
"msa"
].
shape
[
0
]
block_num_seq
=
torch
.
floor
(
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
)
torch
.
tensor
(
num_seq
,
dtype
=
torch
.
float32
,
device
=
protein
[
"msa"
].
device
)
*
config
.
msa_fraction_per_block
).
to
(
torch
.
int32
)
...
...
@@ -275,7 +279,11 @@ def block_delete_msa(protein, config):
@
curry1
def
nearest_neighbor_clusters
(
protein
,
gap_agreement_weight
=
0.0
):
weights
=
torch
.
cat
(
[
torch
.
ones
(
21
),
gap_agreement_weight
*
torch
.
ones
(
1
),
torch
.
zeros
(
1
)],
[
torch
.
ones
(
21
,
device
=
protein
[
"msa"
].
device
),
gap_agreement_weight
*
torch
.
ones
(
1
,
device
=
protein
[
"msa"
].
device
),
torch
.
zeros
(
1
,
device
=
protein
[
"msa"
].
device
)
],
0
,
)
...
...
@@ -324,7 +332,10 @@ def unsorted_segment_sum(data, segment_ids, num_segments):
)
segment_ids
=
segment_ids
.
expand
(
data
.
shape
)
shape
=
[
num_segments
]
+
list
(
data
.
shape
[
1
:])
tensor
=
torch
.
zeros
(
*
shape
).
scatter_add_
(
0
,
segment_ids
,
data
.
float
())
tensor
=
(
torch
.
zeros
(
*
shape
,
device
=
segment_ids
.
device
)
.
scatter_add_
(
0
,
segment_ids
,
data
.
float
())
)
tensor
=
tensor
.
type
(
data
.
dtype
)
return
tensor
...
...
@@ -401,7 +412,7 @@ def make_pseudo_beta(protein, prefix=""):
@
curry1
def
add_constant_field
(
protein
,
key
,
value
):
protein
[
key
]
=
torch
.
tensor
(
value
)
protein
[
key
]
=
torch
.
tensor
(
value
,
device
=
protein
[
"msa"
].
device
)
return
protein
...
...
@@ -454,6 +465,7 @@ def make_masked_msa(protein, config, replace_fraction, seed):
1.0
-
config
.
profile_prob
-
config
.
same_prob
-
config
.
uniform_prob
)
assert
mask_prob
>=
0.0
categorical_probs
=
torch
.
nn
.
functional
.
pad
(
categorical_probs
,
pad_shapes
,
value
=
mask_prob
,
)
...
...
@@ -656,7 +668,11 @@ def make_atom14_masks(protein):
def
make_atom14_masks_np
(
batch
):
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
),
batch
,
np
.
ndarray
)
batch
=
tree_map
(
lambda
n
:
torch
.
tensor
(
n
,
device
=
batch
[
"aatype"
].
device
),
batch
,
np
.
ndarray
)
out
=
make_atom14_masks
(
batch
)
out
=
tensor_tree_map
(
lambda
t
:
np
.
array
(
t
),
out
)
return
out
...
...
openfold/data/feature_pipeline.py
View file @
dba44612
...
...
@@ -40,10 +40,11 @@ def np_to_tensor_dict(
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
"""
tensor_dict
=
{
k
:
torch
.
tensor
(
v
)
for
k
,
v
in
np_example
.
items
()
if
k
in
features
}
return
tensor_dict
...
...
openfold/model/embedders.py
View file @
dba44612
...
...
@@ -327,8 +327,6 @@ class RecyclingEmbedder(nn.Module):
self
.
no_bins
=
no_bins
self
.
inf
=
inf
self
.
bins
=
None
self
.
linear
=
Linear
(
self
.
no_bins
,
self
.
c_z
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_m
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
...
...
@@ -353,15 +351,14 @@ class RecyclingEmbedder(nn.Module):
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if
self
.
bins
is
None
:
self
.
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
,
)
bins
=
torch
.
linspace
(
self
.
min_bin
,
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
device
=
x
.
device
,
requires_grad
=
False
,
)
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
...
...
@@ -369,7 +366,7 @@ class RecyclingEmbedder(nn.Module):
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins
=
self
.
bins
**
2
squared_bins
=
bins
**
2
upper
=
torch
.
cat
(
[
squared_bins
[
1
:],
squared_bins
.
new_tensor
([
self
.
inf
])],
dim
=-
1
)
...
...
openfold/model/evoformer.py
View file @
dba44612
...
...
@@ -352,20 +352,31 @@ class ExtraMSABlock(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
def
add
(
m1
,
m2
):
# The first operation in a checkpoint can't be in-place, but it's
# nice to have in-place addition during inference. Thus...
if
(
torch
.
is_grad_enabled
()):
m1
=
m1
+
m2
else
:
m1
+=
m2
return
m1
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
.
clone
(),
z
=
z
.
clone
(),
m
.
clone
()
if
torch
.
is_grad_enabled
()
else
m
,
z
=
z
.
clone
()
if
torch
.
is_grad_enabled
()
else
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_memory_efficient_kernel
=
not
_chunk_logits
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
)
)
)
def
fn
(
m
,
z
):
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
add
(
m
,
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
...
...
@@ -548,11 +559,14 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
**
kwargs
,
):
super
(
ExtraMSAStack
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
...
...
@@ -569,7 +583,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -593,28 +607,36 @@ class ExtraMSAStack(nn.Module):
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
"""
if
(
not
self
.
chunk_msa_attn
):
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
None
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
else
:
m
,
z
=
b
(
m
,
z
)
else
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
return
z
openfold/model/model.py
View file @
dba44612
...
...
@@ -173,9 +173,10 @@ class AlphaFold(nn.Module):
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if
(
not
_recycle
):
m_1_prev_emb
*=
0
z_prev_emb
*=
0
# EDIT: This has since been removed from the official codebase (2cd61a)
# if(not _recycle):
# m_1_prev_emb *= 0
# z_prev_emb *= 0
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
...
...
openfold/model/msa.py
View file @
dba44612
...
...
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
math
import
torch
import
torch.nn
as
nn
...
...
@@ -79,20 +79,33 @@ class MSAAttention(nn.Module):
)
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
,
)
@
torch
.
jit
.
ignore
def
_chunk
(
self
,
m
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
use_memory_efficient_kernel
:
bool
,
chunk_size
:
int
,
)
->
torch
.
Tensor
:
mha
=
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
)
return
chunk_layer
(
self
.
mha
,
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
mha
,
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
,
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
def
_prep_inputs
(
self
,
...
...
@@ -113,13 +126,6 @@ class MSAAttention(nn.Module):
# [*, N_seq, 1, 1, N_res]
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
...
...
@@ -144,6 +150,11 @@ class MSAAttention(nn.Module):
chunk_logits
:
int
,
checkpoint
:
bool
,
)
->
torch
.
Tensor
:
"""
MSA attention with training-time chunking of the softmax computation.
Saves memory in the extra MSA stack. Probably obviated by our fused
attention kernel, which is now used by default.
"""
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
...
...
@@ -181,6 +192,7 @@ class MSAAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -212,12 +224,13 @@ class MSAAttention(nn.Module):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
use_memory_efficient_kernel
,
chunk_size
)
else
:
m
=
self
.
mha
(
q_x
=
m
,
kv_x
=
m
,
biases
=
biases
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
)
return
m
...
...
@@ -291,7 +304,8 @@ class MSAColumnAttention(nn.Module):
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
openfold/model/primitives.py
View file @
dba44612
...
...
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
functools
import
partial
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
...
...
@@ -24,6 +23,7 @@ import torch.nn as nn
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
flatten_final_dims
,
...
...
@@ -199,8 +199,9 @@ class LayerNorm(nn.Module):
return
out
@
torch
.
jit
.
ignore
def
softmax
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
def
softmax
_no_cast
(
t
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
...
...
@@ -217,14 +218,8 @@ def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
#@torch.jit.script
def
_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
key
=
permute_final_dims
(
key
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
value
=
permute_final_dims
(
value
,
(
1
,
0
,
2
))
key
=
permute_final_dims
(
key
,
(
1
,
0
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
query
,
key
)
...
...
@@ -232,14 +227,11 @@ def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, bias
for
b
in
biases
:
a
+=
b
a
=
softmax
(
a
,
-
1
)
a
=
softmax
_no_cast
(
a
,
-
1
)
# [*, H, Q, C_hidden]
a
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
a
=
a
.
transpose
(
-
2
,
-
3
)
return
a
...
...
@@ -254,7 +246,8 @@ def _attention_chunked_trainable(
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
return
_attention
(
q
,
k
,
v
,
bs
)
a
=
_attention
(
q
,
k
,
v
,
bs
)
return
a
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
...
...
@@ -289,7 +282,8 @@ def _attention_chunked_trainable(
]
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v_chunk
,
bias_chunks
)
o_chunk
=
o_chunk
.
transpose
(
-
2
,
-
3
)
o_chunks
.
append
(
o_chunk
)
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
...
...
@@ -374,6 +368,11 @@ class Attention(nn.Module):
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, H, Q/K, C_hidden]
q
=
q
.
transpose
(
-
2
,
-
3
)
k
=
k
.
transpose
(
-
2
,
-
3
)
v
=
v
.
transpose
(
-
2
,
-
3
)
q
/=
math
.
sqrt
(
self
.
c_hidden
)
return
q
,
k
,
v
...
...
@@ -402,6 +401,7 @@ class Attention(nn.Module):
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
...
...
@@ -414,8 +414,15 @@ class Attention(nn.Module):
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_memory_efficient_kernel:
Whether to use a custom memory-efficient attention kernel.
This should be the default choice for most. If none of the
"use_<...>" flags are True, a stock PyTorch implementation
is used instead
use_lma:
Whether to use low-memory attention
Whether to use low-memory attention (Staats & Rabe 2021). If
none of the "use_<...>" flags are True, a stock PyTorch
implementation is used instead
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
...
...
@@ -430,18 +437,32 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if
(
use_memory_efficient_kernel
and
use_lma
):
raise
ValueError
(
"Choose one of use_memory_efficient_kernel and use_lma"
)
# [*, H, Q/K, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
if
(
use_lma
):
# [*, Q, H, C_hidden]
if
(
use_memory_efficient_kernel
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
"If use_memory_efficient_kernel is True, you may only "
"provide up to two bias terms"
)
o
=
attention_core
(
q
,
k
,
v
,
*
((
biases
+
[
None
]
*
2
)[:
2
]))
o
=
o
.
transpose
(
-
2
,
-
3
)
elif
(
use_lma
):
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
o
=
o
.
transpose
(
-
2
,
-
3
)
o
=
self
.
_wrap_up
(
o
,
q_x
)
...
...
@@ -497,7 +518,7 @@ class GlobalAttention(nn.Module):
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
softmax
(
a
)
a
=
softmax
_no_cast
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
...
...
openfold/np/relax/amber_minimize.py
View file @
dba44612
...
...
@@ -553,60 +553,3 @@ def run_pipeline(
)
iteration
+=
1
return
ret
def
get_initial_energies
(
pdb_strs
:
Sequence
[
str
],
stiffness
:
float
=
0.0
,
restraint_set
:
str
=
"non_hydrogen"
,
exclude_residues
:
Optional
[
Sequence
[
int
]]
=
None
,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues
=
exclude_residues
or
[]
openmm_pdbs
=
[
openmm_app
.
PDBFile
(
PdbStructure
(
io
.
StringIO
(
p
)))
for
p
in
pdb_strs
]
force_field
=
openmm_app
.
ForceField
(
"amber99sb.xml"
)
system
=
force_field
.
createSystem
(
openmm_pdbs
[
0
].
topology
,
constraints
=
openmm_app
.
HBonds
)
stiffness
=
stiffness
*
ENERGY
/
(
LENGTH
**
2
)
if
stiffness
>
0
*
ENERGY
/
(
LENGTH
**
2
):
_add_restraints
(
system
,
openmm_pdbs
[
0
],
stiffness
,
restraint_set
,
exclude_residues
)
simulation
=
openmm_app
.
Simulation
(
openmm_pdbs
[
0
].
topology
,
system
,
openmm
.
LangevinIntegrator
(
0
,
0.01
,
0.0
),
openmm
.
Platform
.
getPlatformByName
(
"CPU"
),
)
energies
=
[]
for
pdb
in
openmm_pdbs
:
try
:
simulation
.
context
.
setPositions
(
pdb
.
positions
)
state
=
simulation
.
context
.
getState
(
getEnergy
=
True
)
energies
.
append
(
state
.
getPotentialEnergy
().
value_in_unit
(
ENERGY
))
except
Exception
as
e
:
# pylint: disable=broad-except
logging
.
error
(
"Error getting initial energy, returning large value %s"
,
e
)
energies
.
append
(
unit
.
Quantity
(
1e20
,
ENERGY
))
return
energies
openfold/utils/__init__.py
View file @
dba44612
...
...
@@ -2,12 +2,14 @@ import os
import
glob
import
importlib
as
importlib
from
.
import
kernel
_files
=
glob
.
glob
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"*.py"
))
__all__
=
[
os
.
path
.
basename
(
f
)[:
-
3
]
for
f
in
_files
if
os
.
path
.
isfile
(
f
)
and
not
f
.
endswith
(
"__init__.py"
)
]
]
+
[
"kernel"
]
_modules
=
[(
m
,
importlib
.
import_module
(
"."
+
m
,
__name__
))
for
m
in
__all__
]
for
_m
in
_modules
:
globals
()[
_m
[
0
]]
=
_m
[
1
]
...
...
openfold/utils/feats.py
View file @
dba44612
...
...
@@ -107,7 +107,10 @@ def dgram_from_positions(
def
build_template_pair_feat
(
batch
,
min_bin
,
max_bin
,
no_bins
,
use_unit_vector
=
False
,
eps
=
1e-20
,
inf
=
1e8
batch
,
min_bin
,
max_bin
,
no_bins
,
use_unit_vector
=
False
,
eps
=
1e-20
,
inf
=
1e8
):
template_mask
=
batch
[
"template_pseudo_beta_mask"
]
template_mask_2d
=
template_mask
[...,
None
]
*
template_mask
[...,
None
,
:]
...
...
openfold/utils/kernel/__init__.py
0 → 100644
View file @
dba44612
openfold/utils/kernel/attention_core.py
0 → 100644
View file @
dba44612
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
importlib
from
functools
import
reduce
from
operator
import
mul
import
torch
attn_core_inplace_cuda
=
importlib
.
import_module
(
"attn_core_inplace_cuda"
)
SUPPORTED_DTYPES
=
[
torch
.
float32
,
torch
.
bfloat16
]
class
AttentionCoreFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
bias_1
=
None
,
bias_2
=
None
):
if
(
bias_1
is
None
and
bias_2
is
not
None
):
raise
ValueError
(
"bias_1 must be specified before bias_2"
)
if
(
q
.
dtype
not
in
SUPPORTED_DTYPES
):
raise
ValueError
(
"Unsupported datatype"
)
q
=
q
.
contiguous
()
k
=
k
.
contiguous
()
# [*, H, Q, K]
attention_logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
)
if
(
bias_1
is
not
None
):
attention_logits
+=
bias_1
if
(
bias_2
is
not
None
):
attention_logits
+=
bias_2
attn_core_inplace_cuda
.
forward_
(
attention_logits
,
reduce
(
mul
,
attention_logits
.
shape
[:
-
1
]),
attention_logits
.
shape
[
-
1
],
)
o
=
torch
.
matmul
(
attention_logits
,
v
)
ctx
.
bias_1_shape
=
bias_1
.
shape
if
bias_1
is
not
None
else
None
ctx
.
bias_2_shape
=
bias_2
.
shape
if
bias_2
is
not
None
else
None
ctx
.
save_for_backward
(
q
,
k
,
v
,
attention_logits
)
return
o
@
staticmethod
def
backward
(
ctx
,
grad_output
):
q
,
k
,
v
,
attention_logits
=
ctx
.
saved_tensors
grad_q
=
grad_k
=
grad_v
=
grad_bias_1
=
grad_bias_2
=
None
grad_v
=
torch
.
matmul
(
attention_logits
.
transpose
(
-
1
,
-
2
),
grad_output
)
attn_core_inplace_cuda
.
backward_
(
attention_logits
,
grad_output
.
contiguous
(),
v
.
contiguous
(),
# v is implicitly transposed in the kernel
reduce
(
mul
,
attention_logits
.
shape
[:
-
1
]),
attention_logits
.
shape
[
-
1
],
grad_output
.
shape
[
-
1
],
)
if
(
ctx
.
bias_1_shape
is
not
None
):
grad_bias_1
=
torch
.
sum
(
attention_logits
,
dim
=
tuple
(
i
for
i
,
d
in
enumerate
(
ctx
.
bias_1_shape
)
if
d
==
1
),
keepdim
=
True
,
)
if
(
ctx
.
bias_2_shape
is
not
None
):
grad_bias_2
=
torch
.
sum
(
attention_logits
,
dim
=
tuple
(
i
for
i
,
d
in
enumerate
(
ctx
.
bias_2_shape
)
if
d
==
1
),
keepdim
=
True
,
)
grad_q
=
torch
.
matmul
(
attention_logits
,
k
)
grad_k
=
torch
.
matmul
(
q
.
transpose
(
-
1
,
-
2
),
attention_logits
,
).
transpose
(
-
1
,
-
2
)
return
grad_q
,
grad_k
,
grad_v
,
grad_bias_1
,
grad_bias_2
attention_core
=
AttentionCoreFunction
.
apply
openfold/utils/kernel/csrc/compat.h
0 → 100644
View file @
dba44612
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
openfold/utils/kernel/csrc/softmax_cuda.cpp
0 → 100644
View file @
dba44612
// Copyright 2021 AlQuraishi Laboratory
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// modified from fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
#include <torch/extension.h>
void
attn_softmax_inplace_forward_
(
at
::
Tensor
input
,
long
long
rows
,
int
cols
);
void
attn_softmax_inplace_backward_
(
at
::
Tensor
output
,
at
::
Tensor
d_ov
,
at
::
Tensor
values
,
long
long
rows
,
int
cols_output
,
int
cols_values
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_"
,
&
attn_softmax_inplace_forward_
,
"Softmax forward (CUDA)"
);
m
.
def
(
"backward_"
,
&
attn_softmax_inplace_backward_
,
"Softmax backward (CUDA)"
);
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment