Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
e45a46a5
Commit
e45a46a5
authored
Mar 14, 2023
by
Tri Dao
Browse files
[Rotary] Implement GPT-J style (interleaved) rotary
parent
f28d61cb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
183 additions
and
26 deletions
+183
-26
csrc/rotary/rotary.cpp
csrc/rotary/rotary.cpp
+4
-0
csrc/rotary/rotary_cuda.cu
csrc/rotary/rotary_cuda.cu
+4
-0
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+61
-26
tests/layers/test_rotary.py
tests/layers/test_rotary.py
+114
-0
No files found.
csrc/rotary/rotary.cpp
View file @
e45a46a5
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
...
...
csrc/rotary/rotary_cuda.cu
View file @
e45a46a5
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#include <torch/python.h>
#include <torch/python.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/cuda/Loops.cuh>
...
...
flash_attn/layers/rotary.py
View file @
e45a46a5
#
Inspired by https://github.com/facebookresearch/xformers/blob/main/xformers/components/positional_embedding/rotary.py
#
Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
from
typing
import
Tuple
import
math
import
math
...
@@ -10,31 +10,37 @@ from einops import rearrange, repeat
...
@@ -10,31 +10,37 @@ from einops import rearrange, repeat
import
rotary_emb
import
rotary_emb
def
rotate_half
(
x
):
def
rotate_half
(
x
,
interleaved
=
False
):
if
not
interleaved
:
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
x1
,
x2
=
x
.
chunk
(
2
,
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
else
:
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
'... d two -> ... (d two)'
,
two
=
2
)
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
):
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
"""
"""
x: (batch_size, seqlen, nheads, headdim)
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos, sin: (seqlen, rotary_dim / 2)
"""
"""
ro
tary
_dim
=
cos
.
shape
[
-
1
]
*
2
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro
tary
_dim
<=
x
.
shape
[
-
1
]
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
's d -> s 1 (2 d)'
)
cos
=
repeat
(
cos
,
's d -> s 1 (2 d)'
)
sin
=
repeat
(
sin
,
's d -> s 1 (2 d)'
)
sin
=
repeat
(
sin
,
's d -> s 1 (2 d)'
)
return
torch
.
cat
([
x
[...,
:
ro
tary
_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro
tary
_dim
])
*
sin
,
return
torch
.
cat
([
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
]
,
interleaved
)
*
sin
,
x
[...,
ro
tary
_dim
:]],
dim
=-
1
)
x
[...,
ro_dim
:]],
dim
=-
1
)
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
cos
,
sin
,
inplace
=
False
):
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
):
"""
"""
x: (batch_size, seqlen, nheads, headdim)
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
Apply rotary embedding to the first rotary_dim of x.
"""
"""
...
@@ -44,14 +50,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -44,14 +50,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
assert
rotary_dim
<=
headdim
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
seqlen
<=
rotary_seqlen
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
x1
,
x2
=
x
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
x_ro
=
x
[...,
:
rotary_dim
]
x1
,
x2
=
x_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
x_ro
[...,
::
2
],
x_ro
[...,
1
::
2
])
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
o1
,
o2
=
out
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
x1
,
x2
)
out_ro
=
out
[...,
:
rotary_dim
]
if
inplace
:
o1
,
o2
=
x1
,
x2
else
:
o1
,
o2
=
(
out_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
not
inplace
and
rotary_dim
<
headdim
:
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
interleaved
=
interleaved
ctx
.
inplace
=
inplace
ctx
.
inplace
=
inplace
return
out
if
not
inplace
else
x
return
out
if
not
inplace
else
x
...
@@ -62,14 +75,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -62,14 +75,21 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
rotary_dim
*=
2
inplace
=
ctx
.
inplace
inplace
=
ctx
.
inplace
do1
,
do2
=
do
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
do_ro
=
do
[...,
:
rotary_dim
]
do1
,
do2
=
(
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
]))
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
dx1
,
dx2
=
dx
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
do1
,
do2
)
if
inplace
:
dx1
,
dx2
=
do1
,
do2
else
:
dx_ro
=
dx
[...,
:
rotary_dim
]
dx1
,
dx2
=
(
dx_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
not
inplace
and
rotary_dim
<
headdim
:
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
return
dx
,
None
,
None
,
None
,
None
apply_rotary_emb_func
=
ApplyRotaryEmb
.
apply
apply_rotary_emb_func
=
ApplyRotaryEmb
.
apply
...
@@ -78,11 +98,13 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
...
@@ -78,11 +98,13 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
):
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
):
"""
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
"""
...
@@ -95,13 +117,16 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -95,13 +117,16 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
cos_k
=
cos
if
cos_k
is
None
else
cos_k
cos_k
=
cos
if
cos_k
is
None
else
cos_k
sin_k
=
sin
if
sin_k
is
None
else
sin_k
sin_k
=
sin
if
sin_k
is
None
else
sin_k
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
q_ro
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
]
q1
,
q2
=
q_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
q_ro
[...,
::
2
],
q_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
k1
,
k2
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
k_ro
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
interleaved
=
interleaved
return
qkv
return
qkv
@
staticmethod
@
staticmethod
...
@@ -110,13 +135,17 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -110,13 +135,17 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
_
,
seqlen
,
_
,
_
,
headdim
=
dqkv
.
shape
_
,
seqlen
,
_
,
_
,
headdim
=
dqkv
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
rotary_dim
*=
2
dq1
,
dq2
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
dq_ro
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
]
dq1
,
dq2
=
(
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
dk1
,
dk2
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
dk_ro
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
return
dqkv
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
...
@@ -135,22 +164,25 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -135,22 +164,25 @@ class RotaryEmbedding(torch.nn.Module):
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _repo: https://github.com/ZhuiyiTechnology/roformer
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
.. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox
If scale_base
> 0
, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
If scale_base
is not None
, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
scale_base
=
0
,
device
=
None
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
interleaved
=
False
,
scale_base
=
None
,
device
=
None
):
"""
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
"""
"""
super
().
__init__
()
super
().
__init__
()
# Generate and save the inverse frequency buffer (non trainable)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
dim
))
dtype
=
torch
.
float32
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
interleaved
=
interleaved
self
.
scale_base
=
scale_base
self
.
scale_base
=
scale_base
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
>
0
else
None
)
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
self
.
register_buffer
(
"scale"
,
scale
)
self
.
register_buffer
(
"scale"
,
scale
)
self
.
_seq_len_cached
=
0
self
.
_seq_len_cached
=
0
...
@@ -187,16 +219,19 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -187,16 +219,19 @@ class RotaryEmbedding(torch.nn.Module):
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
qkv: (batch, seqlen, 3, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
token in the batch.
"""
"""
self
.
_update_cos_sin_cache
(
qkv
,
seqlen_offset
)
self
.
_update_cos_sin_cache
(
qkv
,
seqlen_offset
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:]
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
)
)
else
:
else
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:]
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
)
)
tests/layers/test_rotary.py
0 → 100644
View file @
e45a46a5
# Copyright (c) 2023, Tri Dao.
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
transformers.models.gpt_neox.modeling_gpt_neox
import
RotaryEmbedding
as
RotaryEmbeddingNeoX
from
transformers.models.gpt_neox.modeling_gpt_neox
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_neox
from
transformers.models.gptj.modeling_gptj
import
fixed_pos_embedding
from
transformers.models.gptj.modeling_gptj
import
apply_rotary_pos_emb
as
apply_rotary_pos_emb_gptj
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_qkv_
from
flash_attn.layers.rotary
import
RotaryEmbedding
# NeoX-style rotary embedding
@
pytest
.
mark
.
parametrize
(
'seqlen_offset'
,
[
0
,
711
])
@
pytest
.
mark
.
parametrize
(
'rotary_emb_fraction'
,
[
0.5
,
1.0
])
def
test_rotary
(
rotary_emb_fraction
,
seqlen_offset
):
device
=
'cuda'
dtype
=
torch
.
float16
rtol
,
atol
=
(
1e-3
,
5e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen_total
=
2048
seqlen
=
seqlen_total
-
seqlen_offset
nheads
=
16
headdim
=
128
rotary_dim
=
int
(
headdim
*
rotary_emb_fraction
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
qkv_og
=
qkv
.
clone
().
detach
()
# Our implementation modifies qkv inplace
rotary
=
RotaryEmbedding
(
rotary_dim
,
device
=
device
)
rotary_neox
=
RotaryEmbeddingNeoX
(
rotary_dim
,
seqlen_total
,
device
=
device
)
# Doesn't matter what tensor we pass in, rotary_neox only uses the device of the tensor
cos_neox
,
sin_neox
=
rotary_neox
(
qkv
,
seq_len
=
seqlen_total
)
cos_neox
,
sin_neox
=
cos_neox
.
to
(
dtype
=
dtype
),
sin_neox
.
to
(
dtype
=
dtype
)
q_pt
=
rearrange
(
qkv
[:,
:,
0
,
:,
:
rotary_dim
],
'b s h d -> b h s d'
).
detach
().
clone
().
requires_grad_
(
True
)
k_pt
=
rearrange
(
qkv
[:,
:,
1
,
:,
:
rotary_dim
],
'b s h d -> b h s d'
).
detach
().
clone
().
requires_grad_
(
True
)
q_neox
,
k_neox
=
apply_rotary_pos_emb_neox
(
q_pt
,
k_pt
,
cos_neox
,
sin_neox
,
offset
=
seqlen_offset
)
out
=
rotary
(
qkv
,
seqlen_offset
=
seqlen_offset
)
assert
torch
.
allclose
(
rotary
.
_cos_cached
,
cos_neox
[...,
:
rotary_dim
//
2
].
to
(
dtype
=
dtype
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
rotary
.
_sin_cached
,
sin_neox
[...,
:
rotary_dim
//
2
].
to
(
dtype
=
dtype
),
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
rearrange
(
q_neox
,
'b h s d -> b s h d'
),
out
[:,
:,
0
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
rearrange
(
k_neox
,
'b h s d -> b s h d'
),
out
[:,
:,
1
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
out
[:,
:,
0
:
2
,
:,
rotary_dim
:],
qkv_og
[:,
:,
0
:
2
,
:,
rotary_dim
:])
assert
torch
.
equal
(
out
[:,
:,
2
],
qkv_og
[:,
:,
2
])
g
=
torch
.
randn_like
(
out
)
g_og
=
g
.
clone
().
detach
()
# Our implementation modifies g inplace
out
.
backward
(
g
)
q_neox
.
backward
(
rearrange
(
g_og
[:,
:,
0
,
:,
:
rotary_dim
],
'b s h d -> b h s d'
))
k_neox
.
backward
(
rearrange
(
g_og
[:,
:,
1
,
:,
:
rotary_dim
],
'b s h d -> b h s d'
))
assert
torch
.
allclose
(
rearrange
(
q_pt
.
grad
,
'b h s d -> b s h d'
),
qkv
.
grad
[:,
:,
0
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
rearrange
(
k_pt
.
grad
,
'b h s d -> b s h d'
),
qkv
.
grad
[:,
:,
1
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
qkv
.
grad
[:,
:,
0
:
2
,
:,
rotary_dim
:],
g_og
[:,
:,
0
:
2
,
:,
rotary_dim
:])
assert
torch
.
equal
(
qkv
.
grad
[:,
:,
2
],
g_og
[:,
:,
2
])
# GPT-J-style rotary embedding
@
pytest
.
mark
.
parametrize
(
'seqlen_offset'
,
[
0
,
711
])
@
pytest
.
mark
.
parametrize
(
'rotary_emb_fraction'
,
[
0.5
,
1.0
])
def
test_rotary_interleaved
(
rotary_emb_fraction
,
seqlen_offset
):
device
=
'cuda'
dtype
=
torch
.
float16
rtol
,
atol
=
(
1e-3
,
5e-3
)
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen_total
=
2048
seqlen
=
seqlen_total
-
seqlen_offset
nheads
=
16
headdim
=
128
rotary_dim
=
int
(
headdim
*
rotary_emb_fraction
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
qkv_og
=
qkv
.
clone
().
detach
()
# Our implementation modifies qkv inplace
rotary
=
RotaryEmbedding
(
rotary_dim
,
interleaved
=
True
,
device
=
device
)
sincos_gptj
=
fixed_pos_embedding
(
qkv
[...,
:
rotary_dim
],
seq_dim
=
1
,
seq_len
=
seqlen_total
)
sincos_gptj
=
tuple
(
x
.
to
(
dtype
=
dtype
)
for
x
in
sincos_gptj
)
q_pt
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
detach
().
clone
().
requires_grad_
(
True
)
k_pt
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
].
detach
().
clone
().
requires_grad_
(
True
)
q_gptj
=
apply_rotary_pos_emb_gptj
(
q_pt
,
sincos_gptj
,
offset
=
seqlen_offset
)
k_gptj
=
apply_rotary_pos_emb_gptj
(
k_pt
,
sincos_gptj
,
offset
=
seqlen_offset
)
out
=
rotary
(
qkv
,
seqlen_offset
=
seqlen_offset
)
assert
torch
.
allclose
(
rotary
.
_cos_cached
,
sincos_gptj
[
1
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
rotary
.
_sin_cached
,
sincos_gptj
[
0
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
q_gptj
,
out
[:,
:,
0
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
k_gptj
,
out
[:,
:,
1
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
out
[:,
:,
0
:
2
,
:,
rotary_dim
:],
qkv_og
[:,
:,
0
:
2
,
:,
rotary_dim
:])
assert
torch
.
equal
(
out
[:,
:,
2
],
qkv_og
[:,
:,
2
])
g
=
torch
.
randn_like
(
out
)
g_og
=
g
.
clone
().
detach
()
# Our implementation modifies g inplace
out
.
backward
(
g
)
q_gptj
.
backward
(
g_og
[:,
:,
0
,
:,
:
rotary_dim
])
k_gptj
.
backward
(
g_og
[:,
:,
1
,
:,
:
rotary_dim
])
assert
torch
.
allclose
(
q_pt
.
grad
,
qkv
.
grad
[:,
:,
0
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
allclose
(
k_pt
.
grad
,
qkv
.
grad
[:,
:,
1
,
:,
:
rotary_dim
],
rtol
=
rtol
,
atol
=
atol
)
assert
torch
.
equal
(
qkv
.
grad
[:,
:,
0
:
2
,
:,
rotary_dim
:],
g_og
[:,
:,
0
:
2
,
:,
rotary_dim
:])
assert
torch
.
equal
(
qkv
.
grad
[:,
:,
2
],
g_og
[:,
:,
2
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment