Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
143ba486
Commit
143ba486
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Refactor inplace operations, fix training
parent
f1402490
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
249 additions
and
137 deletions
+249
-137
openfold/model/embedders.py
openfold/model/embedders.py
+5
-6
openfold/model/evoformer.py
openfold/model/evoformer.py
+124
-56
openfold/model/model.py
openfold/model/model.py
+76
-42
openfold/model/msa.py
openfold/model/msa.py
+14
-7
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+2
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+4
-2
openfold/model/template.py
openfold/model/template.py
+18
-15
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+4
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+2
-2
openfold/utils/chunk_utils.py
openfold/utils/chunk_utils.py
+0
-2
No files found.
openfold/model/embedders.py
View file @
143ba486
...
@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module):
...
@@ -95,6 +95,7 @@ class InputEmbedder(nn.Module):
tf
:
torch
.
Tensor
,
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
...
@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module):
...
@@ -111,8 +112,6 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
"""
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# [*, N_res, c_z]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
...
@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -187,7 +186,7 @@ class RecyclingEmbedder(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
...
@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
...
@@ -205,13 +204,13 @@ class RecyclingEmbedder(nn.Module):
"""
"""
# [*, N, C_m]
# [*, N, C_m]
m_update
=
self
.
layer_norm_m
(
m
)
m_update
=
self
.
layer_norm_m
(
m
)
if
(
_
inplace
):
if
(
inplace
_safe
):
m
.
copy_
(
m_update
)
m
.
copy_
(
m_update
)
m_update
=
m
m_update
=
m
# [*, N, N, C_z]
# [*, N, N, C_z]
z_update
=
self
.
layer_norm_z
(
z
)
z_update
=
self
.
layer_norm_z
(
z
)
if
(
_
inplace
):
if
(
inplace
_safe
):
z
.
copy_
(
z_update
)
z
.
copy_
(
z_update
)
z_update
=
z
z_update
=
z
...
@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
...
@@ -237,7 +236,7 @@ class RecyclingEmbedder(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
d
=
self
.
linear
(
d
)
d
=
self
.
linear
(
d
)
z_update
=
add
(
z_update
,
d
,
_
inplace
)
z_update
=
add
(
z_update
,
d
,
inplace
_safe
)
return
m_update
,
z_update
return
m_update
,
z_update
...
...
openfold/model/evoformer.py
View file @
143ba486
...
@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -182,6 +182,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
...
@@ -197,9 +198,6 @@ class EvoformerBlockCore(nn.Module):
...
@@ -197,9 +198,6 @@ class EvoformerBlockCore(nn.Module):
m
,
z
=
input_tensors
m
,
z
=
input_tensors
# Need to dodge activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
m
=
add
(
m
=
add
(
m
,
m
,
self
.
msa_transition
(
self
.
msa_transition
(
...
@@ -215,7 +213,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -215,7 +213,7 @@ class EvoformerBlockCore(nn.Module):
m
,
z
=
input_tensors
m
,
z
=
input_tensors
opm
=
self
.
outer_product_mean
(
opm
=
self
.
outer_product_mean
(
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_
inplace
=
inplace_safe
m
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
inplace
_safe
=
inplace_safe
)
)
if
(
_offload_inference
and
inplace_safe
):
if
(
_offload_inference
and
inplace_safe
):
...
@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -230,7 +228,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update
=
self
.
tri_mul_out
(
tmu_update
=
self
.
tri_mul_out
(
z
,
z
,
mask
=
pair_mask
,
mask
=
pair_mask
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
_add_with_inplace
=
True
,
)
)
if
(
not
inplace_safe
):
if
(
not
inplace_safe
):
...
@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -243,7 +241,7 @@ class EvoformerBlockCore(nn.Module):
tmu_update
=
self
.
tri_mul_in
(
tmu_update
=
self
.
tri_mul_in
(
z
,
z
,
mask
=
pair_mask
,
mask
=
pair_mask
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
_add_with_inplace
=
True
,
_add_with_inplace
=
True
,
)
)
if
(
not
inplace_safe
):
if
(
not
inplace_safe
):
...
@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module):
...
@@ -259,7 +257,8 @@ class EvoformerBlockCore(nn.Module):
z
,
z
,
mask
=
pair_mask
,
mask
=
pair_mask
,
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
)
),
),
inplace
=
inplace_safe
,
inplace
=
inplace_safe
,
...
@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -277,6 +276,7 @@ class EvoformerBlockCore(nn.Module):
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
mask
=
pair_mask
.
transpose
(
-
1
,
-
2
),
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
)
),
),
inplace
=
inplace_safe
,
inplace
=
inplace_safe
,
...
@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module):
...
@@ -355,20 +355,27 @@ class EvoformerBlock(nn.Module):
)
)
def
forward
(
self
,
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_attn_chunk_size
is
None
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
m
,
z
=
input_tensors
m
=
add
(
m
,
m
=
add
(
m
,
...
@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module):
...
@@ -404,17 +411,13 @@ class EvoformerBlock(nn.Module):
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
_offload_inference
=
_offload_inference
,
)
)
if
(
inplace_safe
):
return
m
,
z
out
=
input_tensors
else
:
out
=
[
m
,
z
]
return
out
class
ExtraMSABlock
(
nn
.
Module
):
class
ExtraMSABlock
(
nn
.
Module
):
...
@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module):
...
@@ -477,22 +480,29 @@ class ExtraMSABlock(nn.Module):
)
)
def
forward
(
self
,
def
forward
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_offload_inference
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_offloadable_inputs
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
(
_attn_chunk_size
is
None
):
if
(
_attn_chunk_size
is
None
):
_attn_chunk_size
=
chunk_size
_attn_chunk_size
=
chunk_size
if
(
_offload_inference
and
inplace_safe
):
input_tensors
=
_offloadable_inputs
del
_offloadable_inputs
else
:
input_tensors
=
[
m
,
z
]
m
,
z
=
input_tensors
m
,
z
=
input_tensors
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# If function calls could speak...
m
=
add
(
m
,
m
=
add
(
m
,
self
.
msa_dropout_layer
(
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
self
.
msa_att_row
(
...
@@ -509,6 +519,9 @@ class ExtraMSABlock(nn.Module):
...
@@ -509,6 +519,9 @@ class ExtraMSABlock(nn.Module):
inplace
=
inplace_safe
,
inplace
=
inplace_safe
,
)
)
if
(
not
inplace_safe
):
input_tensors
=
[
m
,
z
]
del
m
,
z
del
m
,
z
def
fn
(
input_tensors
):
def
fn
(
input_tensors
):
...
@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -523,7 +536,7 @@ class ExtraMSABlock(nn.Module):
)
)
if
(
not
inplace_safe
):
if
(
not
inplace_safe
):
input_tensors
[
m
,
input_tensors
[
1
]]
input_tensors
=
[
m
,
input_tensors
[
1
]]
del
m
del
m
...
@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -533,6 +546,7 @@ class ExtraMSABlock(nn.Module):
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
_attn_chunk_size
=
_attn_chunk_size
,
_attn_chunk_size
=
_attn_chunk_size
,
_offload_inference
=
_offload_inference
,
_offload_inference
=
_offload_inference
,
...
@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module):
...
@@ -647,15 +661,16 @@ class EvoformerStack(nn.Module):
if
(
tune_chunk_size
):
if
(
tune_chunk_size
):
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
self
.
chunk_size_tuner
=
ChunkSizeTuner
()
def
_forward_list
(
self
,
def
_prep_blocks
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
m
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_lma
:
bool
,
_mask_trans
:
bool
=
True
,
msa_mask
:
Optional
[
torch
.
Tensor
],
_offload_inference
:
bool
=
False
,
pair_mask
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
inplace_safe
:
bool
,
_mask_trans
:
bool
,
):
blocks
=
[
blocks
=
[
partial
(
partial
(
b
,
b
,
...
@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module):
...
@@ -663,8 +678,8 @@ class EvoformerStack(nn.Module):
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
_offload_inference
=
_offload_inference
,
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
]
]
...
@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module):
...
@@ -677,11 +692,11 @@ class EvoformerStack(nn.Module):
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"evo"
)
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
representative_fn
=
blocks
[
0
],
# We don't want to write in-place during chunk tuning runs
# We don't want to write in-place during chunk tuning runs
args
=
(
[
t
.
clone
()
for
t
in
input_tensors
]
,),
args
=
(
m
.
clone
()
,
z
.
clone
()
,),
min_chunk_size
=
chunk_size
,
min_chunk_size
=
chunk_size
,
)
)
blocks
=
[
blocks
=
[
...
@@ -693,15 +708,42 @@ class EvoformerStack(nn.Module):
...
@@ -693,15 +708,42 @@ class EvoformerStack(nn.Module):
)
for
b
in
blocks
)
for
b
in
blocks
]
]
blocks_per_ckpt
=
self
.
blocks_per_ckpt
return
blocks
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
m
,
z
=
checkpoint_blocks
(
def
_forward_offload
(
self
,
blocks
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
args
=
input_tensors
,
msa_mask
:
torch
.
Tensor
,
blocks_per_ckpt
=
blocks_per_ckpt
,
pair_mask
:
torch
.
Tensor
,
)[
0
]
chunk_size
:
int
,
use_lma
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
()))
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# this function
m
=
input_tensors
[
0
],
z
=
input_tensors
[
1
],
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
)
for
b
in
blocks
:
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
del
m
,
z
m
,
z
=
input_tensors
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
...
@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module):
...
@@ -714,6 +756,7 @@ class EvoformerStack(nn.Module):
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
...
@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module):
...
@@ -738,15 +781,31 @@ class EvoformerStack(nn.Module):
s:
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
"""
return
self
.
_forward_list
(
blocks
=
self
.
_prep_blocks
(
[
m
,
z
],
m
=
m
,
msa_mask
=
msa_mask
,
z
=
z
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
(
not
torch
.
is_grad_enabled
()):
blocks_per_ckpt
=
None
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
blocks_per_ckpt
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
class
ExtraMSAStack
(
nn
.
Module
):
class
ExtraMSAStack
(
nn
.
Module
):
"""
"""
...
@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -769,7 +828,6 @@ class ExtraMSAStack(nn.Module):
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
clear_cache_between_blocks
:
bool
=
False
,
clear_cache_between_blocks
:
bool
=
False
,
chunk_msa_attn
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
tune_chunk_size
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module):
...
@@ -777,7 +835,6 @@ class ExtraMSAStack(nn.Module):
self
.
ckpt
=
ckpt
self
.
ckpt
=
ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
chunk_msa_attn
=
chunk_msa_attn
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
no_blocks
):
for
_
in
range
(
no_blocks
):
block
=
ExtraMSABlock
(
block
=
ExtraMSABlock
(
...
@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -794,7 +851,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
ckpt
=
ckpt
if
chunk_msa_attn
else
False
,
ckpt
=
False
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -810,6 +867,7 @@ class ExtraMSAStack(nn.Module):
use_lma
:
bool
,
use_lma
:
bool
,
msa_mask
:
Optional
[
torch
.
Tensor
],
msa_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
pair_mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
,
_mask_trans
:
bool
,
_mask_trans
:
bool
,
):
):
blocks
=
[
blocks
=
[
...
@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -819,6 +877,7 @@ class ExtraMSAStack(nn.Module):
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
)
for
b
in
self
.
blocks
]
]
...
@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module):
...
@@ -831,10 +890,12 @@ class ExtraMSAStack(nn.Module):
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
blocks
=
[
partial
(
clear_cache
,
b
)
for
b
in
blocks
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
print
(
"extra"
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
representative_fn
=
blocks
[
0
],
args
=
([
m
.
clone
(),
z
.
clone
()],),
# Tensors cloned to avoid getting written to in-place
# A corollary is that chunk size tuning should be disabled for
# large N, when z gets really big
args
=
(
m
.
clone
(),
z
.
clone
(),),
min_chunk_size
=
chunk_size
,
min_chunk_size
=
chunk_size
,
)
)
blocks
=
[
blocks
=
[
...
@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module):
...
@@ -848,16 +909,15 @@ class ExtraMSAStack(nn.Module):
return
blocks
return
blocks
def
_forward_
list
(
self
,
def
_forward_
offload
(
self
,
input_tensors
:
Sequence
[
torch
.
Tensor
],
input_tensors
:
Sequence
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_offload_inference
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
(
not
self
.
training
)
assert
(
not
(
self
.
training
or
torch
.
is_grad_enabled
())
)
blocks
=
self
.
_prep_blocks
(
blocks
=
self
.
_prep_blocks
(
# We are very careful not to create references to these tensors in
# We are very careful not to create references to these tensors in
# this function
# this function
...
@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module):
...
@@ -867,11 +927,17 @@ class ExtraMSAStack(nn.Module):
use_lma
=
use_lma
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
True
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
for
b
in
blocks
:
for
b
in
blocks
:
m
,
z
=
b
(
input_tensors
,
_offload_inference
=
_offload_inference
)
m
,
z
=
b
(
None
,
None
,
_offload_inference
=
True
,
_offloadable_inputs
=
input_tensors
,
)
input_tensors
[
0
]
=
m
input_tensors
[
0
]
=
m
input_tensors
[
1
]
=
z
input_tensors
[
1
]
=
z
del
m
,
z
del
m
,
z
...
@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -885,6 +951,7 @@ class ExtraMSAStack(nn.Module):
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
msa_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
pair_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module):
...
@@ -910,12 +977,13 @@ class ExtraMSAStack(nn.Module):
use_lma
=
use_lma
,
use_lma
=
use_lma
,
msa_mask
=
msa_mask
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
pair_mask
=
pair_mask
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
)
)
for
b
in
blocks
:
for
b
in
blocks
:
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
if
(
self
.
ckpt
and
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
(
m
,
z
)
)
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
else
:
m
,
z
=
b
(
m
,
z
)
m
,
z
=
b
(
m
,
z
)
...
...
openfold/model/model.py
View file @
143ba486
...
@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
...
@@ -107,19 +107,16 @@ class AlphaFold(nn.Module):
self
.
config
[
"heads"
],
self
.
config
[
"heads"
],
)
)
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
if
(
self
.
template_config
.
offload_templates
):
if
(
self
.
template_config
.
offload_templates
):
return
embed_templates_offload
(
return
embed_templates_offload
(
self
,
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
)
elif
(
self
.
template_config
.
average_templates
):
elif
(
self
.
template_config
.
average_templates
):
return
embed_templates_average
(
return
embed_templates_average
(
self
,
self
,
batch
,
z
,
pair_mask
,
templ_dim
batch
,
z
,
pair_mask
,
templ_dim
,
inplace_safe
=
inplace_safe
,
)
)
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds
=
[]
pair_embeds
=
[]
n
=
z
.
shape
[
-
2
]
n
=
z
.
shape
[
-
2
]
...
@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
...
@@ -168,6 +165,7 @@ class AlphaFold(nn.Module):
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
del
t_pair
del
t_pair
...
@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
...
@@ -186,6 +184,11 @@ class AlphaFold(nn.Module):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
],
dim
=-
1
)
>
0
)
ret
=
{}
ret
=
{}
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
if
self
.
config
.
template
.
embed_angles
:
if
self
.
config
.
template
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
template_angle_feat
=
build_template_angle_feat
(
batch
batch
...
@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
...
@@ -196,10 +199,6 @@ class AlphaFold(nn.Module):
ret
[
"template_angle_embedding"
]
=
a
ret
[
"template_angle_embedding"
]
=
a
ret
.
update
({
"template_pair_embedding"
:
t
})
del
t
return
ret
return
ret
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
def
iteration
(
self
,
feats
,
prevs
,
_recycle
=
True
):
...
@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
...
@@ -218,6 +217,9 @@ class AlphaFold(nn.Module):
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n
=
feats
[
"target_feat"
].
shape
[
-
2
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
n_seq
=
feats
[
"msa_feat"
].
shape
[
-
3
]
device
=
feats
[
"target_feat"
].
device
device
=
feats
[
"target_feat"
].
device
# Controls whether the model uses in-place operations throughout
# The dual condition accounts for activation checkpoints
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# Prep some features
# Prep some features
...
@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
...
@@ -233,10 +235,11 @@ class AlphaFold(nn.Module):
feats
[
"target_feat"
],
feats
[
"target_feat"
],
feats
[
"residue_index"
],
feats
[
"residue_index"
],
feats
[
"msa_feat"
],
feats
[
"msa_feat"
],
inplace_safe
=
inplace_safe
,
)
)
# Unpack the recycling embeddings. Removing them from the list allows
# Unpack the recycling embeddings. Removing them from the list allows
# them to be freed further down in this function
.
# them to be freed further down in this function
, saving memory
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
m_1_prev
,
z_prev
,
x_prev
=
reversed
([
prevs
.
pop
()
for
_
in
range
(
3
)])
# Initialize the recycling embeddings, if needs be
# Initialize the recycling embeddings, if needs be
...
@@ -263,6 +266,7 @@ class AlphaFold(nn.Module):
...
@@ -263,6 +266,7 @@ class AlphaFold(nn.Module):
feats
[
"aatype"
],
x_prev
,
None
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
# The recycling embedder is memory-intensive, so we offload first
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
cpu
()
m
=
m
.
cpu
()
z
=
z
.
cpu
()
z
=
z
.
cpu
()
...
@@ -273,7 +277,7 @@ class AlphaFold(nn.Module):
...
@@ -273,7 +277,7 @@ class AlphaFold(nn.Module):
m_1_prev
,
m_1_prev
,
z_prev
,
z_prev
,
x_prev
,
x_prev
,
_
inplace
=
inplace_safe
,
inplace
_safe
=
inplace_safe
,
)
)
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
...
@@ -286,6 +290,9 @@ class AlphaFold(nn.Module):
...
@@ -286,6 +290,9 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
# Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such
# that they can be offloaded later
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
...
@@ -298,6 +305,7 @@ class AlphaFold(nn.Module):
...
@@ -298,6 +305,7 @@ class AlphaFold(nn.Module):
z
,
z
,
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
.
to
(
dtype
=
z
.
dtype
),
no_batch_dims
,
no_batch_dims
,
inplace_safe
=
inplace_safe
,
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
...
@@ -306,7 +314,7 @@ class AlphaFold(nn.Module):
...
@@ -306,7 +314,7 @@ class AlphaFold(nn.Module):
inplace_safe
,
inplace_safe
,
)
)
if
self
.
config
.
template
.
embed
_angle
s
:
if
"template_angle_embedding"
in
template
_
embeds
:
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
m
=
torch
.
cat
(
m
=
torch
.
cat
(
[
m
,
template_embeds
[
"template_angle_embedding"
]],
[
m
,
template_embeds
[
"template_angle_embedding"
]],
...
@@ -325,29 +333,43 @@ class AlphaFold(nn.Module):
...
@@ -325,29 +333,43 @@ class AlphaFold(nn.Module):
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
if
(
self
.
globals
.
offload_inference
):
# To allow the extra MSA stack (and later the evoformer) to
# offload its inputs, we remove all references to them here
input_tensors
=
[
a
,
z
]
input_tensors
=
[
a
,
z
]
del
a
,
z
del
a
,
z
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
.
_forward_
list
(
z
=
self
.
extra_msa_stack
.
_forward_
offload
(
input_tensors
,
input_tensors
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
)
del
input_tensors
del
input_tensors
else
:
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
a
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
# Run MSA + pair embeddings through the trunk of the network
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
# s: [*, N, C_s]
if
(
self
.
globals
.
offload_inference
):
input_tensors
=
[
m
,
z
]
input_tensors
=
[
m
,
z
]
del
m
,
z
del
m
,
z
m
,
z
,
s
=
self
.
evoformer
.
_forward_
list
(
m
,
z
,
s
=
self
.
evoformer
.
_forward_
offload
(
input_tensors
,
input_tensors
,
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
...
@@ -357,6 +379,17 @@ class AlphaFold(nn.Module):
...
@@ -357,6 +379,17 @@ class AlphaFold(nn.Module):
)
)
del
input_tensors
del
input_tensors
else
:
m
,
z
,
s
=
self
.
evoformer
(
m
,
z
,
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"pair"
]
=
z
...
@@ -369,6 +402,7 @@ class AlphaFold(nn.Module):
...
@@ -369,6 +402,7 @@ class AlphaFold(nn.Module):
outputs
,
outputs
,
feats
[
"aatype"
],
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
inplace_safe
=
inplace_safe
,
_offload_inference
=
self
.
globals
.
offload_inference
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
...
...
openfold/model/msa.py
View file @
143ba486
...
@@ -117,10 +117,9 @@ class MSAAttention(nn.Module):
...
@@ -117,10 +117,9 @@ class MSAAttention(nn.Module):
def
_prep_inputs
(
self
,
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
mask
:
Optional
[
torch
.
Tensor
],
inplace_safe
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
_inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
if
mask
is
None
:
# [*, N_seq, N_res]
# [*, N_seq, N_res]
...
@@ -163,6 +162,7 @@ class MSAAttention(nn.Module):
...
@@ -163,6 +162,7 @@ class MSAAttention(nn.Module):
mask
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
chunk_logits
:
int
,
checkpoint
:
bool
,
checkpoint
:
bool
,
inplace_safe
:
bool
=
False
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
MSA attention with training-time chunking of the softmax computation.
MSA attention with training-time chunking of the softmax computation.
...
@@ -172,7 +172,9 @@ class MSAAttention(nn.Module):
...
@@ -172,7 +172,9 @@ class MSAAttention(nn.Module):
MSA_DIM
=
-
4
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
m
,
q
,
k
,
v
,
mask_bias
,
z
return
m
,
q
,
k
,
v
,
mask_bias
,
z
...
@@ -208,6 +210,7 @@ class MSAAttention(nn.Module):
...
@@ -208,6 +210,7 @@ class MSAAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -229,10 +232,14 @@ class MSAAttention(nn.Module):
...
@@ -229,10 +232,14 @@ class MSAAttention(nn.Module):
if
(
_chunk_logits
is
not
None
):
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
,
inplace_safe
=
inplace_safe
,
)
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
,
inplace_safe
=
inplace_safe
)
biases
=
[
mask_bias
]
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
if
(
z
is
not
None
):
...
...
openfold/model/outer_product_mean.py
View file @
143ba486
...
@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
...
@@ -97,7 +97,7 @@ class OuterProductMean(nn.Module):
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
...
@@ -137,7 +137,7 @@ class OuterProductMean(nn.Module):
norm
=
norm
+
self
.
eps
norm
=
norm
+
self
.
eps
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
if
(
_
inplace
):
if
(
inplace
_safe
):
outer
/=
norm
outer
/=
norm
else
:
else
:
outer
=
outer
/
norm
outer
=
outer
/
norm
...
...
openfold/model/structure_module.py
View file @
143ba486
...
@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module):
...
@@ -232,6 +232,7 @@ class InvariantPointAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
r
:
Rigid
,
r
:
Rigid
,
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
inplace_safe
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_offload_inference
:
bool
=
False
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
_z_reference_list
:
Optional
[
Sequence
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -248,7 +249,6 @@ class InvariantPointAttention(nn.Module):
...
@@ -248,7 +249,6 @@ class InvariantPointAttention(nn.Module):
Returns:
Returns:
[*, N_res, C_s] single representation update
[*, N_res, C_s] single representation update
"""
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
if
(
_offload_inference
and
inplace_safe
):
if
(
_offload_inference
and
inplace_safe
):
z
=
_z_reference_list
z
=
_z_reference_list
else
:
else
:
...
@@ -619,6 +619,7 @@ class StructureModule(nn.Module):
...
@@ -619,6 +619,7 @@ class StructureModule(nn.Module):
evoformer_output_dict
,
evoformer_output_dict
,
aatype
,
aatype
,
mask
=
None
,
mask
=
None
,
inplace_safe
=
False
,
_offload_inference
=
False
,
_offload_inference
=
False
,
):
):
"""
"""
...
@@ -674,6 +675,7 @@ class StructureModule(nn.Module):
...
@@ -674,6 +675,7 @@ class StructureModule(nn.Module):
z
,
z
,
rigids
,
rigids
,
mask
,
mask
,
inplace_safe
=
inplace_safe
,
_offload_inference
=
_offload_inference
,
_offload_inference
=
_offload_inference
,
_z_reference_list
=
z_reference_list
_z_reference_list
=
z_reference_list
)
)
...
...
openfold/model/template.py
View file @
143ba486
...
@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -201,8 +201,8 @@ class TemplatePairStackBlock(nn.Module):
mask
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
_inplace
:
bool
=
False
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
_attn_chunk_size
:
Optional
[
int
]
=
None
,
):
):
if
(
_attn_chunk_size
is
None
):
if
(
_attn_chunk_size
is
None
):
...
@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -214,6 +214,7 @@ class TemplatePairStackBlock(nn.Module):
single_templates_masks
=
[
single_templates_masks
=
[
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
m
.
unsqueeze
(
-
3
)
for
m
in
torch
.
unbind
(
mask
,
dim
=-
3
)
]
]
for
i
in
range
(
len
(
single_templates
)):
for
i
in
range
(
len
(
single_templates
)):
single
=
single_templates
[
i
]
single
=
single_templates
[
i
]
single_mask
=
single_templates_masks
[
i
]
single_mask
=
single_templates_masks
[
i
]
...
@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -225,9 +226,10 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
mask
=
single_mask
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
)
),
),
_
inplace
,
inplace
_safe
,
)
)
single
=
add
(
single
,
single
=
add
(
single
,
...
@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -237,18 +239,19 @@ class TemplatePairStackBlock(nn.Module):
chunk_size
=
_attn_chunk_size
,
chunk_size
=
_attn_chunk_size
,
mask
=
single_mask
,
mask
=
single_mask
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
)
),
),
_
inplace
,
inplace
_safe
,
)
)
tmu_update
=
self
.
tri_mul_out
(
tmu_update
=
self
.
tri_mul_out
(
single
,
single
,
mask
=
single_mask
,
mask
=
single_mask
,
_
inplace
=
_inplace
,
inplace_
safe
=
inplace
_safe
,
_add_with_inplace
=
True
,
_add_with_inplace
=
True
,
)
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
else
:
single
=
tmu_update
single
=
tmu_update
...
@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -258,10 +261,10 @@ class TemplatePairStackBlock(nn.Module):
tmu_update
=
self
.
tri_mul_in
(
tmu_update
=
self
.
tri_mul_in
(
single
,
single
,
mask
=
single_mask
,
mask
=
single_mask
,
_
inplace
=
_inplace
,
inplace_
safe
=
inplace
_safe
,
_add_with_inplace
=
True
,
_add_with_inplace
=
True
,
)
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single
=
single
+
self
.
dropout_row
(
tmu_update
)
single
=
single
+
self
.
dropout_row
(
tmu_update
)
else
:
else
:
single
=
tmu_update
single
=
tmu_update
...
@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -274,13 +277,13 @@ class TemplatePairStackBlock(nn.Module):
mask
=
single_mask
if
_mask_trans
else
None
,
mask
=
single_mask
if
_mask_trans
else
None
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
),
),
_
inplace
,
inplace
_safe
,
)
)
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
single_templates
[
i
]
=
single
single_templates
[
i
]
=
single
if
(
not
_
inplace
):
if
(
not
inplace
_safe
):
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
z
=
torch
.
cat
(
single_templates
,
dim
=-
4
)
return
z
return
z
...
@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module):
...
@@ -352,6 +355,7 @@ class TemplatePairStack(nn.Module):
mask
:
torch
.
tensor
,
mask
:
torch
.
tensor
,
chunk_size
:
int
,
chunk_size
:
int
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
):
):
"""
"""
...
@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module):
...
@@ -374,13 +378,14 @@ class TemplatePairStack(nn.Module):
mask
=
mask
,
mask
=
mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
use_lma
=
use_lma
,
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
_mask_trans
=
_mask_trans
,
_mask_trans
=
_mask_trans
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
()),
)
)
for
b
in
self
.
blocks
for
b
in
self
.
blocks
]
]
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
assert
(
not
self
.
training
)
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
representative_fn
=
blocks
[
0
],
args
=
(
t
.
clone
(),),
args
=
(
t
.
clone
(),),
...
@@ -411,6 +416,7 @@ def embed_templates_offload(
...
@@ -411,6 +416,7 @@ def embed_templates_offload(
pair_mask
,
pair_mask
,
templ_dim
,
templ_dim
,
template_chunk_size
=
256
,
template_chunk_size
=
256
,
inplace_safe
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -435,8 +441,6 @@ def embed_templates_offload(
...
@@ -435,8 +441,6 @@ def embed_templates_offload(
offloads the large template pair tensor to CPU. Slower but more frugal
offloads the large template pair tensor to CPU. Slower but more frugal
with GPU memory than the original. Useful for long-sequence inference.
with GPU memory than the original. Useful for long-sequence inference.
"""
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds_cpu
=
[]
pair_embeds_cpu
=
[]
n
=
z
.
shape
[
-
2
]
n
=
z
.
shape
[
-
2
]
...
@@ -519,6 +523,7 @@ def embed_templates_average(
...
@@ -519,6 +523,7 @@ def embed_templates_average(
pair_mask
,
pair_mask
,
templ_dim
,
templ_dim
,
templ_group_size
=
2
,
templ_group_size
=
2
,
inplace_safe
=
False
,
):
):
"""
"""
Args:
Args:
...
@@ -547,8 +552,6 @@ def embed_templates_average(
...
@@ -547,8 +552,6 @@ def embed_templates_average(
embedding, while its low memory footprint allows the number of templates
embedding, while its low memory footprint allows the number of templates
to scale almost indefinitely.
to scale almost indefinitely.
"""
"""
inplace_safe
=
not
(
model
.
training
or
torch
.
is_grad_enabled
())
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
n
=
z
.
shape
[
-
2
]
n
=
z
.
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
...
...
openfold/model/triangular_attention.py
View file @
143ba486
...
@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
...
@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"triangle! triangle!"
"triangle! triangle!"
mha_inputs
=
{
mha_inputs
=
{
...
@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module):
...
@@ -72,8 +73,6 @@ class TriangleAttention(nn.Module):
"biases"
:
biases
,
"biases"
:
biases
,
}
}
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
return
chunk_layer
(
return
chunk_layer
(
partial
(
partial
(
self
.
mha
,
self
.
mha
,
...
@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module):
...
@@ -92,6 +91,7 @@ class TriangleAttention(nn.Module):
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
inplace_safe
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module):
...
@@ -130,7 +130,8 @@ class TriangleAttention(nn.Module):
biases
,
biases
,
chunk_size
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
use_lma
=
use_lma
,
inplace_safe
=
inplace_safe
,
)
)
else
:
else
:
x
=
self
.
mha
(
x
=
self
.
mha
(
...
...
openfold/model/triangular_multiplicative_update.py
View file @
143ba486
...
@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -357,7 +357,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
def
forward
(
self
,
def
forward
(
self
,
z
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
_
inplace
:
bool
=
False
,
inplace
_safe
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_add_with_inplace
:
bool
=
False
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
_inplace_chunk_size
:
Optional
[
int
]
=
256
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
...
@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
...
@@ -370,7 +370,7 @@ class TriangleMultiplicativeUpdate(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] output tensor
[*, N_res, N_res, C_z] output tensor
"""
"""
if
(
_
inplace
):
if
(
inplace
_safe
):
x
=
self
.
_inference_forward
(
x
=
self
.
_inference_forward
(
z
,
z
,
mask
,
mask
,
...
...
openfold/utils/chunk_utils.py
View file @
143ba486
...
@@ -415,8 +415,6 @@ class ChunkSizeTuner:
...
@@ -415,8 +415,6 @@ class ChunkSizeTuner:
# Otherwise, we can reuse the precomputed value
# Otherwise, we can reuse the precomputed value
consistent
=
False
consistent
=
False
print
(
consistent
)
if
(
not
consistent
):
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
representative_fn
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment