Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
143ba486
"segmentation/git@developer.sourcefind.cn:OpenDAS/dcnv3.git" did not exist on "7142c933b5dc64f33a88539e12108bdbcce11b5e"
Commit
143ba486
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Refactor inplace operations, fix training
parent
f1402490
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
249 additions
and
137 deletions
+249
-137
openfold/model/embedders.py
openfold/model/embedders.py
+5
-6
openfold/model/evoformer.py
openfold/model/evoformer.py
+124
-56
openfold/model/model.py
openfold/model/model.py
+76
-42
openfold/model/msa.py
openfold/model/msa.py
+14
-7
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+2
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+4
-2
openfold/model/template.py
openfold/model/template.py
+18
-15
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+4
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+2
-2
openfold/utils/chunk_utils.py
openfold/utils/chunk_utils.py
+0
-2
No files found.
openfold/model/embedders.py
View file @
143ba486
...
...
@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module):
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
...
...
@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
...
...
@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
...
...
@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
"""
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
if
(
_
inplace
):
if
(
inplace
_safe
):
m
.
copy_
(
m_update
)
m_update
=
m
# [*, N, N, C_z]
z_update
=
self
.
layer_norm_z
(
z
)
if
(
_
inplace
):
if
(
inplace
_safe
):
z
.
copy_
(
z_update
)
z_update
=
z
...
...
@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
d
=
self
.
linear
(
d
)
z_update
=
add
(
z_update
,
d
,
_
inplace
)
z_update
=
add
(
z_update
,
d
,
inplace
_safe
)
return
m_update
,
z_update
...
...
openfold/model/evoformer.py
View file @
143ba486
...
...
@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
...
...
@@ -191,15 +192,12 @@ class EvoformerBlockCore(nn.Module):
# the original.
msa_trans_mask
=
msa_mask
if
_mask_trans
else
None
pair_trans_mask
=
pair_mask
if
_mask_trans
else
None
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
m
,
z
=
input_tensors
# Need to dodge activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
m
=
add
(
m
,
self
.
msa_transition
(
...
...
@@ -213,9 +211,9 @@ class EvoformerBlockCore(nn.Module):
input_tensors
[
1
]
=
input_tensors
[
1
].
cpu
()
torch
.
cuda
.
empty_cache
()
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_
inplace
=
inplace_safe
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace
_safe
=
inplace_safe
)
if
(
_offload_inference
and
inplace_safe
):
...
...
@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update
=
self
.
tri_mul_out
(
z
,
mask
=
pair_mask
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
...
...
@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update
=
self
.
tri_mul_in
(
z
,
mask
=
pair_mask
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
)
if
(
not
inplace_safe
):
...
...
@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module):
z
,
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
...
...
@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
inplace
=
inplace_safe
,
...
...
@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module):
)
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
m
=
add
(
m
,
...
...
@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
)
if
(
inplace_safe
):
out
=
input_tensors
else
:
out
=
[
m
,
z
]
return
out
return
m
,
z
class
ExtraMSABlock
(
nn
.
Module
):
...
...
@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module):
)
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# If function calls could speak...
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
...
...
@@ -509,8 +519,11 @@ class ExtraMSABlock(nn.Module):
inplace
=
inplace_safe
,
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
def
fn
(
input_tensors
):
m
=
add
(
input_tensors
[
0
],
self
.
msa_att_col
(
...
...
@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module):
)
if
(
not
inplace_safe
):
input_tensors
[
m
,
input_tensors
[
1
]]
input_tensors
=
[
m
,
input_tensors
[
1
]]
del
m
...
...
@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
...
...
@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module):
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_forward_list
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
def
_prep_blocks
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_offload_inference
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
partial
(
b
,
...
...
@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_offload_inference
=
_offload_inference
,
)
for
b
in
self
.
blocks
]
...
...
@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module):
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"evo"
)
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
# We don't want to write in-place during chunk tuning runs
args
=
(
[
t
.
clone
()
for
t
in
input_tensors
]
,),
args
=
(
m
.
clone
()
,
z
.
clone
()
,),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
...
...
@@ -693,16 +708,43 @@ class EvoformerStack(nn.Module):
)
for
b
in
blocks
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
return
blocks
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
input_tensors
,
blocks_per_ckpt
=
blocks_per_ckpt
,
)[
0
]
def
_forward_offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
m
,
z
=
input_tensors
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
...
...
@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module):
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
...
...
@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module):
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
return
self
.
_forward_list
(
[
m
,
z
],
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
blocks
=
self
.
_prep_blocks
(
m
=
m
,
z
=
z
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
blocks_per_ckpt
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
class
ExtraMSAStack
(
nn
.
Module
):
"""
...
...
@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
**
kwargs
,
):
...
...
@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module):
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
...
...
@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
False
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module):
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
...
...
@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module):
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
...
...
@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"extra"
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
([
m
.
clone
(),
z
.
clone
()],),
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
)
blocks
=
[
...
...
@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module):
return
blocks
def
_forward_
list
(
self
,
def
_forward_
offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
_offload_inference
:
bool
=
False
,
)
->
torch
.
Tensor
:
assert
(
not
self
.
training
)
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
())
)
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
...
...
@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module):
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
input_tensors
,
_offload_inference
=
_offload_inference
)
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
...
...
@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module):
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
"""
...
...
@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module):
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
(
m
,
z
)
)
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
m
,
z
=
b
(
m
,
z
)
...
...
openfold/model/model.py
View file @
143ba486
...
...
@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self
.
config
[
"heads"
],
)
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
return
embed_templates_offload
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
return
embed_templates_average
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
...
...
@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
t_pair
...
...
@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
batch
...
...
@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
return
ret
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
...
...
@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
...
...
@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
inplace_safe
=
inplace_safe
,
)
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function
.
# them to be freed further down in this function
, saving memory
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
...
...
@@ -263,6 +266,7 @@ class AlphaFold(nn.Module):
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
# The recycling embedder is memory-intensive, so we offload first
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
cpu
()
z
=
z
.
cpu
()
...
...
@@ -273,7 +277,7 @@ class AlphaFold(nn.Module):
m_1_prev
,
z_prev
,
x_prev
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
)
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
...
...
@@ -286,10 +290,13 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if
self
.
config
.
template
.
enabled
:
if
self
.
config
.
template
.
enabled
:
template_feats
=
{
k
:
v
for
k
,
v
in
feats
.
items
()
if
k
.
startswith
(
"template_"
)
}
...
...
@@ -298,6 +305,7 @@ class AlphaFold(nn.Module):
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
inplace_safe
=
inplace_safe
,
)
# [*, N, N, C_z]
...
...
@@ -306,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe
,
)
if
self
.
config
.
template
.
embed
_angle
s
:
if
"template_angle_embedding"
in
template
_
embeds
:
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
...
...
@@ -324,39 +332,64 @@ class AlphaFold(nn.Module):
if
self
.
config
.
extra_msa
.
enabled
:
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
input_tensors
=
[
a
,
z
]
del
a
,
z
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
.
_forward_list
(
if
(
self
.
globals
.
offload_inference
):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors
=
[
a
,
z
]
del
a
,
z
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
.
_forward_offload
(
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
input_tensors
else
:
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
if
(
self
.
globals
.
offload_inference
):
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_offload
(
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
del
input_tensors
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
input_tensors
=
[
m
,
z
]
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_list
(
input_tensors
,
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
del
input_tensors
else
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
...
...
@@ -369,6 +402,7 @@ class AlphaFold(nn.Module):
outputs
,
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
inplace_safe
=
inplace_safe
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
...
...
openfold/model/msa.py
View file @
143ba486
...
...
@@ -117,10 +117,9 @@ class MSAAttention(nn.Module):
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
_inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
# [*, N_seq, N_res]
...
...
@@ -163,6 +162,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
inplace_safe
:
bool
=
False
)
->
torch
.
Tensor
:
"""
MSA attention with training-time chunking of the softmax computation.
...
...
@@ -172,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
...
...
@@ -208,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -229,10 +232,14 @@ class MSAAttention(nn.Module):
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
,
inplace_safe
=
inplace_safe
,
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
...
...
openfold/model/outer_product_mean.py
View file @
143ba486
...
...
@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
if
(
_
inplace
):
if
(
inplace
_safe
):
outer
/=
norm
else
:
outer
=
outer
/
norm
...
...
openfold/model/structure_module.py
View file @
143ba486
...
...
@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
],
r
:
Rigid
,
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -248,12 +249,11 @@ class InvariantPointAttention(nn.Module):
Returns:
[*, N_res, C_s] single representation update
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
else
:
z
=
[
z
]
#######################################
# Generate scalar and point activations
#######################################
...
...
@@ -619,6 +619,7 @@ class StructureModule(nn.Module):
evoformer_output_dict
,
aatype
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
):
"""
...
...
@@ -674,6 +675,7 @@ class StructureModule(nn.Module):
z
,
rigids
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
)
...
...
openfold/model/template.py
View file @
143ba486
...
...
@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module):
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_inplace
:
bool
=
False
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
if
(
_attn_chunk_size
is
None
):
...
...
@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
...
...
@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
_
inplace
,
inplace
_safe
,
)
single
=
add
(
single
,
...
...
@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
),
_
inplace
,
inplace
_safe
,
)
tmu_update
=
self
.
tri_mul_out
(
single
,
mask
=
single_mask
,
_
inplace
=
_inplace
,
inplace_
safe
=
inplace
_safe
,
_add_with_inplace
=
True
,
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
...
...
@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update
=
self
.
tri_mul_in
(
single
,
mask
=
single_mask
,
_
inplace
=
_inplace
,
inplace_
safe
=
inplace
_safe
,
_add_with_inplace
=
True
,
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
single
=
tmu_update
...
...
@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
),
_
inplace
,
inplace
_safe
,
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single_templates
[
i
]
=
single
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
...
...
@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module):
mask
:
torch
.
tensor
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
):
"""
...
...
@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module):
mask
=
mask
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
for
b
in
self
.
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
args
=
(
t
.
clone
(),),
...
...
@@ -411,6 +416,7 @@ def embed_templates_offload(
pair_mask
,
templ_dim
,
template_chunk_size
=
256
,
inplace_safe
=
False
,
):
"""
Args:
...
...
@@ -435,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
...
...
@@ -519,6 +523,7 @@ def embed_templates_average(
pair_mask
,
templ_dim
,
templ_group_size
=
2
,
inplace_safe
=
False
,
):
"""
Args:
...
...
@@ -547,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
openfold/model/triangular_attention.py
View file @
143ba486
...
...
@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"triangle! triangle!"
mha_inputs
=
{
...
...
@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module):
"biases"
:
biases
,
}
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
return
chunk_layer
(
partial
(
self
.
mha
,
...
...
@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module):
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
else
:
x
=
self
.
mha
(
...
...
openfold/model/triangular_multiplicative_update.py
View file @
143ba486
...
...
@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
forward
(
self
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
...
...
@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if
(
_
inplace
):
if
(
inplace
_safe
):
x
=
self
.
_inference_forward
(
z
,
mask
,
...
...
openfold/utils/chunk_utils.py
View file @
143ba486
...
...
@@ -415,8 +415,6 @@ class ChunkSizeTuner:
# Otherwise, we can reuse the precomputed value
consistent
=
False
print
(
consistent
)
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment