Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
496e4f52
Commit
496e4f52
authored
Dec 21, 2022
by
Tri Dao
Browse files
Implement XPos (Sun et al.)
parent
c2407dec
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
16 deletions
+47
-16
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+43
-14
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-1
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+2
-1
No files found.
flash_attn/layers/rotary.py
View file @
496e4f52
...
...
@@ -78,10 +78,11 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cos
,
sin
):
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
...
...
@@ -91,19 +92,21 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
cos_k
=
cos
if
cos_k
is
None
else
cos_k
sin_k
=
sin
if
sin_k
is
None
else
sin_k
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
k1
,
k2
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
ctx
.
save_for_backward
(
cos
,
sin
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
_k
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
_k
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
return
qkv
@
staticmethod
def
backward
(
ctx
,
dqkv
):
cos
,
sin
=
ctx
.
saved_tensors
cos
,
sin
,
cos_k
,
sin_k
=
ctx
.
saved_tensors
_
,
seqlen
,
_
,
_
,
headdim
=
dqkv
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
...
...
@@ -111,9 +114,9 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
dk1
,
dk2
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
return
dqkv
,
None
,
None
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
_k
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
_k
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
return
dqkv
,
None
,
None
,
None
,
None
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
...
...
@@ -134,15 +137,24 @@ class RotaryEmbedding(torch.nn.Module):
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
*
_
,
**
__
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
scale_base
=
0
,
*
_
,
**
__
):
"""
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
super
().
__init__
()
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
scale_base
=
scale_base
scale
=
(
torch
.
arange
(
0
,
dim
,
2
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
>
0
else
None
self
.
register_buffer
(
"scale"
,
scale
)
self
.
_seq_len_cached
=
0
self
.
_cos_cached
=
None
self
.
_sin_cached
=
None
self
.
_cos_k_cached
=
None
self
.
_sin_k_cached
=
None
def
_update_cos_sin_cache
(
self
,
x
,
seqlen_offset
=
0
):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
...
...
@@ -157,8 +169,18 @@ class RotaryEmbedding(torch.nn.Module):
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
)
if
self
.
scale
is
None
:
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
x
.
dtype
)
else
:
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
-
seqlen
//
2
)
/
self
.
scale_base
)
scale
=
self
.
scale
**
rearrange
(
power
,
's -> s 1'
)
# We want the multiplication by scale to happen in fp32
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
x
.
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
x
.
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
...
...
@@ -166,5 +188,12 @@ class RotaryEmbedding(torch.nn.Module):
token in the batch.
"""
self
.
_update_cos_sin_cache
(
qkv
,
seqlen_offset
)
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:])
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:]
)
else
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:]
)
flash_attn/models/gpt.py
View file @
496e4f52
...
...
@@ -36,11 +36,12 @@ def create_mixer_cls(config, layer_idx=None):
softmax_scale
/=
float
(
layer_idx
+
1
)
dwconv
=
getattr
(
config
,
'attn_dwconv'
,
False
)
rotary_emb_dim
=
int
(
getattr
(
config
,
'rotary_emb_fraction'
,
0.0
)
*
head_dim
)
rotary_emb_scale_base
=
getattr
(
config
,
'rotary_emb_scale_base'
,
0
)
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attn_pdrop
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
dwconv
=
dwconv
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_dim
=
rotary_emb_dim
,
rotary_emb_scale_base
=
rotary_emb_scale_base
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
)
return
mixer_cls
...
...
flash_attn/modules/mha.py
View file @
496e4f52
...
...
@@ -283,6 +283,7 @@ class MHA(nn.Module):
def
__init__
(
self
,
embed_dim
,
num_heads
,
cross_attn
=
False
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
dwconv
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
fused_bias_fc
=
False
,
use_flash_attn
=
False
,
return_residual
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
"""
...
...
@@ -308,7 +309,7 @@ class MHA(nn.Module):
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
)
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
if
fused_bias_fc
and
FusedDenseTD
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment