Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
df89bb28
Commit
df89bb28
authored
Jan 05, 2022
by
Gustaf Ahdritz
Browse files
Add chunking experiment
parent
70d6bda5
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
359 additions
and
249 deletions
+359
-249
openfold/config.py
openfold/config.py
+1
-1
openfold/model/evoformer.py
openfold/model/evoformer.py
+210
-84
openfold/model/model.py
openfold/model/model.py
+6
-4
openfold/model/msa.py
openfold/model/msa.py
+11
-4
openfold/model/primitives.py
openfold/model/primitives.py
+116
-148
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+2
-2
openfold/utils/checkpointing.py
openfold/utils/checkpointing.py
+13
-6
No files found.
openfold/config.py
View file @
df89bb28
...
@@ -318,10 +318,10 @@ config = mlc.ConfigDict(
...
@@ -318,10 +318,10 @@ config = mlc.ConfigDict(
"transition_n"
:
4
,
"transition_n"
:
4
,
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
True
,
"clear_cache_between_blocks"
:
True
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
"ckpt"
:
blocks_per_ckpt
is
not
None
,
},
},
"enabled"
:
True
,
"enabled"
:
True
,
},
},
...
...
openfold/model/evoformer.py
View file @
df89bb28
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
...
@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -35,7 +36,7 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing
,
TriangleMultiplicationOutgoing
,
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
chunk_layer
...
@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
...
@@ -117,51 +118,23 @@ class MSATransition(nn.Module):
return
m
return
m
class
EvoformerBlock
(
nn
.
Module
):
class
EvoformerBlock
Core
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
_is_extra_msa_stack
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
):
):
super
(
EvoformerBlock
,
self
).
__init__
()
super
(
EvoformerBlockCore
,
self
).
__init__
()
self
.
_is_extra_msa_stack
=
_is_extra_msa_stack
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
if
_is_extra_msa_stack
:
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
else
:
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_transition
=
MSATransition
(
self
.
msa_transition
=
MSATransition
(
c_m
=
c_m
,
c_m
=
c_m
,
...
@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
...
@@ -201,7 +174,6 @@ class EvoformerBlock(nn.Module):
transition_n
,
transition_n
,
)
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
...
@@ -220,10 +192,6 @@ class EvoformerBlock(nn.Module):
...
@@ -220,10 +192,6 @@ class EvoformerBlock(nn.Module):
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
=
m
+
self
.
msa_transition
(
m
=
m
+
self
.
msa_transition
(
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
)
)
...
@@ -245,6 +213,174 @@ class EvoformerBlock(nn.Module):
...
@@ -245,6 +213,174 @@ class EvoformerBlock(nn.Module):
return
m
,
z
return
m
,
z
class
EvoformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
):
super
().
__init__
()
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnAttention
(
c_m
,
c_hidden_msa_att
,
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
)
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_opm
:
int
,
c_hidden_mul
:
int
,
c_hidden_pair_att
:
int
,
no_heads_msa
:
int
,
no_heads_pair
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
pair_dropout
:
float
,
inf
:
float
,
eps
:
float
,
ckpt
:
bool
,
):
super
().
__init__
()
self
.
ckpt
=
ckpt
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
)
self
.
msa_att_col
=
MSAColumnGlobalAttention
(
c_in
=
c_m
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
inf
=
inf
,
eps
=
eps
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
c_m
=
c_m
,
c_z
=
c_z
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
transition_n
=
transition_n
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
checkpoint_chunk_size
:
Optional
[
int
]
=
512
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
checkpoint_chunk_size
=
checkpoint_chunk_size
if
self
.
ckpt
else
None
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_chunk_and_checkpoint
=
checkpoint_chunk_size
,
)
)
def
fn
(
m
,
z
):
m
=
m
+
self
.
msa_att_col
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
)
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
return
m
,
z
if
(
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
m
,
z
)
else
:
m
,
z
=
fn
(
m
,
z
)
return
m
,
z
class
EvoformerStack
(
nn
.
Module
):
class
EvoformerStack
(
nn
.
Module
):
"""
"""
Main Evoformer trunk.
Main Evoformer trunk.
...
@@ -271,7 +407,6 @@ class EvoformerStack(nn.Module):
...
@@ -271,7 +407,6 @@ class EvoformerStack(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
_is_extra_msa_stack
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -313,7 +448,6 @@ class EvoformerStack(nn.Module):
...
@@ -313,7 +448,6 @@ class EvoformerStack(nn.Module):
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
_is_extra_msa_stack
=
_is_extra_msa_stack
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
...
@@ -332,15 +466,12 @@ class EvoformerStack(nn.Module):
...
@@ -332,15 +466,12 @@ class EvoformerStack(nn.Module):
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
_is_extra_msa_stack
=
_is_extra_msa_stack
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
if
not
self
.
_is_extra_msa_stack
:
self
.
linear
=
Linear
(
c_m
,
c_s
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
...
@@ -390,8 +521,6 @@ class EvoformerStack(nn.Module):
...
@@ -390,8 +521,6 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
s
=
None
if
not
self
.
_is_extra_msa_stack
:
seq_dim
=
-
3
seq_dim
=
-
3
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
...
@@ -405,8 +534,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -405,8 +534,7 @@ class ExtraMSAStack(nn.Module):
Implements Algorithm 18.
Implements Algorithm 18.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
c_m
:
int
,
c_m
:
int
,
c_z
:
int
,
c_z
:
int
,
c_hidden_msa_att
:
int
,
c_hidden_msa_att
:
int
,
...
@@ -419,38 +547,38 @@ class ExtraMSAStack(nn.Module):
...
@@ -419,38 +547,38 @@ class ExtraMSAStack(nn.Module):
transition_n
:
int
,
transition_n
:
int
,
msa_dropout
:
float
,
msa_dropout
:
float
,
pair_dropout
:
float
,
pair_dropout
:
float
,
blocks_per_ckpt
:
int
,
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
super
(
ExtraMSAStack
,
self
).
__init__
()
super
(
ExtraMSAStack
,
self
).
__init__
()
c_s
=
None
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
stack
=
EvoformerStack
(
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
c_m
=
c_m
,
c_m
=
c_m
,
c_z
=
c_z
,
c_z
=
c_z
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_msa_att
=
c_hidden_msa_att
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_opm
=
c_hidden_opm
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_mul
=
c_hidden_mul
,
c_hidden_pair_att
=
c_hidden_pair_att
,
c_hidden_pair_att
=
c_hidden_pair_att
,
c_s
=
c_s
,
no_heads_msa
=
no_heads_msa
,
no_heads_msa
=
no_heads_msa
,
no_heads_pair
=
no_heads_pair
,
no_heads_pair
=
no_heads_pair
,
no_blocks
=
no_blocks
,
transition_n
=
transition_n
,
transition_n
=
transition_n
,
msa_dropout
=
msa_dropout
,
msa_dropout
=
msa_dropout
,
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
blocks_per_ckpt
=
blocks_per_ckpt
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
clear_cache_between_blocks
=
clear_cache_between_blocks
,
ckpt
=
ckpt
,
_is_extra_msa_stack
=
True
,
)
)
self
.
blocks
.
append
(
block
)
def
forward
(
def
forward
(
self
,
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
...
@@ -471,12 +599,10 @@ class ExtraMSAStack(nn.Module):
...
@@ -471,12 +599,10 @@ class ExtraMSAStack(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] pair update
[*, N_res, N_res, C_z] pair update
"""
"""
_
,
z
,
_
=
self
.
stack
(
for
b
in
self
.
blocks
:
m
,
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
z
,
msa_mask
=
msa_mask
,
if
(
self
.
clear_cache_between_blocks
):
pair_mask
=
pair_mask
,
torch
.
cuda
.
empty_cache
()
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
return
z
return
z
openfold/model/model.py
View file @
df89bb28
...
@@ -336,7 +336,9 @@ class AlphaFold(nn.Module):
...
@@ -336,7 +336,9 @@ class AlphaFold(nn.Module):
def
_disable_activation_checkpointing
(
self
):
def
_disable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
template_pair_stack
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
evoformer
.
blocks_per_ckpt
=
None
self
.
extra_msa_stack
.
stack
.
blocks_per_ckpt
=
None
for
b
in
self
.
extra_msa_stack
.
blocks
:
b
.
ckpt
=
False
def
_enable_activation_checkpointing
(
self
):
def
_enable_activation_checkpointing
(
self
):
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
self
.
template_pair_stack
.
blocks_per_ckpt
=
(
...
@@ -345,9 +347,9 @@ class AlphaFold(nn.Module):
...
@@ -345,9 +347,9 @@ class AlphaFold(nn.Module):
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
evoformer
.
blocks_per_ckpt
=
(
self
.
config
.
evoformer_stack
.
blocks_per_ckpt
self
.
config
.
evoformer_stack
.
blocks_per_ckpt
)
)
self
.
extra_msa_stack
.
stack
.
blocks_per_ckpt
=
(
self
.
config
.
extra_msa
.
extra_msa_stack
.
blocks
_per_ckpt
for
b
in
self
.
extra_msa_stack
.
blocks
:
)
b
.
ckpt
=
self
.
config
.
extra_msa
.
extra_msa_stack
.
ckpt
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
"""
"""
...
...
openfold/model/msa.py
View file @
df89bb28
...
@@ -93,6 +93,7 @@ class MSAAttention(nn.Module):
...
@@ -93,6 +93,7 @@ class MSAAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
]
=
None
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_and_checkpoint
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -125,9 +126,9 @@ class MSAAttention(nn.Module):
...
@@ -125,9 +126,9 @@ class MSAAttention(nn.Module):
# This step simply returns a larger view of the bias, and does not
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
# [*, N_seq, no_heads, N_res, N_res]
bias
=
bias
.
expand
(
#
bias = bias.expand(
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
#
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
#
)
biases
=
[
bias
]
biases
=
[
bias
]
...
@@ -150,7 +151,13 @@ class MSAAttention(nn.Module):
...
@@ -150,7 +151,13 @@ class MSAAttention(nn.Module):
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
m
=
self
.
_chunk
(
m
,
biases
,
chunk_size
)
else
:
else
:
m
=
self
.
mha
(
q_x
=
m
,
k_x
=
m
,
v_x
=
m
,
biases
=
biases
)
m
=
self
.
mha
(
q_x
=
m
,
k_x
=
m
,
v_x
=
m
,
biases
=
biases
,
_chunk_and_checkpoint
=
_chunk_and_checkpoint
)
return
m
return
m
...
...
openfold/model/primitives.py
View file @
df89bb28
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partial
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
numpy
as
np
...
@@ -21,6 +22,7 @@ import torch
...
@@ -21,6 +22,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
...
@@ -164,6 +166,67 @@ class Linear(nn.Linear):
...
@@ -164,6 +166,67 @@ class Linear(nn.Linear):
raise
ValueError
(
"Invalid init string."
)
raise
ValueError
(
"Invalid init string."
)
def
_attention
(
query
,
key
,
value
,
biases
):
a
=
torch
.
matmul
(
query
,
key
)
for
b
in
biases
:
a
+=
b
a
=
torch
.
nn
.
functional
.
softmax
(
a
,
dim
=-
1
)
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
return
o
@
torch
.
jit
.
ignore
def
_attention_chunk_and_checkpoint
(
query
,
key
,
value
,
biases
,
chunk_size
):
if
(
len
(
biases
)
>
2
):
raise
ValueError
(
"_chunk_and_checkpoint only permits two bias terms"
)
biases
=
biases
+
[
None
,
None
]
bias_1
,
bias_2
=
biases
[:
2
]
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b1
,
b2
]
return
_attention
(
q
,
k
,
v
,
bs
)
batch_dims
=
query
.
shape
[:
-
3
]
no_batch_dims
=
len
(
query
.
shape
[:
-
3
])
# q, k, and v are assumed to have no singleton dimensions
flat_q
=
query
.
reshape
(
-
1
,
*
query
.
shape
[
-
3
:])
flat_k
=
key
.
reshape
(
-
1
,
*
key
.
shape
[
-
3
:])
flat_v
=
value
.
reshape
(
-
1
,
*
value
.
shape
[
-
3
:])
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
count
=
flat_q
.
shape
[
0
]
for
start
in
range
(
0
,
count
,
chunk_size
):
end
=
start
+
chunk_size
q_chunk
=
flat_q
[
start
:
end
,
...]
k_chunk
=
flat_k
[
start
:
end
,
...]
v_chunk
=
flat_v
[
start
:
end
,
...]
bias_1_chunk
=
_chunk_slice
(
bias_1
,
start
,
end
,
no_batch_dims
)
bias_2_chunk
=
_chunk_slice
(
bias_2
,
start
,
end
,
no_batch_dims
)
o_chunk
=
checkpoint_fn
(
_checkpointable_attention
,
q_chunk
,
k_chunk
,
v_chunk
,
bias_1_chunk
,
bias_2_chunk
)
o_chunks
.
append
(
o_chunk
)
o_flat
=
torch
.
cat
(
o_chunks
,
dim
=
0
)
return
o_flat
.
reshape
(
batch_dims
+
o_flat
.
shape
[
1
:])
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
"""
"""
Standard multi-head attention using AlphaFold's default layer
Standard multi-head attention using AlphaFold's default layer
...
@@ -225,7 +288,6 @@ class Attention(nn.Module):
...
@@ -225,7 +288,6 @@ class Attention(nn.Module):
)
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -233,6 +295,10 @@ class Attention(nn.Module):
...
@@ -233,6 +295,10 @@ class Attention(nn.Module):
k_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
_chunk_and_checkpoint
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -245,6 +311,18 @@ class Attention(nn.Module):
...
@@ -245,6 +311,18 @@ class Attention(nn.Module):
Returns
Returns
[*, Q, C_q] attention update
[*, Q, C_q] attention update
"""
"""
if
(
biases
is
None
):
biases
=
[]
if
(
use_lma
and
(
q_chunk_size
is
None
or
kv_chunk_size
is
None
)):
raise
ValueError
(
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if
(
use_lma
and
_chunk_and_checkpoint
is
not
None
):
raise
ValueError
(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
k
=
self
.
linear_k
(
k_x
)
...
@@ -255,34 +333,33 @@ class Attention(nn.Module):
...
@@ -255,34 +333,33 @@ class Attention(nn.Module):
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
/
math
.
sqrt
(
self
.
c_hidden
)
if
(
use_lma
):
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
# [*, H, C_hidden, K]
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
q
,
k
)
del
q
,
k
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
*=
norm
if
biases
is
not
None
:
for
b
in
biases
:
a
+=
b
a
=
self
.
softmax
(
a
)
# [*, H, V, C_hidden]
# [*, H, V, C_hidden]
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
# [*, H, Q, C_hidden]
if
(
_chunk_and_checkpoint
):
o
=
torch
.
matmul
(
a
,
v
)
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o
=
_attention_chunk_and_checkpoint
(
q
,
k
,
v
,
biases
,
_chunk_and_checkpoint
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
if
(
self
.
linear_g
is
not
None
):
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
...
@@ -374,14 +451,13 @@ class GlobalAttention(nn.Module):
...
@@ -374,14 +451,13 @@ class GlobalAttention(nn.Module):
return
m
return
m
@
torch
.
jit
.
script
def
_lma
(
def
_lma
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
q_chunk_size
:
int
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
kv_chunk_size
:
int
,
):
):
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
no_q
,
no_kv
=
q
.
shape
[
-
3
],
k
.
shape
[
-
3
]
...
@@ -389,7 +465,7 @@ def _lma(
...
@@ -389,7 +465,7 @@ def _lma(
o
=
q
.
new_zeros
(
q
.
shape
)
o
=
q
.
new_zeros
(
q
.
shape
)
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
for
q_s
in
range
(
0
,
no_q
,
q_chunk_size
):
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
q_chunk
=
q
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
big
_bias_chunks
=
[
large
_bias_chunks
=
[
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
b
[...,
q_s
:
q_s
+
q_chunk_size
,
:]
for
b
in
biases
]
]
...
@@ -400,11 +476,11 @@ def _lma(
...
@@ -400,11 +476,11 @@ def _lma(
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
k_chunk
=
k
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
v_chunk
=
v
[...,
kv_s
:
kv_s
+
kv_chunk_size
,
:,
:]
small_bias_chunks
=
[
small_bias_chunks
=
[
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
big
_bias_chunks
b
[...,
kv_s
:
kv_s
+
kv_chunk_size
]
for
b
in
large
_bias_chunks
]
]
a
=
torch
.
einsum
(
a
=
torch
.
einsum
(
"...qhd,...khd->...hqk"
,
q
_chunk
,
k_chunk
"...qhd,...khd->...hqk"
,
q
uery
,
key
)
)
for
b
in
small_bias_chunks
:
for
b
in
small_bias_chunks
:
...
@@ -412,11 +488,11 @@ def _lma(
...
@@ -412,11 +488,11 @@ def _lma(
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
.
detach
()
max_a
=
torch
.
max
(
a
,
dim
=-
1
,
keepdim
=
True
)[
0
]
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_a
=
torch
.
exp
(
a
-
max_a
)
exp_v
=
torch
.
einsum
(
"...vhf,...qhv->...qhf"
,
v
_chunk
,
exp_a
)
exp_v
=
torch
.
einsum
(
"...vhf,...qhv->...qhf"
,
v
alue
,
exp_a
)
maxes
.
append
(
max_a
.
squeeze
(
-
1
))
maxes
.
append
(
max_a
.
detach
().
squeeze
(
-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
weights
.
append
(
torch
.
sum
(
exp_a
,
dim
=-
1
))
values
.
append
(
exp_v
)
values
.
append
(
exp_v
)
...
@@ -437,111 +513,3 @@ def _lma(
...
@@ -437,111 +513,3 @@ def _lma(
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
o
[...,
q_s
:
q_s
+
q_chunk_size
,
:,
:]
=
q_chunk_out
return
o
return
o
class
LowMemoryAttention
(
nn
.
Module
):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors. Implements Rabe and Staats'
low-memory self-attention algorithm.
"""
def
__init__
(
self
,
c_q
:
int
,
c_k
:
int
,
c_v
:
int
,
c_hidden
:
int
,
no_heads
:
int
,
gating
:
bool
=
True
,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
chunk_size:
Trades memory for better parallelization. A low value
corresponds to lower memory usage.
"""
super
().
__init__
()
self
.
c_q
=
c_q
self
.
c_k
=
c_k
self
.
c_v
=
c_v
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
gating
=
gating
self
.
linear_q
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
self
.
c_k
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_v
=
Linear
(
self
.
c_v
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_o
=
Linear
(
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_q
,
init
=
"final"
)
if
self
.
gating
:
self
.
linear_g
=
Linear
(
self
.
c_q
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
q_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
q_chunk_size
:
int
,
kv_chunk_size
:
int
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
):
if
(
biases
is
None
):
biases
=
[]
else
:
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
for
b
in
biases
]
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
v
=
self
.
linear_v
(
v_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
/
math
.
sqrt
(
q
.
shape
[
-
1
])
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
if
self
.
gating
:
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
openfold/model/triangular_attention.py
View file @
df89bb28
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
functools
import
partialmethod
from
functools
import
partialmethod
,
partial
import
math
import
math
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
...
@@ -70,7 +70,7 @@ class TriangleAttention(nn.Module):
...
@@ -70,7 +70,7 @@ class TriangleAttention(nn.Module):
"biases"
:
biases
,
"biases"
:
biases
,
}
}
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
partial
(
self
.
mha
)
,
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
...
...
openfold/utils/checkpointing.py
View file @
df89bb28
...
@@ -15,17 +15,27 @@
...
@@ -15,17 +15,27 @@
import
deepspeed
import
deepspeed
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
from
typing
import
Any
,
Tuple
,
List
,
Callable
,
Optional
BLOCK_ARG
=
Any
BLOCK_ARG
=
Any
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
BLOCK_ARGS
=
List
[
BLOCK_ARG
]
def
get_checkpoint_fn
():
if
(
deepspeed
.
checkpointing
.
is_configured
()):
checkpoint
=
deepspeed
.
checkpointing
.
checkpoint
else
:
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
return
checkpoint
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
checkpoint_blocks
(
def
checkpoint_blocks
(
blocks
:
List
[
Callable
],
blocks
:
List
[
Callable
],
args
:
BLOCK_ARGS
,
args
:
BLOCK_ARGS
,
blocks_per_ckpt
:
int
,
blocks_per_ckpt
:
Optional
[
int
]
,
)
->
BLOCK_ARGS
:
)
->
BLOCK_ARGS
:
"""
"""
Chunk a list of blocks and run each chunk with activation
Chunk a list of blocks and run each chunk with activation
...
@@ -68,10 +78,7 @@ def checkpoint_blocks(
...
@@ -68,10 +78,7 @@ def checkpoint_blocks(
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
elif
blocks_per_ckpt
<
1
or
blocks_per_ckpt
>
len
(
blocks
):
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
raise
ValueError
(
"blocks_per_ckpt must be between 1 and len(blocks)"
)
if
(
deepspeed
.
checkpointing
.
is_configured
()):
checkpoint
=
get_checkpoint_fn
()
checkpoint
=
deepspeed
.
checkpointing
.
checkpoint
else
:
checkpoint
=
torch
.
utils
.
checkpoint
.
checkpoint
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment