Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
143ba486
Commit
143ba486
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Refactor inplace operations, fix training
parent
f1402490
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
249 additions
and
137 deletions
+249
-137
openfold/model/embedders.py
openfold/model/embedders.py
+5
-6
openfold/model/evoformer.py
openfold/model/evoformer.py
+124
-56
openfold/model/model.py
openfold/model/model.py
+76
-42
openfold/model/msa.py
openfold/model/msa.py
+14
-7
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+2
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+4
-2
openfold/model/template.py
openfold/model/template.py
+18
-15
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+4
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+2
-2
openfold/utils/chunk_utils.py
openfold/utils/chunk_utils.py
+0
-2
No files found.
openfold/model/embedders.py
View file @
143ba486
...
...
@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module):
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
...
...
@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
...
...
@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
...
...
@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
"""
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
if
(
_
inplace
):
if
(
inplace
_safe
):
m
.
copy_
(
m_update
)
m_update
=
m
# [*, N, N, C_z]
z_update
=
self
.
layer_norm_z
(
z
)
if
(
_
inplace
):
if
(
inplace
_safe
):
z
.
copy_
(
z_update
)
z_update
=
z
...
...
@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
d
=
self
.
linear
(
d
)
z_update
=
add
(
z_update
,
d
,
_
inplace
)
z_update
=
add
(
z_update
,
d
,
inplace
_safe
)
return
m_update
,
z_update
...
...
openfold/model/evoformer.py
View file @
143ba486
...
...
@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
...
...
@@ -197,9 +198,6 @@ class EvoformerBlockCore(nn.Module):
m
,
z
=
input_tensors
# Need to dodge activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
m
=
add
(
m
,
self
.
msa_transition
(
...
...
@@ -215,7 +213,7 @@ class EvoformerBlockCore(nn.Module):
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_
inplace
=
inplace_safe
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace
_safe
=
inplace_safe
)
if
(
_offload_inference
and
inplace_safe
):
...
...
@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update
=
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
...
...
@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update
=
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
...
...
@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module):
z
,
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
...
...
@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
...
...
@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module):
)
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
m
=
add
(
m
,
...
...
@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
)
if
(
inplace_safe
):
out
=
input_tensors
else
:
out
=
[
m
,
z
]
return
out
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
...
...
@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module):
)
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# If function calls could speak...
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
...
...
@@ -509,6 +519,9 @@ class ExtraMSABlock(nn.Module):
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
def
fn
(
input_tensors
):
...
...
@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module):
)
if
(
not
inplace_safe
):
input_tensors
[
m
,
input_tensors
[
1
]]
input_tensors
=
[
m
,
input_tensors
[
1
]]
del
m
...
...
@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
...
...
@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module):
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_forward_list
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
...
...
@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_offload_inference
=
_offload_inference
,
)
for
b
in
self
.
blocks
]
...
...
@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module):
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"evo"
)
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
# We don't want to write in-place during chunk tuning runs
args
=
(
[
t
.
clone
()
for
t
in
input_tensors
]
,),
args
=
(
m
.
clone
()
,
z
.
clone
()
,),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
...
...
@@ -693,15 +708,42 @@ class EvoformerStack(nn.Module):
)
for
b
in
blocks
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
return
blocks
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
input_tensors
,
blocks_per_ckpt
=
blocks_per_ckpt
,
)[
0
]
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
m
,
z
=
input_tensors
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
...
...
@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module):
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
...
...
@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module):
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
return
self
.
_forward_list
(
[
m
,
z
],
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
blocks
=
self
.
_prep_blocks
(
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
blocks_per_ckpt
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
class
ExtraMSAStack
(
nn
.
Module
):
"""
...
...
@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
**
kwargs
,
):
...
...
@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module):
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
...
...
@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
False
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module):
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
...
...
@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
...
...
@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"extra"
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
([
m
.
clone
(),
z
.
clone
()],),
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
...
...
@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module):
return
blocks
def
_forward_
list
(
self
,
def
_forward_
offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
_offload_inference
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
(
not
self
.
training
)
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
())
)
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
...
...
@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module):
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
input_tensors
,
_offload_inference
=
_offload_inference
)
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
...
...
@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module):
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module):
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
(
m
,
z
)
)
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
m
,
z
=
b
(
m
,
z
)
...
...
openfold/model/model.py
View file @
143ba486
...
...
@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self
.
config
[
"heads"
],
)
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
...
...
@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
t_pair
...
...
@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
...
...
@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
return
ret
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
...
...
@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
...
...
@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
inplace_safe
=
inplace_safe
,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function
.
# them to be freed further down in this function
, saving memory
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
...
...
@@ -263,6 +266,7 @@ class AlphaFold(nn.Module):
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
# The recycling embedder is memory-intensive, so we offload first
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
cpu
()
z
=
z
.
cpu
()
...
...
@@ -273,7 +277,7 @@ class AlphaFold(nn.Module):
m_1_prev
,
z_prev
,
x_prev
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
)
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
...
...
@@ -286,6 +290,9 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
...
...
@@ -298,6 +305,7 @@ class AlphaFold(nn.Module):
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
inplace_safe
=
inplace_safe
,
)
# [*, N, N, C_z]
...
...
@@ -306,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe
,
)
if
self
.
config
.
template
.
embed
_angle
s
:
if
"template_angle_embedding"
in
template
_
embeds
:
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
...
...
@@ -325,29 +333,43 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
if
(
self
.
globals
.
offload_inference
):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors
=
[
a
,
z
]
del
a
,
z
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
.
_forward_
list
(
z
=
self
.
extra_msa_stack
.
_forward_
offload
(
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
del
input_tensors
else
:
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if
(
self
.
globals
.
offload_inference
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_
list
(
m
,
z
,
s
=
self
.
evoformer
.
_forward_
offload
(
input_tensors
,
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
...
...
@@ -357,6 +379,17 @@ class AlphaFold(nn.Module):
)
del
input_tensors
else
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
...
...
@@ -369,6 +402,7 @@ class AlphaFold(nn.Module):
outputs
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
inplace_safe
=
inplace_safe
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
...
...
openfold/model/msa.py
View file @
143ba486
...
...
@@ -117,10 +117,9 @@ class MSAAttention(nn.Module):
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
_inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
# [*, N_seq, N_res]
...
...
@@ -163,6 +162,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
inplace_safe
:
bool
=
False
)
->
torch
.
Tensor
:
"""
MSA attention with training-time chunking of the softmax computation.
...
...
@@ -172,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
...
...
@@ -208,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -229,10 +232,14 @@ class MSAAttention(nn.Module):
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
,
inplace_safe
=
inplace_safe
,
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
...
...
openfold/model/outer_product_mean.py
View file @
143ba486
...
...
@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
if
(
_
inplace
):
if
(
inplace
_safe
):
outer
/=
norm
else
:
outer
=
outer
/
norm
...
...
openfold/model/structure_module.py
View file @
143ba486
...
...
@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
],
r
:
Rigid
,
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -248,7 +249,6 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
...
...
@@ -619,6 +619,7 @@ class StructureModule(nn.Module):
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
"""
...
...
@@ -674,6 +675,7 @@ class StructureModule(nn.Module):
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
...
...
openfold/model/template.py
View file @
143ba486
...
...
@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module):
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_inplace
:
bool
=
False
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
if
(
_attn_chunk_size
is
None
):
...
...
@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
...
...
@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
_
inplace
,
inplace
_safe
,
)
single
=
add
(
single
,
...
...
@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
_
inplace
,
inplace
_safe
,
)
tmu_update
=
self
.
tri_mul_out
(
single
,
mask
=
single_mask
,
_
inplace
=
_inplace
,
inplace_
safe
=
inplace
_safe
,
_add_with_inplace
=
True
,
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
...
...
@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update
=
self
.
tri_mul_in
(
single
,
mask
=
single_mask
,
_
inplace
=
_inplace
,
inplace_
safe
=
inplace
_safe
,
_add_with_inplace
=
True
,
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
...
...
@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
),
_
inplace
,
inplace
_safe
,
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single_templates
[
i
]
=
single
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
...
...
@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module):
mask
:
torch
.
tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
):
"""
...
...
@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module):
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
for
b
in
self
.
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
.
clone
(),),
...
...
@@ -411,6 +416,7 @@ def embed_templates_offload(
pair_mask
,
templ_dim
,
template_chunk_size
=
256
,
inplace_safe
=
False
,
):
"""
Args:
...
...
@@ -435,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
...
...
@@ -519,6 +523,7 @@ def embed_templates_average(
pair_mask
,
templ_dim
,
templ_group_size
=
2
,
inplace_safe
=
False
,
):
"""
Args:
...
...
@@ -547,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
openfold/model/triangular_attention.py
View file @
143ba486
...
...
@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"triangle! triangle!"
mha_inputs
=
{
...
...
@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module):
"biases"
:
biases
,
}
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
return
chunk_layer
(
partial
(
self
.
mha
,
...
...
@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module):
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
else
:
x
=
self
.
mha
(
...
...
openfold/model/triangular_multiplicative_update.py
View file @
143ba486
...
...
@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
...
...
@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
(
_
inplace
):
if
(
inplace
_safe
):
x
=
self
.
_inference_forward
(
z
,
mask
,
...
...
openfold/utils/chunk_utils.py
View file @
143ba486
...
...
@@ -415,8 +415,6 @@ class ChunkSizeTuner:
# Otherwise, we can reuse the precomputed value
consistent
=
False
print
(
consistent
)
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment