Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a8601529
Commit
a8601529
authored
Jan 10, 2022
by
Gustaf Ahdritz
Browse files
Prep for bfloat16 training
parent
df89bb28
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
271 additions
and
140 deletions
+271
-140
openfold/config.py
openfold/config.py
+5
-5
openfold/model/evoformer.py
openfold/model/evoformer.py
+31
-17
openfold/model/msa.py
openfold/model/msa.py
+99
-38
openfold/model/primitives.py
openfold/model/primitives.py
+136
-80
No files found.
openfold/config.py
View file @
a8601529
...
@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
...
@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
"same_prob"
:
0.1
,
"same_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
},
},
"max_extra_msa"
:
1024
,
"max_extra_msa"
:
2048
,
"max_recycling_iters"
:
3
,
"max_recycling_iters"
:
3
,
"msa_cluster_features"
:
True
,
"msa_cluster_features"
:
True
,
"reduce_msa_clusters_by_max_templates"
:
False
,
"reduce_msa_clusters_by_max_templates"
:
False
,
...
@@ -211,12 +211,12 @@ config = mlc.ConfigDict(
...
@@ -211,12 +211,12 @@ config = mlc.ConfigDict(
"fixed_size"
:
True
,
"fixed_size"
:
True
,
"subsample_templates"
:
True
,
"subsample_templates"
:
True
,
"masked_msa_replace_fraction"
:
0.15
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
12
8
,
"max_msa_clusters"
:
5
12
,
"max_template_hits"
:
4
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"max_templates"
:
4
,
"shuffle_top_k_prefiltered"
:
20
,
"shuffle_top_k_prefiltered"
:
20
,
"crop"
:
True
,
"crop"
:
True
,
"crop_size"
:
256
,
"crop_size"
:
384
,
"supervised"
:
True
,
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"clamp_prob"
:
0.9
,
"subsample_recycling"
:
True
,
"subsample_recycling"
:
True
,
...
@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
...
@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd"
:
False
,
"use_small_bfd"
:
False
,
"data_loaders"
:
{
"data_loaders"
:
{
"batch_size"
:
1
,
"batch_size"
:
1
,
"num_workers"
:
8
,
"num_workers"
:
1
,
},
},
},
},
},
},
...
@@ -340,7 +340,7 @@ config = mlc.ConfigDict(
...
@@ -340,7 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
Fals
e
,
"clear_cache_between_blocks"
:
Tru
e
,
"inf"
:
1e9
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
"eps"
:
eps
,
# 1e-10,
},
},
...
...
openfold/model/evoformer.py
View file @
a8601529
...
@@ -185,7 +185,7 @@ class EvoformerBlockCore(nn.Module):
...
@@ -185,7 +185,7 @@ class EvoformerBlockCore(nn.Module):
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# should be disabled to better approximate the exact activations of
# the original.
# the original.
...
@@ -229,7 +229,7 @@ class EvoformerBlock(nn.Module):
...
@@ -229,7 +229,7 @@ class EvoformerBlock(nn.Module):
inf
:
float
,
inf
:
float
,
eps
:
float
,
eps
:
float
,
):
):
super
().
__init__
()
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
c_m
=
c_m
,
...
@@ -246,7 +246,6 @@ class EvoformerBlock(nn.Module):
...
@@ -246,7 +246,6 @@ class EvoformerBlock(nn.Module):
inf
=
inf
,
inf
=
inf
,
)
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
self
.
core
=
EvoformerBlockCore
(
...
@@ -310,7 +309,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -310,7 +309,7 @@ class ExtraMSABlock(nn.Module):
eps
:
float
,
eps
:
float
,
ckpt
:
bool
,
ckpt
:
bool
,
):
):
super
().
__init__
()
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
ckpt
=
ckpt
self
.
ckpt
=
ckpt
...
@@ -352,16 +351,16 @@ class ExtraMSABlock(nn.Module):
...
@@ -352,16 +351,16 @@ class ExtraMSABlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
checkpoint
_chunk_
size
:
Optional
[
int
]
=
512
,
_chunk_
logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
checkpoint_chunk_size
=
checkpoint_chunk_size
if
self
.
ckpt
else
None
m
=
m
+
self
.
msa_dropout_layer
(
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
self
.
msa_att_row
(
m
,
m
,
z
=
z
,
z
=
z
,
mask
=
msa_mask
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
_chunk_and_checkpoint
=
checkpoint_chunk_size
,
_chunk_logits
=
_chunk_logits
,
_checkpoint_chunks
=
self
.
ckpt
,
)
)
)
)
...
@@ -370,6 +369,7 @@ class ExtraMSABlock(nn.Module):
...
@@ -370,6 +369,7 @@ class ExtraMSABlock(nn.Module):
m
,
z
=
self
.
core
(
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
)
return
m
,
z
return
m
,
z
if
(
self
.
ckpt
):
if
(
self
.
ckpt
):
...
@@ -521,11 +521,8 @@ class EvoformerStack(nn.Module):
...
@@ -521,11 +521,8 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
)
seq_dim
=
-
3
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
s
=
s
.
squeeze
(
seq_dim
)
return
m
,
z
,
s
return
m
,
z
,
s
...
@@ -574,7 +571,7 @@ class ExtraMSAStack(nn.Module):
...
@@ -574,7 +571,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
pair_dropout
=
pair_dropout
,
inf
=
inf
,
inf
=
inf
,
eps
=
eps
,
eps
=
eps
,
ckpt
=
ckpt
,
ckpt
=
False
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
@@ -599,10 +596,27 @@ class ExtraMSAStack(nn.Module):
...
@@ -599,10 +596,27 @@ class ExtraMSAStack(nn.Module):
Returns:
Returns:
[*, N_res, N_res, C_z] pair update
[*, N_res, N_res, C_z] pair update
"""
"""
for
b
in
self
.
blocks
:
checkpoint_fn
=
get_checkpoint_fn
()
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
None
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
dodo
(
b
,
*
args
):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
blocks
=
[
partial
(
dodo
,
b
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
m
,
z
=
b
(
m
,
z
)
#for b in self.blocks:
# m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
# if(self.clear_cache_between_blocks):
# torch.cuda.empty_cache()
return
z
return
z
openfold/model/msa.py
View file @
a8601529
...
@@ -16,9 +16,16 @@
...
@@ -16,9 +16,16 @@
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
,
Tuple
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalAttention
from
openfold.model.primitives
import
(
Linear
,
LayerNorm
,
Attention
,
GlobalAttention
,
_attention_chunked_trainable
,
)
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
chunk_layer
,
permute_final_dims
,
permute_final_dims
,
...
@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
...
@@ -61,16 +68,16 @@ class MSAAttention(nn.Module):
self
.
c_z
=
c_z
self
.
c_z
=
c_z
self
.
inf
=
inf
self
.
inf
=
inf
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
if
self
.
pair_bias
:
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
linear_z
=
Linear
(
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
)
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
)
)
...
@@ -83,33 +90,16 @@ class MSAAttention(nn.Module):
...
@@ -83,33 +90,16 @@ class MSAAttention(nn.Module):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
chunk_layer
(
return
chunk_layer
(
self
.
mha
,
self
.
mha
,
{
"q_x"
:
m
,
"k
_x"
:
m
,
"
v_x"
:
m
,
"biases"
:
biases
},
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
)
def
forward
(
self
,
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
chunk_size
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
_chunk_and_checkpoint
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_seq, N_res, C_m]
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
m
=
self
.
layer_norm_m
(
m
)
...
@@ -121,7 +111,7 @@ class MSAAttention(nn.Module):
...
@@ -121,7 +111,7 @@ class MSAAttention(nn.Module):
)
)
# [*, N_seq, 1, 1, N_res]
# [*, N_seq, 1, 1, N_res]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
mask_
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# consume additional memory.
...
@@ -129,9 +119,7 @@ class MSAAttention(nn.Module):
...
@@ -129,9 +119,7 @@ class MSAAttention(nn.Module):
#bias = bias.expand(
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
#)
biases
=
[
bias
]
if
(
self
.
pair_bias
and
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
self
.
layer_norm_z
is
not
None
and
# benefit of
...
@@ -139,13 +127,88 @@ class MSAAttention(nn.Module):
...
@@ -139,13 +127,88 @@ class MSAAttention(nn.Module):
):
):
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
z
=
self
.
layer_norm_z
(
z
)
z
=
self
.
layer_norm_z
(
z
)
# [*, N_res, N_res, no_heads]
# [*, N_res, N_res, no_heads]
z
=
self
.
linear_z
(
z
)
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
return
m
,
mask_bias
,
z
@
torch
.
jit
.
ignore
def
_chunked_msa_attn
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
)
->
torch
.
Tensor
:
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
q
,
k
,
v
,
mask_bias
,
z
checkpoint_fn
=
get_checkpoint_fn
()
if
(
checkpoint
):
q
,
k
,
v
,
mask_bias
,
z
=
checkpoint_fn
(
_get_qkv
,
m
,
z
)
else
:
q
,
k
,
v
,
mask_bias
,
z
=
_get_qkv
(
m
,
z
)
o
=
_attention_chunked_trainable
(
query
=
q
,
key
=
k
,
value
=
v
,
biases
=
[
mask_bias
,
z
],
chunk_size
=
chunk_logits
,
chunk_dim
=
MSA_DIM
,
checkpoint
=
checkpoint
,
)
if
(
checkpoint
):
# Storing an additional m here is far from ideal
m
=
checkpoint_fn
(
self
.
mha
.
_wrap_up
,
o
,
m
)
else
:
m
=
self
.
mha
.
_wrap_up
(
o
,
m
)
return
m
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
if
chunk_size
is
not
None
:
...
@@ -153,10 +216,8 @@ class MSAAttention(nn.Module):
...
@@ -153,10 +216,8 @@ class MSAAttention(nn.Module):
else
:
else
:
m
=
self
.
mha
(
m
=
self
.
mha
(
q_x
=
m
,
q_x
=
m
,
k_x
=
m
,
kv_x
=
m
,
v_x
=
m
,
biases
=
biases
biases
=
biases
,
_chunk_and_checkpoint
=
_chunk_and_checkpoint
)
)
return
m
return
m
...
...
openfold/model/primitives.py
View file @
a8601529
...
@@ -18,6 +18,7 @@ import math
...
@@ -18,6 +18,7 @@ import math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
numpy
as
np
import
deepspeed
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
from
scipy.stats
import
truncnorm
...
@@ -166,65 +167,126 @@ class Linear(nn.Linear):
...
@@ -166,65 +167,126 @@ class Linear(nn.Linear):
raise
ValueError
(
"Invalid init string."
)
raise
ValueError
(
"Invalid init string."
)
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
eps
=
1e-5
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
c_in
=
(
c_in
,)
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
c_in
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
c_in
))
def
forward
(
self
,
x
):
d
=
x
.
dtype
if
(
d
==
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
x
,
self
.
c_in
,
self
.
weight
.
to
(
dtype
=
d
),
self
.
bias
.
to
(
dtype
=
d
),
self
.
eps
)
elif
(
d
==
torch
.
bfloat16
):
raise
NotImplementedError
return
out
def
softmax
(
t
,
dim
=-
1
):
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d
=
t
.
dtype
if
(
d
==
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
elif
(
d
==
torch
.
bfloat16
):
raise
NotImplementedError
return
s
def
_attention
(
query
,
key
,
value
,
biases
):
def
_attention
(
query
,
key
,
value
,
biases
):
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
key
=
permute_final_dims
(
key
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
value
=
permute_final_dims
(
value
,
(
1
,
0
,
2
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
query
,
key
)
a
=
torch
.
matmul
(
query
,
key
)
for
b
in
biases
:
for
b
in
biases
:
a
+=
b
a
+=
b
a
=
torch
.
nn
.
functional
.
softmax
(
a
,
dim
=-
1
)
a
=
softmax
(
a
,
dim
=-
1
)
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
value
)
a
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
return
o
return
a
@
torch
.
jit
.
ignore
@
torch
.
jit
.
ignore
def
_attention_chunk_and_checkpoint
(
query
,
key
,
value
,
biases
,
chunk_size
):
def
_attention_chunked_trainable
(
if
(
len
(
biases
)
>
2
):
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
):
if
(
checkpoint
and
len
(
biases
)
>
2
):
raise
ValueError
(
raise
ValueError
(
"
_chunk_and_checkpoint
only permits two bias terms"
"
Checkpointed version permits
only permits two bias terms"
)
)
biases
=
biases
+
[
None
,
None
]
bias_1
,
bias_2
=
biases
[:
2
]
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b
1
,
b2
]
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
return
_attention
(
q
,
k
,
v
,
bs
)
return
_attention
(
q
,
k
,
v
,
bs
)
batch_dims
=
query
.
shape
[:
-
3
]
no_batch_dims
=
len
(
query
.
shape
[:
-
3
])
# q, k, and v are assumed to have no singleton dimensions
flat_q
=
query
.
reshape
(
-
1
,
*
query
.
shape
[
-
3
:])
flat_k
=
key
.
reshape
(
-
1
,
*
key
.
shape
[
-
3
:])
flat_v
=
value
.
reshape
(
-
1
,
*
value
.
shape
[
-
3
:])
o_chunks
=
[]
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
checkpoint_fn
=
get_checkpoint_fn
()
count
=
flat_q
.
shape
[
0
]
count
=
query
.
shape
[
chunk_dim
]
for
start
in
range
(
0
,
count
,
chunk_size
):
for
start
in
range
(
0
,
count
,
chunk_size
):
end
=
start
+
chunk_size
end
=
start
+
chunk_size
q_chunk
=
flat_q
[
start
:
end
,
...]
idx
=
[
slice
(
None
)]
*
len
(
query
.
shape
)
k_chunk
=
flat_k
[
start
:
end
,
...]
idx
[
chunk_dim
]
=
slice
(
start
,
end
)
v_chunk
=
flat_v
[
start
:
end
,
...]
idx_tup
=
tuple
(
idx
)
bias_1_chunk
=
_chunk_slice
(
bias_1
,
start
,
end
,
no_batch_dims
)
q_chunk
=
query
[
idx_tup
]
bias_2_chunk
=
_chunk_slice
(
bias_2
,
start
,
end
,
no_batch_dims
)
k_chunk
=
key
[
idx_tup
]
v_chunk
=
value
[
idx_tup
]
o_chunk
=
checkpoint_fn
(
_checkpointable_attention
,
q_chunk
,
k_chunk
,
v_chunk
,
bias_1_chunk
,
bias_2_chunk
def
_slice_bias
(
b
):
)
idx
[
chunk_dim
]
=
(
slice
(
start
,
end
)
if
b
.
shape
[
chunk_dim
]
!=
1
else
slice
(
None
)
)
return
b
[
tuple
(
idx
)]
o_chunks
.
append
(
o_chunk
)
if
(
checkpoint
):
bias_1_chunk
,
bias_2_chunk
=
[
_slice_bias
(
b
)
if
b
is
not
None
else
None
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
]
o_chunk
=
checkpoint_fn
(
_checkpointable_attention
,
q_chunk
,
k_chunk
,
v_chunk
,
bias_1_chunk
,
bias_2_chunk
)
else
:
bias_chunks
=
[
_slice_bias
(
b
)
for
b
in
biases
]
o_flat
=
torch
.
cat
(
o
_chunk
s
,
dim
=
0
)
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v
_chunk
,
bias_chunks
)
return
o_flat
.
reshape
(
batch_dims
+
o_flat
.
shape
[
1
:])
o_chunks
.
append
(
o_chunk
)
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
return
o
class
Attention
(
nn
.
Module
):
class
Attention
(
nn
.
Module
):
...
@@ -289,16 +351,50 @@ class Attention(nn.Module):
...
@@ -289,16 +351,50 @@ class Attention(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
def
_prep_qkv
(
self
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
kv_x
)
v
=
self
.
linear_v
(
kv_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
/=
math
.
sqrt
(
self
.
c_hidden
)
return
q
,
k
,
v
def
_wrap_up
(
self
,
o
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
def
forward
(
def
forward
(
self
,
self
,
q_x
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_lma
:
bool
=
False
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
_chunk_and_checkpoint
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
Args:
Args:
...
@@ -318,59 +414,20 @@ class Attention(nn.Module):
...
@@ -318,59 +414,20 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
"be provided"
)
)
if
(
use_lma
and
_chunk_and_checkpoint
is
not
None
):
raise
ValueError
(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
v
=
self
.
linear_v
(
v_x
)
# [*, Q/K, H, C_hidden]
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
/
math
.
sqrt
(
self
.
c_hidden
)
if
(
use_lma
):
if
(
use_lma
):
biases
=
[
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
for
b
in
biases
for
b
in
biases
]
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
else
:
# [*, H, Q, C_hidden]
o
=
_attention
(
q
,
k
,
v
,
biases
)
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
if
(
_chunk_and_checkpoint
):
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o
=
_attention_chunk_and_checkpoint
(
q
,
k
,
v
,
biases
,
_chunk_and_checkpoint
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
self
.
_wrap_up
(
o
,
q_x
)
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
return
o
...
@@ -399,7 +456,6 @@ class GlobalAttention(nn.Module):
...
@@ -399,7 +456,6 @@ class GlobalAttention(nn.Module):
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
# [*, N_res, C_in]
...
@@ -425,7 +481,7 @@ class GlobalAttention(nn.Module):
...
@@ -425,7 +481,7 @@ class GlobalAttention(nn.Module):
)
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
+=
bias
a
=
self
.
softmax
(
a
)
a
=
softmax
(
a
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment