Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
fe9ad07e
Commit
fe9ad07e
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Add offloading to evoformer
parent
b40fab25
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
238 additions
and
129 deletions
+238
-129
openfold/model/evoformer.py
openfold/model/evoformer.py
+238
-129
No files found.
openfold/model/evoformer.py
View file @
fe9ad07e
...
...
@@ -16,7 +16,7 @@
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Sequence
,
Optional
from
functools
import
partial
from
openfold.model.primitives
import
Linear
,
LayerNorm
...
...
@@ -29,6 +29,7 @@ from openfold.model.msa import (
from
openfold.model.outer_product_mean
import
OuterProductMean
from
openfold.model.pair_transition
import
PairTransition
from
openfold.model.triangular_attention
import
(
TriangleAttention
,
TriangleAttentionStartingNode
,
TriangleAttentionEndingNode
,
)
...
...
@@ -37,7 +38,8 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
,
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
,
ChunkSizeTuner
from
openfold.utils.chunk_utils
import
chunk_layer
,
ChunkSizeTuner
from
openfold.utils.tensor_utils
import
add
class
MSATransition
(
nn
.
Module
):
...
...
@@ -66,6 +68,7 @@ class MSATransition(nn.Module):
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_m
,
self
.
c_m
,
init
=
"final"
)
def
_transition
(
self
,
m
,
mask
):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
linear_1
(
m
)
m
=
self
.
relu
(
m
)
m
=
self
.
linear_2
(
m
)
*
mask
...
...
@@ -107,8 +110,6 @@ class MSATransition(nn.Module):
mask
=
mask
.
unsqueeze
(
-
1
)
m
=
self
.
layer_norm
(
m
)
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
)
else
:
...
...
@@ -155,13 +156,13 @@ class EvoformerBlockCore(nn.Module):
c_hidden_mul
,
)
self
.
tri_att_start
=
TriangleAttention
StartingNode
(
self
.
tri_att_start
=
TriangleAttention
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
inf
=
inf
,
)
self
.
tri_att_end
=
TriangleAttention
EndingNode
(
self
.
tri_att_end
=
TriangleAttention
(
c_z
,
c_hidden_pair_att
,
no_heads_pair
,
...
...
@@ -174,18 +175,16 @@ class EvoformerBlockCore(nn.Module):
)
self
.
ps_dropout_row_layer
=
DropoutRowwise
(
pair_dropout
)
self
.
ps_dropout_col_layer
=
DropoutColumnwise
(
pair_dropout
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
...
...
@@ -196,6 +195,8 @@ class EvoformerBlockCore(nn.Module):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
,
z
=
input_tensors
# Need to dodge activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
...
...
@@ -205,13 +206,26 @@ class EvoformerBlockCore(nn.Module):
m
,
mask
=
msa_trans_mask
,
chunk_size
=
chunk_size
,
),
inplace
=
inplace_safe
,
)
if
(
_offload_inference
and
inplace_safe
):
del
m
,
z
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_inplace
=
inplace_safe
)
z
=
add
(
z
,
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_inplace
=
inplace_safe
),
inplace
=
inplace_safe
,
)
if
(
_offload_inference
and
inplace_safe
):
del
m
,
z
input_tensors
[
0
]
=
input_tensors
[
0
].
cpu
()
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
opm
.
device
)
m
,
z
=
input_tensors
z
=
add
(
z
,
opm
,
inplace
=
inplace_safe
)
del
opm
tmu_update
=
self
.
tri_mul_out
(
z
,
...
...
@@ -250,17 +264,30 @@ class EvoformerBlockCore(nn.Module):
),
inplace
=
inplace_safe
,
)
z
=
add
(
z
,
self
.
ps_dropout_col_layer
(
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
z
=
add
(
z
,
self
.
ps_dropout_row_layer
(
self
.
tri_att_end
(
z
,
mask
=
pair_mask
,
z
,
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
)
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
)
),
inplace
=
inplace_safe
,
)
z
=
z
.
transpose
(
-
2
,
-
3
)
if
(
inplace_safe
):
input_tensors
[
1
]
=
z
.
contiguous
()
z
=
input_tensors
[
1
]
z
=
add
(
z
,
self
.
pair_transition
(
z
,
mask
=
pair_trans_mask
,
chunk_size
=
chunk_size
,
...
...
@@ -268,6 +295,13 @@ class EvoformerBlockCore(nn.Module):
inplace
=
inplace_safe
,
)
if
(
_offload_inference
and
inplace_safe
):
device
=
z
.
device
del
m
,
z
input_tensors
[
0
]
=
input_tensors
[
0
].
to
(
device
)
input_tensors
[
1
]
=
input_tensors
[
1
].
to
(
device
)
m
,
z
=
input_tensors
return
m
,
z
...
...
@@ -321,23 +355,22 @@ class EvoformerBlock(nn.Module):
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
print
(
chunk_size
)
print
(
_attn_chunk_size
)
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
,
z
=
input_tensors
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
...
...
@@ -359,18 +392,29 @@ class EvoformerBlock(nn.Module):
),
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
input_tensors
[
1
]]
del
m
,
z
m
,
z
=
self
.
core
(
m
,
z
,
input_tensors
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
)
return
m
,
z
if
(
inplace_safe
):
out
=
input_tensors
else
:
out
=
[
m
,
z
]
return
out
class
ExtraMSABlock
(
nn
.
Module
):
...
...
@@ -433,19 +477,21 @@ class ExtraMSABlock(nn.Module):
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
1024
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
,
z
=
input_tensors
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# If function calls could speak...
m
=
add
(
m
,
self
.
msa_dropout_layer
(
...
...
@@ -455,44 +501,50 @@ class ExtraMSABlock(nn.Module):
mask
=
msa_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
use_memory_efficient_kernel
=
not
_chunk_logits
and
not
use_lma
,
_chunk_logits
=
_chunk_logits
if
torch
.
is_grad_enabled
()
else
None
,
use_memory_efficient_kernel
=
not
use_lma
,
_checkpoint_chunks
=
self
.
ckpt
if
torch
.
is_grad_enabled
()
else
False
,
)
),
inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
,
inplace
=
inplace_safe
,
)
del
m
,
z
def
fn
(
m
,
z
):
m
=
add
(
m
,
def
fn
(
input_tensors
):
m
=
add
(
input_tensors
[
0
]
,
self
.
msa_att_col
(
m
,
input_tensors
[
0
]
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
),
inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
,
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
[
m
,
input_tensors
[
1
]]
del
m
m
,
z
=
self
.
core
(
m
,
z
,
input_tensors
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
)
return
m
,
z
if
(
torch
.
is_grad_enabled
()
and
self
.
ckpt
):
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
checkpoint_fn
(
fn
,
m
,
z
)
m
,
z
=
checkpoint_fn
(
fn
,
input_tensors
)
else
:
m
,
z
=
fn
(
m
,
z
)
m
,
z
=
fn
(
input_tensors
)
return
m
,
z
...
...
@@ -595,37 +647,15 @@ class EvoformerStack(nn.Module):
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
def
_forward_list
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
blocks
=
[
partial
(
b
,
...
...
@@ -634,6 +664,7 @@ class EvoformerStack(nn.Module):
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
_offload_inference
=
_offload_inference
,
)
for
b
in
self
.
blocks
]
...
...
@@ -646,9 +677,11 @@ class EvoformerStack(nn.Module):
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"evo"
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
# We don't want to write in-place during chunk tuning runs
args
=
([
t
.
clone
()
for
t
in
input_tensors
],),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
...
...
@@ -666,14 +699,54 @@ class EvoformerStack(nn.Module):
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
)
,
args
=
input_tensors
,
blocks_per_ckpt
=
blocks_per_ckpt
,
)
)
[
0
]
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
chunk_size:
Inference-time subbatch size. Acts as a minimum if
self.tune_chunk_size is True
use_lma: Whether to use low-memory attention during inference
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
return
self
.
_forward_list
(
[
m
,
z
],
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
class
ExtraMSAStack
(
nn
.
Module
):
"""
...
...
@@ -730,6 +803,81 @@ class ExtraMSAStack(nn.Module):
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
,
**
kwargs
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"extra"
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
([
m
.
clone
(),
z
.
clone
()],),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
return
blocks
def
_forward_list
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
_offload_inference
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
(
not
self
.
training
)
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
input_tensors
,
_offload_inference
=
_offload_inference
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
return
input_tensors
[
1
]
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
...
...
@@ -754,60 +902,21 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
if
(
not
self
.
chunk_msa_attn
):
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_chunk_logits
=
None
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
def
clear_cache
(
b
,
*
args
,
**
kwargs
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
,
**
kwargs
)
if
(
self
.
clear_cache_between_blocks
):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
m
,
z
),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
partial
(
b
,
chunk_size
=
tuned_chunk_size
,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size
=
max
(
chunk_size
,
tuned_chunk_size
//
4
),
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
*
(
m
,
z
))
else
:
m
,
z
=
b
(
m
,
z
)
else
:
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
_mask_trans
=
_mask_trans
)
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
self
.
_prep_blocks
(
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
_mask_trans
=
_mask_trans
,
)
if
(
self
.
clear_cache_between_blocks
):
torch
.
cuda
.
empty_cache
()
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
(
m
,
z
))
else
:
m
,
z
=
b
(
m
,
z
)
return
z
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment