Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
a8601529
Commit
a8601529
authored
Jan 10, 2022
by
Gustaf Ahdritz
Browse files
Prep for bfloat16 training
parent
df89bb28
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
271 additions
and
140 deletions
+271
-140
openfold/config.py
openfold/config.py
+5
-5
openfold/model/evoformer.py
openfold/model/evoformer.py
+31
-17
openfold/model/msa.py
openfold/model/msa.py
+99
-38
openfold/model/primitives.py
openfold/model/primitives.py
+136
-80
No files found.
openfold/config.py
View file @
a8601529
...
...
@@ -148,7 +148,7 @@ config = mlc.ConfigDict(
"same_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
},
"max_extra_msa"
:
1024
,
"max_extra_msa"
:
2048
,
"max_recycling_iters"
:
3
,
"msa_cluster_features"
:
True
,
"reduce_msa_clusters_by_max_templates"
:
False
,
...
...
@@ -211,12 +211,12 @@ config = mlc.ConfigDict(
"fixed_size"
:
True
,
"subsample_templates"
:
True
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
12
8
,
"max_msa_clusters"
:
5
12
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"shuffle_top_k_prefiltered"
:
20
,
"crop"
:
True
,
"crop_size"
:
256
,
"crop_size"
:
384
,
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"subsample_recycling"
:
True
,
...
...
@@ -226,7 +226,7 @@ config = mlc.ConfigDict(
"use_small_bfd"
:
False
,
"data_loaders"
:
{
"batch_size"
:
1
,
"num_workers"
:
8
,
"num_workers"
:
1
,
},
},
},
...
...
@@ -340,7 +340,7 @@ config = mlc.ConfigDict(
"msa_dropout"
:
0.15
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"clear_cache_between_blocks"
:
Fals
e
,
"clear_cache_between_blocks"
:
Tru
e
,
"inf"
:
1e9
,
"eps"
:
eps
,
# 1e-10,
},
...
...
openfold/model/evoformer.py
View file @
a8601529
...
...
@@ -229,7 +229,7 @@ class EvoformerBlock(nn.Module):
inf
:
float
,
eps
:
float
,
):
super
().
__init__
()
super
(
EvoformerBlock
,
self
).
__init__
()
self
.
msa_att_row
=
MSARowAttentionWithPairBias
(
c_m
=
c_m
,
...
...
@@ -246,7 +246,6 @@ class EvoformerBlock(nn.Module):
inf
=
inf
,
)
self
.
msa_dropout_layer
=
DropoutRowwise
(
msa_dropout
)
self
.
core
=
EvoformerBlockCore
(
...
...
@@ -310,7 +309,7 @@ class ExtraMSABlock(nn.Module):
eps
:
float
,
ckpt
:
bool
,
):
super
().
__init__
()
super
(
ExtraMSABlock
,
self
).
__init__
()
self
.
ckpt
=
ckpt
...
...
@@ -352,16 +351,16 @@ class ExtraMSABlock(nn.Module):
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
checkpoint
_chunk_
size
:
Optional
[
int
]
=
512
,
_chunk_
logits
:
Optional
[
int
]
=
1024
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
checkpoint_chunk_size
=
checkpoint_chunk_size
if
self
.
ckpt
else
None
m
=
m
+
self
.
msa_dropout_layer
(
self
.
msa_att_row
(
m
,
z
=
z
,
mask
=
msa_mask
,
chunk_size
=
chunk_size
,
_chunk_and_checkpoint
=
checkpoint_chunk_size
,
_chunk_logits
=
_chunk_logits
,
_checkpoint_chunks
=
self
.
ckpt
,
)
)
...
...
@@ -370,6 +369,7 @@ class ExtraMSABlock(nn.Module):
m
,
z
=
self
.
core
(
m
,
z
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
)
return
m
,
z
if
(
self
.
ckpt
):
...
...
@@ -521,10 +521,7 @@ class EvoformerStack(nn.Module):
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
seq_dim
=
-
3
index
=
torch
.
tensor
([
0
],
device
=
m
.
device
)
s
=
self
.
linear
(
torch
.
index_select
(
m
,
dim
=
seq_dim
,
index
=
index
))
s
=
s
.
squeeze
(
seq_dim
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
...
...
@@ -574,7 +571,7 @@ class ExtraMSAStack(nn.Module):
pair_dropout
=
pair_dropout
,
inf
=
inf
,
eps
=
eps
,
ckpt
=
ckpt
,
ckpt
=
False
,
)
self
.
blocks
.
append
(
block
)
...
...
@@ -599,10 +596,27 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
for
b
in
self
.
blocks
:
m
,
z
=
b
(
m
,
z
,
msa_mask
,
pair_mask
,
chunk_size
=
chunk_size
)
checkpoint_fn
=
get_checkpoint_fn
()
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_chunk_logits
=
None
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_block
s
):
def
dodo
(
b
,
*
arg
s
):
torch
.
cuda
.
empty_cache
()
return
b
(
*
args
)
blocks
=
[
partial
(
dodo
,
b
)
for
b
in
blocks
]
for
b
in
blocks
:
if
(
torch
.
is_grad_enabled
()):
m
,
z
=
checkpoint_fn
(
b
,
m
,
z
)
else
:
m
,
z
=
b
(
m
,
z
)
#for b in self.blocks:
# m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
# if(self.clear_cache_between_blocks):
# torch.cuda.empty_cache()
return
z
openfold/model/msa.py
View file @
a8601529
...
...
@@ -16,9 +16,16 @@
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
,
List
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalAttention
from
typing
import
Optional
,
List
,
Tuple
from
openfold.model.primitives
import
(
Linear
,
LayerNorm
,
Attention
,
GlobalAttention
,
_attention_chunked_trainable
,
)
from
openfold.utils.checkpointing
import
get_checkpoint_fn
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
...
...
@@ -61,12 +68,12 @@ class MSAAttention(nn.Module):
self
.
c_z
=
c_z
self
.
inf
=
inf
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_m
=
LayerNorm
(
self
.
c_in
)
self
.
layer_norm_z
=
None
self
.
linear_z
=
None
if
self
.
pair_bias
:
self
.
layer_norm_z
=
nn
.
LayerNorm
(
self
.
c_z
)
self
.
layer_norm_z
=
LayerNorm
(
self
.
c_z
)
self
.
linear_z
=
Linear
(
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
...
...
@@ -83,33 +90,16 @@ class MSAAttention(nn.Module):
)
->
torch
.
Tensor
:
return
chunk_layer
(
self
.
mha
,
{
"q_x"
:
m
,
"k
_x"
:
m
,
"
v_x"
:
m
,
"biases"
:
biases
},
{
"q_x"
:
m
,
"kv_x"
:
m
,
"biases"
:
biases
},
chunk_size
=
chunk_size
,
no_batch_dims
=
len
(
m
.
shape
[:
-
2
]),
)
def
forward
(
self
,
def
_prep_inputs
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_and_checkpoint
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, N_seq, N_res, C_m]
m
=
self
.
layer_norm_m
(
m
)
...
...
@@ -121,7 +111,7 @@ class MSAAttention(nn.Module):
)
# [*, N_seq, 1, 1, N_res]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
mask_
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
...
...
@@ -130,8 +120,6 @@ class MSAAttention(nn.Module):
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
biases
=
[
bias
]
if
(
self
.
pair_bias
and
z
is
not
None
and
# For the
self
.
layer_norm_z
is
not
None
and
# benefit of
...
...
@@ -146,6 +134,81 @@ class MSAAttention(nn.Module):
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)).
unsqueeze
(
-
4
)
return
m
,
mask_bias
,
z
@
torch
.
jit
.
ignore
def
_chunked_msa_attn
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
],
mask
:
Optional
[
torch
.
Tensor
],
chunk_logits
:
int
,
checkpoint
:
bool
,
)
->
torch
.
Tensor
:
MSA_DIM
=
-
4
def
_get_qkv
(
m
,
z
):
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
q
,
k
,
v
=
self
.
mha
.
_prep_qkv
(
m
,
m
)
return
q
,
k
,
v
,
mask_bias
,
z
checkpoint_fn
=
get_checkpoint_fn
()
if
(
checkpoint
):
q
,
k
,
v
,
mask_bias
,
z
=
checkpoint_fn
(
_get_qkv
,
m
,
z
)
else
:
q
,
k
,
v
,
mask_bias
,
z
=
_get_qkv
(
m
,
z
)
o
=
_attention_chunked_trainable
(
query
=
q
,
key
=
k
,
value
=
v
,
biases
=
[
mask_bias
,
z
],
chunk_size
=
chunk_logits
,
chunk_dim
=
MSA_DIM
,
checkpoint
=
checkpoint
,
)
if
(
checkpoint
):
# Storing an additional m here is far from ideal
m
=
checkpoint_fn
(
self
.
mha
.
_wrap_up
,
o
,
m
)
else
:
m
=
self
.
mha
.
_wrap_up
(
o
,
m
)
return
m
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
,
chunk_size
:
Optional
[
int
]
=
None
,
_chunk_logits
:
Optional
[
int
]
=
None
,
_checkpoint_chunks
:
Optional
[
bool
]
=
None
,
)
->
torch
.
Tensor
:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if
(
_chunk_logits
is
not
None
):
return
self
.
_chunked_msa_attn
(
m
=
m
,
z
=
z
,
mask
=
mask
,
chunk_logits
=
_chunk_logits
,
checkpoint
=
_checkpoint_chunks
)
m
,
mask_bias
,
z
=
self
.
_prep_inputs
(
m
,
z
,
mask
)
biases
=
[
mask_bias
]
if
(
z
is
not
None
):
biases
.
append
(
z
)
if
chunk_size
is
not
None
:
...
...
@@ -153,10 +216,8 @@ class MSAAttention(nn.Module):
else
:
m
=
self
.
mha
(
q_x
=
m
,
k_x
=
m
,
v_x
=
m
,
biases
=
biases
,
_chunk_and_checkpoint
=
_chunk_and_checkpoint
kv_x
=
m
,
biases
=
biases
)
return
m
...
...
openfold/model/primitives.py
View file @
a8601529
...
...
@@ -18,6 +18,7 @@ import math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
import
numpy
as
np
import
deepspeed
import
torch
import
torch.nn
as
nn
from
scipy.stats
import
truncnorm
...
...
@@ -166,65 +167,126 @@ class Linear(nn.Linear):
raise
ValueError
(
"Invalid init string."
)
class
LayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
eps
=
1e-5
):
super
(
LayerNorm
,
self
).
__init__
()
self
.
c_in
=
(
c_in
,)
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
c_in
))
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
c_in
))
def
forward
(
self
,
x
):
d
=
x
.
dtype
if
(
d
==
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
out
=
nn
.
functional
.
layer_norm
(
x
,
self
.
c_in
,
self
.
weight
.
to
(
dtype
=
d
),
self
.
bias
.
to
(
dtype
=
d
),
self
.
eps
)
elif
(
d
==
torch
.
bfloat16
):
raise
NotImplementedError
return
out
def
softmax
(
t
,
dim
=-
1
):
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
d
=
t
.
dtype
if
(
d
==
torch
.
bfloat16
and
not
deepspeed
.
utils
.
is_initialized
()):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
False
):
s
=
torch
.
nn
.
functional
.
softmax
(
t
,
dim
=
dim
)
elif
(
d
==
torch
.
bfloat16
):
raise
NotImplementedError
return
s
def
_attention
(
query
,
key
,
value
,
biases
):
# [*, H, Q, C_hidden]
query
=
permute_final_dims
(
query
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
key
=
permute_final_dims
(
key
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
value
=
permute_final_dims
(
value
,
(
1
,
0
,
2
))
# [*, H, Q, K]
a
=
torch
.
matmul
(
query
,
key
)
for
b
in
biases
:
a
+=
b
a
=
torch
.
nn
.
functional
.
softmax
(
a
,
dim
=-
1
)
a
=
softmax
(
a
,
dim
=-
1
)
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
value
)
a
=
torch
.
matmul
(
a
,
value
)
# [*, Q, H, C_hidden]
o
=
o
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
return
o
return
a
@
torch
.
jit
.
ignore
def
_attention_chunk_and_checkpoint
(
query
,
key
,
value
,
biases
,
chunk_size
):
if
(
len
(
biases
)
>
2
):
def
_attention_chunked_trainable
(
query
,
key
,
value
,
biases
,
chunk_size
,
chunk_dim
,
checkpoint
,
):
if
(
checkpoint
and
len
(
biases
)
>
2
):
raise
ValueError
(
"
_chunk_and_checkpoint
only permits two bias terms"
"
Checkpointed version permits
only permits two bias terms"
)
biases
=
biases
+
[
None
,
None
]
bias_1
,
bias_2
=
biases
[:
2
]
def
_checkpointable_attention
(
q
,
k
,
v
,
b1
,
b2
):
bs
=
[
b
1
,
b2
]
bs
=
[
b
for
b
in
[
b1
,
b2
]
if
b
is
not
None
]
return
_attention
(
q
,
k
,
v
,
bs
)
batch_dims
=
query
.
shape
[:
-
3
]
no_batch_dims
=
len
(
query
.
shape
[:
-
3
])
# q, k, and v are assumed to have no singleton dimensions
flat_q
=
query
.
reshape
(
-
1
,
*
query
.
shape
[
-
3
:])
flat_k
=
key
.
reshape
(
-
1
,
*
key
.
shape
[
-
3
:])
flat_v
=
value
.
reshape
(
-
1
,
*
value
.
shape
[
-
3
:])
o_chunks
=
[]
checkpoint_fn
=
get_checkpoint_fn
()
count
=
flat_q
.
shape
[
0
]
count
=
query
.
shape
[
chunk_dim
]
for
start
in
range
(
0
,
count
,
chunk_size
):
end
=
start
+
chunk_size
q_chunk
=
flat_q
[
start
:
end
,
...]
k_chunk
=
flat_k
[
start
:
end
,
...]
v_chunk
=
flat_v
[
start
:
end
,
...]
bias_1_chunk
=
_chunk_slice
(
bias_1
,
start
,
end
,
no_batch_dims
)
bias_2_chunk
=
_chunk_slice
(
bias_2
,
start
,
end
,
no_batch_dims
)
idx
=
[
slice
(
None
)]
*
len
(
query
.
shape
)
idx
[
chunk_dim
]
=
slice
(
start
,
end
)
idx_tup
=
tuple
(
idx
)
q_chunk
=
query
[
idx_tup
]
k_chunk
=
key
[
idx_tup
]
v_chunk
=
value
[
idx_tup
]
def
_slice_bias
(
b
):
idx
[
chunk_dim
]
=
(
slice
(
start
,
end
)
if
b
.
shape
[
chunk_dim
]
!=
1
else
slice
(
None
)
)
return
b
[
tuple
(
idx
)]
if
(
checkpoint
):
bias_1_chunk
,
bias_2_chunk
=
[
_slice_bias
(
b
)
if
b
is
not
None
else
None
for
b
in
(
biases
+
[
None
,
None
])[:
2
]
]
o_chunk
=
checkpoint_fn
(
_checkpointable_attention
,
q_chunk
,
k_chunk
,
v_chunk
,
bias_1_chunk
,
bias_2_chunk
)
else
:
bias_chunks
=
[
_slice_bias
(
b
)
for
b
in
biases
]
o_chunk
s
.
append
(
o_chunk
)
o_chunk
=
_attention
(
q_chunk
,
k_chunk
,
v_chunk
,
bias_chunks
)
o_flat
=
torch
.
cat
(
o_chunks
,
dim
=
0
)
o_chunks
.
append
(
o_chunk
)
return
o_flat
.
reshape
(
batch_dims
+
o_flat
.
shape
[
1
:])
o
=
torch
.
cat
(
o_chunks
,
dim
=
chunk_dim
)
return
o
class
Attention
(
nn
.
Module
):
...
...
@@ -289,16 +351,50 @@ class Attention(nn.Module):
self
.
sigmoid
=
nn
.
Sigmoid
()
def
_prep_qkv
(
self
,
q_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
kv_x
)
v
=
self
.
linear_v
(
kv_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
/=
math
.
sqrt
(
self
.
c_hidden
)
return
q
,
k
,
v
def
_wrap_up
(
self
,
o
:
torch
.
Tensor
,
q_x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
return
o
def
forward
(
self
,
q_x
:
torch
.
Tensor
,
k_x
:
torch
.
Tensor
,
v_x
:
torch
.
Tensor
,
kv_x
:
torch
.
Tensor
,
biases
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
use_lma
:
bool
=
False
,
q_chunk_size
:
Optional
[
int
]
=
None
,
kv_chunk_size
:
Optional
[
int
]
=
None
,
_chunk_and_checkpoint
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
"""
Args:
...
...
@@ -318,59 +414,20 @@ class Attention(nn.Module):
"If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided"
)
if
(
use_lma
and
_chunk_and_checkpoint
is
not
None
):
raise
ValueError
(
"use_lma and _chunk_and_checkpoint are mutually exclusive"
)
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
v
=
self
.
linear_v
(
v_x
)
# [*, Q/K, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
k
=
k
.
view
(
k
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
v
=
v
.
view
(
v
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
q
=
q
/
math
.
sqrt
(
self
.
c_hidden
)
q
,
k
,
v
=
self
.
_prep_qkv
(
q_x
,
kv_x
)
if
(
use_lma
):
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
k_x
.
shape
[
-
2
],))
for
b
in
biases
]
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
# [*, H, Q, C_hidden]
q
=
permute_final_dims
(
q
,
(
1
,
0
,
2
))
# [*, H, C_hidden, K]
k
=
permute_final_dims
(
k
,
(
1
,
2
,
0
))
# [*, H, V, C_hidden]
v
=
permute_final_dims
(
v
,
(
1
,
0
,
2
))
if
(
_chunk_and_checkpoint
):
# REMEMBER THAT THE K, Q, V COMPUTATION AND GATING ARE *NOT*
# CHECKPOINTED HERE
o
=
_attention_chunk_and_checkpoint
(
q
,
k
,
v
,
biases
,
_chunk_and_checkpoint
)
o
=
_lma
(
q
,
k
,
v
,
biases
,
q_chunk_size
,
kv_chunk_size
)
else
:
o
=
_attention
(
q
,
k
,
v
,
biases
)
if
(
self
.
linear_g
is
not
None
):
g
=
self
.
sigmoid
(
self
.
linear_g
(
q_x
))
# [*, Q, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
o
=
o
*
g
# [*, Q, H * C_hidden]
o
=
flatten_final_dims
(
o
,
2
)
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
o
=
self
.
_wrap_up
(
o
,
q_x
)
return
o
...
...
@@ -399,7 +456,6 @@ class GlobalAttention(nn.Module):
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
...
...
@@ -425,7 +481,7 @@ class GlobalAttention(nn.Module):
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
self
.
softmax
(
a
)
a
=
softmax
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment