Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
5aa21eae
Commit
5aa21eae
authored
Jun 21, 2022
by
Gustaf Ahdritz
Browse files
Move chunking code, add memory optimizations
parent
fe9ad07e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
565 additions
and
441 deletions
+565
-441
openfold/model/embedders.py
openfold/model/embedders.py
+19
-5
openfold/model/model.py
openfold/model/model.py
+34
-19
openfold/model/msa.py
openfold/model/msa.py
+36
-18
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+1
-1
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+4
-4
openfold/model/primitives.py
openfold/model/primitives.py
+1
-1
openfold/model/template.py
openfold/model/template.py
+5
-3
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+33
-16
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+2
-1
openfold/utils/chunk_utils.py
openfold/utils/chunk_utils.py
+428
-0
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+2
-373
No files found.
openfold/model/embedders.py
View file @
5aa21eae
...
@@ -81,9 +81,14 @@ class InputEmbedder(nn.Module):
...
@@ -81,9 +81,14 @@ class InputEmbedder(nn.Module):
d
=
ri
[...,
None
]
-
ri
[...,
None
,
:]
d
=
ri
[...,
None
]
-
ri
[...,
None
,
:]
boundaries
=
torch
.
arange
(
boundaries
=
torch
.
arange
(
start
=-
self
.
relpos_k
,
end
=
self
.
relpos_k
+
1
,
device
=
d
.
device
start
=-
self
.
relpos_k
,
end
=
self
.
relpos_k
+
1
,
device
=
d
.
device
)
)
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
reshaped_bins
=
boundaries
.
view
(((
1
,)
*
len
(
d
.
shape
))
+
(
len
(
boundaries
),))
return
self
.
linear_relpos
(
oh
)
d
=
d
[...,
None
]
-
reshaped_bins
d
=
torch
.
abs
(
d
)
d
=
torch
.
argmin
(
d
,
dim
=-
1
)
d
=
nn
.
functional
.
one_hot
(
d
,
num_classes
=
len
(
boundaries
)).
float
()
d
=
d
.
to
(
ri
.
dtype
)
return
self
.
linear_relpos
(
d
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -106,13 +111,22 @@ class InputEmbedder(nn.Module):
...
@@ -106,13 +111,22 @@ class InputEmbedder(nn.Module):
[*, N_res, N_res, C_z] pair embedding
[*, N_res, N_res, C_z] pair embedding
"""
"""
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
# [*, N_res, c_z]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
self
.
relpos
(
ri
.
type
(
tf_emb_i
.
dtype
))
pair_emb
=
pair_emb
+
self
.
relpos
(
ri
.
type
(
pair_emb
.
dtype
))
pair_emb
=
add
(
pair_emb
,
tf_emb_i
[...,
None
,
:],
inplace
=
inplace_safe
)
pair_emb
=
add
(
pair_emb
,
tf_emb_j
[...,
None
,
:,
:],
inplace
=
inplace_safe
)
# [*, N_clust, N_res, c_m]
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
n_clust
=
msa
.
shape
[
-
3
]
...
...
openfold/model/model.py
View file @
5aa21eae
...
@@ -263,22 +263,29 @@ class AlphaFold(nn.Module):
...
@@ -263,22 +263,29 @@ class AlphaFold(nn.Module):
feats
[
"aatype"
],
x_prev
,
None
feats
[
"aatype"
],
x_prev
,
None
).
to
(
dtype
=
z
.
dtype
)
).
to
(
dtype
=
z
.
dtype
)
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
cpu
()
z
=
z
.
cpu
()
# m_1_prev_emb: [*, N, C_m]
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev_emb
,
z_prev_emb
=
self
.
recycling_embedder
(
m_1_prev
,
m_1_prev
,
z_prev
,
z_prev
,
x_prev
,
x_prev
,
_inplace
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
,
_inplace
=
inplace_safe
,
)
)
if
(
self
.
globals
.
offload_inference
and
inplace_safe
):
m
=
m
.
to
(
m_1_prev_emb
.
device
)
z
=
z
.
to
(
z_prev
.
device
)
# [*, S_c, N, C_m]
# [*, S_c, N, C_m]
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
# [*, N, N, C_z]
# [*, N, N, C_z]
z
+
=
z_prev_emb
z
=
add
(
z
,
z_prev_emb
,
inplace
=
inplace_safe
)
# This matters during inference with large N
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
del
m_1_prev
,
z_prev
,
x_prev
,
m_1_prev_emb
,
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
...
@@ -317,44 +324,52 @@ class AlphaFold(nn.Module):
...
@@ -317,44 +324,52 @@ class AlphaFold(nn.Module):
if
self
.
config
.
extra_msa
.
enabled
:
if
self
.
config
.
extra_msa
.
enabled
:
# [*, S_e, N, C_e]
# [*, S_e, N, C_e]
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
a
=
self
.
extra_msa_embedder
(
build_extra_msa_feat
(
feats
))
input_tensors
=
[
a
,
z
]
del
a
,
z
# [*, N, N, C_z]
# [*, N, N, C_z]
z
=
self
.
extra_msa_stack
(
z
=
self
.
extra_msa_stack
.
_forward_list
(
a
,
input_tensors
,
z
,
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
m
.
dtype
),
msa_mask
=
feats
[
"extra_msa_mask"
].
to
(
dtype
=
a
.
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
m
.
dtype
),
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_offload_inference
=
self
.
globals
.
offload_inference
,
)
)
del
a
del
input_tensors
# Run MSA + pair embeddings through the trunk of the network
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
# s: [*, N, C_s]
m
,
z
,
s
=
self
.
evoformer
(
input_tensors
=
[
m
,
z
]
m
,
del
m
,
z
z
,
m
,
z
,
s
=
self
.
evoformer
.
_forward_list
(
msa_mask
=
msa_mask
.
to
(
dtype
=
m
.
dtype
),
input_tensors
,
pair_mask
=
pair_mask
.
to
(
dtype
=
z
.
dtype
),
msa_mask
=
msa_mask
.
to
(
dtype
=
input_tensors
[
0
].
dtype
),
pair_mask
=
pair_mask
.
to
(
dtype
=
input_tensors
[
1
].
dtype
),
chunk_size
=
self
.
globals
.
chunk_size
,
chunk_size
=
self
.
globals
.
chunk_size
,
use_lma
=
self
.
globals
.
use_lma
,
use_lma
=
self
.
globals
.
use_lma
,
_mask_trans
=
self
.
config
.
_mask_trans
,
_mask_trans
=
self
.
config
.
_mask_trans
,
)
)
del
input_tensors
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"msa"
]
=
m
[...,
:
n_seq
,
:,
:]
outputs
[
"pair"
]
=
z
outputs
[
"pair"
]
=
z
outputs
[
"single"
]
=
s
outputs
[
"single"
]
=
s
del
z
# Predict 3D structure
# Predict 3D structure
outputs
[
"sm"
]
=
self
.
structure_module
(
outputs
[
"sm"
]
=
self
.
structure_module
(
s
,
outputs
,
z
,
feats
[
"aatype"
],
feats
[
"aatype"
],
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
mask
=
feats
[
"seq_mask"
].
to
(
dtype
=
s
.
dtype
),
_offload_inference
=
self
.
globals
.
offload_inference
,
)
)
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"final_atom_positions"
]
=
atom14_to_atom37
(
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
outputs
[
"sm"
][
"positions"
][
-
1
],
feats
...
@@ -368,7 +383,7 @@ class AlphaFold(nn.Module):
...
@@ -368,7 +383,7 @@ class AlphaFold(nn.Module):
m_1_prev
=
m
[...,
0
,
:,
:]
m_1_prev
=
m
[...,
0
,
:,
:]
# [*, N, N, C_z]
# [*, N, N, C_z]
z_prev
=
z
z_prev
=
outputs
[
"pair"
]
# [*, N, 3]
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
x_prev
=
outputs
[
"final_atom_positions"
]
...
...
openfold/model/msa.py
View file @
5aa21eae
...
@@ -26,8 +26,8 @@ from openfold.model.primitives import (
...
@@ -26,8 +26,8 @@ from openfold.model.primitives import (
_attention_chunked_trainable
,
_attention_chunked_trainable
,
)
)
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
...
@@ -94,16 +94,20 @@ class MSAAttention(nn.Module):
...
@@ -94,16 +94,20 @@ class MSAAttention(nn.Module):
use_memory_efficient_kernel
:
bool
,
use_memory_efficient_kernel
:
bool
,
use_lma
:
bool
,
use_lma
:
bool
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
mha
=
partial
(
def
fn
(
m
,
biases
):
self
.
mha
,
m
=
self
.
layer_norm_m
(
m
)
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
return
self
.
mha
(
use_lma
=
use_lma
,
q_x
=
m
,
)
kv_x
=
m
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
,
)
return
chunk_layer
(
return
chunk_layer
(
mha
,
fn
,
{
{
"q_x"
:
m
,
"m"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
,
"biases"
:
biases
,
},
},
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
...
@@ -115,9 +119,8 @@ class MSAAttention(nn.Module):
...
@@ -115,9 +119,8 @@ class MSAAttention(nn.Module):
z
:
Optional
[
torch
.
Tensor
],
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
mask
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, N_seq, N_res, C_m]
_inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
m
=
self
.
layer_norm_m
(
m
)
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
n_seq
,
n_res
=
m
.
shape
[
-
3
:
-
1
]
if
mask
is
None
:
if
mask
is
None
:
# [*, N_seq, N_res]
# [*, N_seq, N_res]
...
@@ -133,11 +136,20 @@ class MSAAttention(nn.Module):
...
@@ -133,11 +136,20 @@ class MSAAttention(nn.Module):
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
linear_z
is
not
None
# TorchScript
self
.
linear_z
is
not
None
# TorchScript
):
):
# [*, N_res, N_res, C_z]
chunks
=
[]
z
=
self
.
layer_norm_z
(
z
)
for
i
in
range
(
0
,
z
.
shape
[
-
3
],
256
):
z_chunk
=
z
[...,
i
:
i
+
256
,
:,
:]
# [*, N_res, N_res, C_z]
z_chunk
=
self
.
layer_norm_z
(
z_chunk
)
# [*, N_res, N_res, no_heads]
# [*, N_res, N_res, no_heads]
z
=
self
.
linear_z
(
z
)
z_chunk
=
self
.
linear_z
(
z_chunk
)
chunks
.
append
(
z_chunk
)
z
=
torch
.
cat
(
chunks
,
dim
=-
3
)
# [*, 1, no_heads, N_res, N_res]
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
...
@@ -376,8 +388,13 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -376,8 +388,13 @@ class MSAColumnGlobalAttention(nn.Module):
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
"mask"
:
mask
,
}
}
def
fn
(
m
,
mask
):
m
=
self
.
layer_norm_m
(
m
)
return
self
.
global_attention
(
m
,
mask
,
use_lma
=
use_lma
)
return
chunk_layer
(
return
chunk_layer
(
partial
(
self
.
global_attention
,
use_lma
=
use_lma
)
,
fn
,
mha_input
,
mha_input
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
...
@@ -405,11 +422,12 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -405,11 +422,12 @@ class MSAColumnGlobalAttention(nn.Module):
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
# [*, N_res, N_seq, C_in]
# [*, N_res, N_seq, C_in]
m
=
self
.
layer_norm_m
(
m
)
#
m = self.layer_norm_m(m)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
,
use_lma
=
use_lma
)
m
=
self
.
_chunk
(
m
,
mask
,
chunk_size
,
use_lma
=
use_lma
)
else
:
else
:
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
,
use_lma
=
use_lma
)
m
=
self
.
global_attention
(
m
=
m
,
mask
=
mask
,
use_lma
=
use_lma
)
# [*, N_seq, N_res, C_in]
# [*, N_seq, N_res, C_in]
...
...
openfold/model/outer_product_mean.py
View file @
5aa21eae
...
@@ -20,7 +20,7 @@ import torch
...
@@ -20,7 +20,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
from
openfold.utils.
tensor
_utils
import
chunk_layer
from
openfold.utils.
chunk
_utils
import
chunk_layer
class
OuterProductMean
(
nn
.
Module
):
class
OuterProductMean
(
nn
.
Module
):
...
...
openfold/model/pair_transition.py
View file @
5aa21eae
...
@@ -18,7 +18,7 @@ import torch
...
@@ -18,7 +18,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.
tensor
_utils
import
chunk_layer
from
openfold.utils.
chunk
_utils
import
chunk_layer
class
PairTransition
(
nn
.
Module
):
class
PairTransition
(
nn
.
Module
):
...
@@ -46,6 +46,9 @@ class PairTransition(nn.Module):
...
@@ -46,6 +46,9 @@ class PairTransition(nn.Module):
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
self
.
linear_2
=
Linear
(
self
.
n
*
self
.
c_z
,
c_z
,
init
=
"final"
)
def
_transition
(
self
,
z
,
mask
):
def
_transition
(
self
,
z
,
mask
):
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
# [*, N_res, N_res, C_hidden]
# [*, N_res, N_res, C_hidden]
z
=
self
.
linear_1
(
z
)
z
=
self
.
linear_1
(
z
)
z
=
self
.
relu
(
z
)
z
=
self
.
relu
(
z
)
...
@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
...
@@ -88,9 +91,6 @@ class PairTransition(nn.Module):
# [*, N_res, N_res, 1]
# [*, N_res, N_res, 1]
mask
=
mask
.
unsqueeze
(
-
1
)
mask
=
mask
.
unsqueeze
(
-
1
)
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
z
=
self
.
_chunk
(
z
,
mask
,
chunk_size
)
else
:
else
:
...
...
openfold/model/primitives.py
View file @
5aa21eae
...
@@ -23,11 +23,11 @@ import torch.nn as nn
...
@@ -23,11 +23,11 @@ import torch.nn as nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.chunk_utils
import
_chunk_slice
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.kernel.attention_core
import
attention_core
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
_chunk_slice
,
)
)
...
...
openfold/model/template.py
View file @
5aa21eae
...
@@ -34,14 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
...
@@ -34,14 +34,16 @@ from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationIncoming
,
TriangleMultiplicationIncoming
,
)
)
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.checkpointing
import
checkpoint_blocks
from
openfold.utils.chunk_utils
import
(
chunk_layer
,
ChunkSizeTuner
,
)
from
openfold.utils.feats
import
(
from
openfold.utils.feats
import
(
build_template_angle_feat
,
build_template_angle_feat
,
build_template_pair_feat
,
build_template_pair_feat
,
)
)
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
add
,
add
,
chunk_layer
,
ChunkSizeTuner
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
tensor_tree_map
,
tensor_tree_map
,
...
@@ -381,7 +383,7 @@ class TemplatePairStack(nn.Module):
...
@@ -381,7 +383,7 @@ class TemplatePairStack(nn.Module):
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
if
(
chunk_size
is
not
None
and
self
.
chunk_size_tuner
is
not
None
):
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
tuned_chunk_size
=
self
.
chunk_size_tuner
.
tune_chunk_size
(
representative_fn
=
blocks
[
0
],
representative_fn
=
blocks
[
0
],
args
=
(
t
,),
args
=
(
t
.
clone
()
,),
min_chunk_size
=
chunk_size
,
min_chunk_size
=
chunk_size
,
)
)
blocks
=
[
blocks
=
[
...
...
openfold/model/triangular_attention.py
View file @
5aa21eae
...
@@ -21,8 +21,8 @@ import torch
...
@@ -21,8 +21,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.model.primitives
import
Linear
,
LayerNorm
,
Attention
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
flatten_final_dims
,
flatten_final_dims
,
)
)
...
@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
...
@@ -30,7 +30,7 @@ from openfold.utils.tensor_utils import (
class
TriangleAttention
(
nn
.
Module
):
class
TriangleAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
starting
,
inf
=
1e9
self
,
c_in
,
c_hidden
,
no_heads
,
starting
=
True
,
inf
=
1e9
):
):
"""
"""
Args:
Args:
...
@@ -62,24 +62,35 @@ class TriangleAttention(nn.Module):
...
@@ -62,24 +62,35 @@ class TriangleAttention(nn.Module):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
biases
:
List
[
torch
.
Tensor
],
chunk_size
:
int
,
chunk_size
:
int
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"triangle! triangle!"
mha_inputs
=
{
mha_inputs
=
{
"q_x"
:
x
,
"q_x"
:
x
,
"kv_x"
:
x
,
"kv_x"
:
x
,
"biases"
:
biases
,
"biases"
:
biases
,
}
}
inplace_safe
=
not
(
self
.
training
or
torch
.
is_grad_enabled
())
return
chunk_layer
(
return
chunk_layer
(
partial
(
self
.
mha
,
use_lma
=
use_lma
),
partial
(
self
.
mha
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
),
mha_inputs
,
mha_inputs
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
x
.
shape
[:
-
2
]),
_out
=
x
if
inplace_safe
else
None
,
)
)
def
forward
(
self
,
def
forward
(
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
use_memory_efficient_kernel
:
bool
=
False
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
...
@@ -88,15 +99,14 @@ class TriangleAttention(nn.Module):
...
@@ -88,15 +99,14 @@ class TriangleAttention(nn.Module):
[*, I, J, C_in] input tensor (e.g. the pair representation)
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
Returns:
[*, I, J, C_in] output tensor
[*, I, J, C_in] output tensor
"""
"""
if
mask
is
None
:
if
mask
is
None
:
# [*, I, J]
# [*, I, J]
mask
=
x
.
new_ones
(
mask
=
x
.
new_ones
(
x
.
shape
[:
-
1
],
x
.
shape
[:
-
1
],
)
)
# Shape annotations assume self.starting. Else, I and J are flipped
if
(
not
self
.
starting
):
if
not
self
.
starting
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
mask
=
mask
.
transpose
(
-
1
,
-
2
)
...
@@ -115,27 +125,34 @@ class TriangleAttention(nn.Module):
...
@@ -115,27 +125,34 @@ class TriangleAttention(nn.Module):
biases
=
[
mask_bias
,
triangle_bias
]
biases
=
[
mask_bias
,
triangle_bias
]
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
,
use_lma
=
use_lma
)
x
=
self
.
_chunk
(
x
,
biases
,
chunk_size
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
)
else
:
else
:
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
use_lma
=
use_lma
)
x
=
self
.
mha
(
q_x
=
x
,
kv_x
=
x
,
biases
=
biases
,
use_memory_efficient_kernel
=
use_memory_efficient_kernel
,
use_lma
=
use_lma
)
if
not
self
.
starting
:
if
(
not
self
.
starting
)
:
x
=
x
.
transpose
(
-
2
,
-
3
)
x
=
x
.
transpose
(
-
2
,
-
3
)
return
x
return
x
class
TriangleAttentionStartingNode
(
TriangleAttention
):
# Implements Algorithm 13
"""
TriangleAttentionStartingNode
=
TriangleAttention
Implements Algorithm 13.
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
True
)
class
TriangleAttentionEndingNode
(
TriangleAttention
):
class
TriangleAttentionEndingNode
(
TriangleAttention
):
"""
"""
Implements Algorithm 14.
Implements Algorithm 14.
"""
"""
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
__init__
=
partialmethod
(
TriangleAttention
.
__init__
,
starting
=
False
)
openfold/model/triangular_multiplicative_update.py
View file @
5aa21eae
...
@@ -20,7 +20,8 @@ import torch
...
@@ -20,7 +20,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.model.primitives
import
Linear
,
LayerNorm
from
openfold.utils.tensor_utils
import
add
,
chunk_layer
,
permute_final_dims
from
openfold.utils.chunk_utils
import
chunk_layer
from
openfold.utils.tensor_utils
import
add
,
permute_final_dims
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
class
TriangleMultiplicativeUpdate
(
nn
.
Module
):
...
...
openfold/utils/chunk_utils.py
0 → 100644
View file @
5aa21eae
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
math
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
import
torch
from
openfold.utils.tensor_utils
import
(
tree_map
,
tensor_tree_map
,
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
l
):
tally
=
1
for
i
in
range
(
len
(
l
)):
reversed_idx
=
-
1
*
(
i
+
1
)
l
[
reversed_idx
]
*=
tally
tally
=
l
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
_out
:
Any
=
None
,
_add_into_out
:
bool
=
False
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
_prep_inputs
(
t
):
if
(
not
low_mem
):
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
prepped_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
prepped_outputs
=
None
if
(
_out
is
not
None
):
reshape_fn
=
lambda
t
:
t
.
view
([
-
1
]
+
list
(
t
.
shape
[
no_batch_dims
:]))
prepped_outputs
=
tensor_tree_map
(
reshape_fn
,
_out
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
prepped_outputs
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
if
(
_add_into_out
):
v
[
i
:
i
+
chunk_size
]
+=
d2
[
k
]
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
if
(
_add_into_out
):
x1
[
i
:
i
+
chunk_size
]
+=
x2
else
:
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
if
(
_add_into_out
):
out
[
i
:
i
+
chunk_size
]
+=
output_chunk
else
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
class
ChunkSizeTuner
:
def
__init__
(
self
,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size
=
256
,
):
self
.
max_chunk_size
=
max_chunk_size
self
.
cached_chunk_size
=
None
self
.
cached_arg_data
=
None
def
_determine_favorable_chunk_size
(
self
,
fn
,
args
,
min_chunk_size
):
logging
.
info
(
"Tuning chunk size..."
)
if
(
min_chunk_size
>=
self
.
max_chunk_size
):
return
min_chunk_size
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
min_chunk_size
]
+
candidates
def
test_chunk_size
(
chunk_size
):
try
:
with
torch
.
no_grad
():
fn
(
*
args
,
chunk_size
=
chunk_size
)
return
True
except
RuntimeError
:
return
False
min_viable_chunk_size_index
=
0
i
=
len
(
candidates
)
-
1
while
i
>
min_viable_chunk_size_index
:
viable
=
test_chunk_size
(
candidates
[
i
])
if
(
not
viable
):
i
=
(
min_viable_chunk_size_index
+
i
)
//
2
else
:
min_viable_chunk_size_index
=
i
i
=
(
i
+
len
(
candidates
)
-
1
)
//
2
return
candidates
[
min_viable_chunk_size_index
]
def
_compare_arg_caches
(
self
,
ac1
,
ac2
):
consistent
=
True
for
a1
,
a2
in
zip
(
ac1
,
ac2
):
assert
(
type
(
ac1
)
==
type
(
ac2
))
if
(
type
(
ac1
)
is
list
or
type
(
ac1
)
is
tuple
):
consistent
&=
self
.
_compare_arg_caches
(
a1
,
a2
)
elif
(
type
(
ac1
)
is
dict
):
a1_items
=
[
v
for
_
,
v
in
sorted
(
a1
.
items
(),
key
=
lambda
x
:
x
[
0
])
]
a2_items
=
[
v
for
_
,
v
in
sorted
(
a2
.
items
(),
key
=
lambda
x
:
x
[
0
])
]
consistent
&=
self
.
_compare_arg_caches
(
a1_items
,
a2_items
)
else
:
consistent
&=
a1
==
a2
return
consistent
def
tune_chunk_size
(
self
,
representative_fn
:
Callable
,
args
:
Tuple
[
Any
],
min_chunk_size
:
int
,
)
->
int
:
consistent
=
True
remove_tensors
=
lambda
a
:
a
.
shape
if
type
(
a
)
is
torch
.
Tensor
else
a
arg_data
=
tree_map
(
remove_tensors
,
args
,
object
)
if
(
self
.
cached_arg_data
is
not
None
):
# If args have changed shape/value, we need to re-tune
assert
(
len
(
self
.
cached_arg_data
)
==
len
(
arg_data
))
consistent
=
self
.
_compare_arg_caches
(
self
.
cached_arg_data
,
arg_data
)
else
:
# Otherwise, we can reuse the precomputed value
consistent
=
False
print
(
consistent
)
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
args
,
min_chunk_size
,
)
self
.
cached_arg_data
=
arg_data
return
self
.
cached_chunk_size
openfold/utils/tensor_utils.py
View file @
5aa21eae
...
@@ -15,10 +15,10 @@
...
@@ -15,10 +15,10 @@
from
functools
import
partial
from
functools
import
partial
import
logging
import
logging
import
math
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
,
Sequence
,
Optional
def
add
(
m1
,
m2
,
inplace
):
def
add
(
m1
,
m2
,
inplace
):
...
@@ -119,374 +119,3 @@ def tree_map(fn, tree, leaf_type):
...
@@ -119,374 +119,3 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
_fetch_dims
(
tree
):
shapes
=
[]
tree_type
=
type
(
tree
)
if
tree_type
is
dict
:
for
v
in
tree
.
values
():
shapes
.
extend
(
_fetch_dims
(
v
))
elif
tree_type
is
list
or
tree_type
is
tuple
:
for
t
in
tree
:
shapes
.
extend
(
_fetch_dims
(
t
))
elif
tree_type
is
torch
.
Tensor
:
shapes
.
append
(
tree
.
shape
)
else
:
raise
ValueError
(
"Not supported"
)
return
shapes
@
torch
.
jit
.
ignore
def
_flat_idx_to_idx
(
flat_idx
:
int
,
dims
:
Tuple
[
int
],
)
->
Tuple
[
int
]:
idx
=
[]
for
d
in
reversed
(
dims
):
idx
.
append
(
flat_idx
%
d
)
flat_idx
=
flat_idx
//
d
return
tuple
(
reversed
(
idx
))
@
torch
.
jit
.
ignore
def
_get_minimal_slice_set
(
start
:
Sequence
[
int
],
end
:
Sequence
[
int
],
dims
:
int
,
start_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
end_edges
:
Optional
[
Sequence
[
bool
]]
=
None
,
)
->
Sequence
[
Tuple
[
int
]]:
"""
Produces an ordered sequence of tensor slices that, when used in
sequence on a tensor with shape dims, yields tensors that contain every
leaf in the contiguous range [start, end]. Care is taken to yield a
short sequence of slices, and perhaps even the shortest possible (I'm
pretty sure it's the latter).
end is INCLUSIVE.
"""
# start_edges and end_edges both indicate whether, starting from any given
# dimension, the start/end index is at the top/bottom edge of the
# corresponding tensor, modeled as a tree
def
reduce_edge_list
(
l
):
tally
=
1
for
i
in
range
(
len
(
l
)):
reversed_idx
=
-
1
*
(
i
+
1
)
l
[
reversed_idx
]
*=
tally
tally
=
l
[
reversed_idx
]
if
(
start_edges
is
None
):
start_edges
=
[
s
==
0
for
s
in
start
]
reduce_edge_list
(
start_edges
)
if
(
end_edges
is
None
):
end_edges
=
[
e
==
(
d
-
1
)
for
e
,
d
in
zip
(
end
,
dims
)]
reduce_edge_list
(
end_edges
)
# Base cases. Either start/end are empty and we're done, or the final,
# one-dimensional tensor can be simply sliced
if
(
len
(
start
)
==
0
):
return
[
tuple
()]
elif
(
len
(
start
)
==
1
):
return
[(
slice
(
start
[
0
],
end
[
0
]
+
1
),)]
slices
=
[]
path
=
[]
# Dimensions common to start and end can be selected directly
for
s
,
e
in
zip
(
start
,
end
):
if
(
s
==
e
):
path
.
append
(
slice
(
s
,
s
+
1
))
else
:
break
path
=
tuple
(
path
)
divergence_idx
=
len
(
path
)
# start == end, and we're done
if
(
divergence_idx
==
len
(
dims
)):
return
[
tuple
(
path
)]
def
upper
():
sdi
=
start
[
divergence_idx
]
return
[
path
+
(
slice
(
sdi
,
sdi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
start
[
divergence_idx
+
1
:],
[
d
-
1
for
d
in
dims
[
divergence_idx
+
1
:]],
dims
[
divergence_idx
+
1
:],
start_edges
=
start_edges
[
divergence_idx
+
1
:],
end_edges
=
[
1
for
_
in
end_edges
[
divergence_idx
+
1
:]]
)
]
def
lower
():
edi
=
end
[
divergence_idx
]
return
[
path
+
(
slice
(
edi
,
edi
+
1
),)
+
s
for
s
in
_get_minimal_slice_set
(
[
0
for
_
in
start
[
divergence_idx
+
1
:]],
end
[
divergence_idx
+
1
:],
dims
[
divergence_idx
+
1
:],
start_edges
=
[
1
for
_
in
start_edges
[
divergence_idx
+
1
:]],
end_edges
=
end_edges
[
divergence_idx
+
1
:],
)
]
# If both start and end are at the edges of the subtree rooted at
# divergence_idx, we can just select the whole subtree at once
if
(
start_edges
[
divergence_idx
]
and
end_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]
+
1
),)
)
# If just start is at the edge, we can grab almost all of the subtree,
# treating only the ragged bottom edge as an edge case
elif
(
start_edges
[
divergence_idx
]):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
],
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
# Analogous to the previous case, but the top is ragged this time
elif
(
end_edges
[
divergence_idx
]):
slices
.
extend
(
upper
())
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]
+
1
),)
)
# If both sides of the range are ragged, we need to handle both sides
# separately. If there's contiguous meat in between them, we can index it
# in one big chunk
else
:
slices
.
extend
(
upper
())
middle_ground
=
end
[
divergence_idx
]
-
start
[
divergence_idx
]
if
(
middle_ground
>
1
):
slices
.
append
(
path
+
(
slice
(
start
[
divergence_idx
]
+
1
,
end
[
divergence_idx
]),)
)
slices
.
extend
(
lower
())
return
[
tuple
(
s
)
for
s
in
slices
]
@
torch
.
jit
.
ignore
def
_chunk_slice
(
t
:
torch
.
Tensor
,
flat_start
:
int
,
flat_end
:
int
,
no_batch_dims
:
int
,
)
->
torch
.
Tensor
:
"""
Equivalent to
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
but without the need for the initial reshape call, which can be
memory-intensive in certain situations. The only reshape operations
in this function are performed on sub-tensors that scale with
(flat_end - flat_start), the chunk size.
"""
batch_dims
=
t
.
shape
[:
no_batch_dims
]
start_idx
=
list
(
_flat_idx_to_idx
(
flat_start
,
batch_dims
))
# _get_minimal_slice_set is inclusive
end_idx
=
list
(
_flat_idx_to_idx
(
flat_end
-
1
,
batch_dims
))
# Get an ordered list of slices to perform
slices
=
_get_minimal_slice_set
(
start_idx
,
end_idx
,
batch_dims
,
)
sliced_tensors
=
[
t
[
s
]
for
s
in
slices
]
return
torch
.
cat
(
[
s
.
view
((
-
1
,)
+
t
.
shape
[
no_batch_dims
:])
for
s
in
sliced_tensors
]
)
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
low_mem
:
bool
=
False
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
Layer outputs and inputs are assumed to be simple "pytrees,"
consisting only of (arbitrarily nested) lists, tuples, and dicts with
torch.Tensor leaves.
Args:
layer:
The layer to be applied chunk-wise
inputs:
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
indexing of all batch dimensions simultaneously (s.t. the
number of sub-batches is the product of the batch dimensions).
no_batch_dims:
How many of the initial dimensions of each input tensor can
be considered batch dimensions.
low_mem:
Avoids flattening potentially large input tensors. Unnecessary
in most cases, and is ever so slightly slower than the default
setting.
Returns:
The reassembled output of the layer on the inputs.
"""
if
not
(
len
(
inputs
)
>
0
):
raise
ValueError
(
"Must provide at least one input"
)
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
_fetch_dims
(
inputs
)]
orig_batch_dims
=
tuple
([
max
(
s
)
for
s
in
zip
(
*
initial_dims
)])
def
_prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
if
(
not
low_mem
):
if
not
sum
(
t
.
shape
[:
no_batch_dims
])
==
no_batch_dims
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
t
=
t
.
reshape
(
-
1
,
*
t
.
shape
[
no_batch_dims
:])
else
:
t
=
t
.
expand
(
orig_batch_dims
+
t
.
shape
[
no_batch_dims
:])
return
t
prepped_inputs
=
tensor_tree_map
(
_prep_inputs
,
inputs
)
flat_batch_dim
=
1
for
d
in
orig_batch_dims
:
flat_batch_dim
*=
d
no_chunks
=
flat_batch_dim
//
chunk_size
+
(
flat_batch_dim
%
chunk_size
!=
0
)
i
=
0
out
=
None
for
_
in
range
(
no_chunks
):
# Chunk the input
if
(
not
low_mem
):
select_chunk
=
(
lambda
t
:
t
[
i
:
i
+
chunk_size
]
if
t
.
shape
[
0
]
!=
1
else
t
)
else
:
select_chunk
=
(
partial
(
_chunk_slice
,
flat_start
=
i
,
flat_end
=
min
(
flat_batch_dim
,
i
+
chunk_size
),
no_batch_dims
=
len
(
orig_batch_dims
)
)
)
chunks
=
tensor_tree_map
(
select_chunk
,
prepped_inputs
)
# Run the layer on the chunk
output_chunk
=
layer
(
**
chunks
)
# Allocate space for the output
if
out
is
None
:
allocate
=
lambda
t
:
t
.
new_zeros
((
flat_batch_dim
,)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
out_type
=
type
(
output_chunk
)
if
out_type
is
dict
:
def
assign
(
d1
,
d2
):
for
k
,
v
in
d1
.
items
():
if
type
(
v
)
is
dict
:
assign
(
v
,
d2
[
k
])
else
:
v
[
i
:
i
+
chunk_size
]
=
d2
[
k
]
assign
(
out
,
output_chunk
)
elif
out_type
is
tuple
:
for
x1
,
x2
in
zip
(
out
,
output_chunk
):
x1
[
i
:
i
+
chunk_size
]
=
x2
elif
out_type
is
torch
.
Tensor
:
out
[
i
:
i
+
chunk_size
]
=
output_chunk
else
:
raise
ValueError
(
"Not supported"
)
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
view
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
class
ChunkSizeTuner
:
def
__init__
(
self
,
# Heuristically, runtimes for most of the modules in the network
# plateau earlier than this on all GPUs I've run the model on.
max_chunk_size
=
256
,
):
self
.
max_chunk_size
=
max_chunk_size
self
.
cached_chunk_size
=
None
self
.
cached_arg_data
=
None
def
_determine_favorable_chunk_size
(
self
,
fn
,
args
,
min_chunk_size
):
logging
.
info
(
"Tuning chunk size..."
)
if
(
min_chunk_size
>=
self
.
max_chunk_size
):
return
min_chunk_size
candidates
=
[
2
**
l
for
l
in
range
(
int
(
math
.
log
(
self
.
max_chunk_size
,
2
))
+
1
)]
candidates
=
[
c
for
c
in
candidates
if
c
>
min_chunk_size
]
candidates
=
[
min_chunk_size
]
+
candidates
def
test_chunk_size
(
chunk_size
):
try
:
with
torch
.
no_grad
():
fn
(
*
args
,
chunk_size
=
chunk_size
)
return
True
except
RuntimeError
:
return
False
min_viable_chunk_size_index
=
0
i
=
len
(
candidates
)
-
1
while
i
>
min_viable_chunk_size_index
:
viable
=
test_chunk_size
(
candidates
[
i
])
if
(
not
viable
):
i
=
(
min_viable_chunk_size_index
+
i
)
//
2
else
:
min_viable_chunk_size_index
=
i
i
=
(
i
+
len
(
candidates
)
-
1
)
//
2
return
candidates
[
min_viable_chunk_size_index
]
def
tune_chunk_size
(
self
,
representative_fn
:
Callable
,
args
:
Tuple
[
Any
],
min_chunk_size
:
int
,
)
->
int
:
consistent
=
True
arg_data
=
[
arg
if
type
(
arg
)
!=
torch
.
Tensor
else
arg
.
shape
for
arg
in
args
]
if
(
self
.
cached_arg_data
is
not
None
):
# If args have changed shape/value, we need to re-tune
assert
(
len
(
self
.
cached_arg_data
)
==
len
(
arg_data
))
arg_data_iter
=
zip
(
self
.
cached_arg_data
,
arg_data
)
for
cached_arg_datum
,
arg_datum
in
arg_data_iter
:
assert
(
type
(
cached_arg_datum
)
==
type
(
arg_datum
))
consistent
=
cached_arg_datum
==
arg_datum
else
:
# Otherwise, we can reuse the precomputed value
consistent
=
False
if
(
not
consistent
):
self
.
cached_chunk_size
=
self
.
_determine_favorable_chunk_size
(
representative_fn
,
args
,
min_chunk_size
,
)
self
.
cached_arg_data
=
arg_data
return
self
.
cached_chunk_size
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment