Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
9b294976
Unverified
Commit
9b294976
authored
Dec 02, 2023
by
Woosuk Kwon
Committed by
GitHub
Dec 02, 2023
Browse files
Add PyTorch-native implementation of custom layers (#1898)
parent
5313c2cb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
150 additions
and
185 deletions
+150
-185
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+10
-17
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+24
-35
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+24
-133
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+18
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+20
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+54
-0
No files found.
tests/kernels/test_activation.py
View file @
9b294976
import
pytest
import
torch
import
torch.nn.functional
as
F
from
transformers.activations
import
get_activation
from
vllm.
_C
import
ops
from
vllm.
model_executor.layers.activation
import
FastGELU
,
NewGELU
,
SiluAndMul
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
...
...
@@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS
=
[
0
]
def
ref_silu_and_mul
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
,
x2
=
x
.
chunk
(
chunks
=
2
,
dim
=
1
)
return
F
.
silu
(
x1
)
*
x2
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"d"
,
D
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
@@ -30,9 +23,9 @@ def test_silu_and_mul(
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
o
ps
.
silu_and_mul
(
out
,
x
)
ref_out
=
ref_silu_and_mul
(
x
)
layer
=
SiluAndMul
(
)
o
ut
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
...
...
@@ -50,9 +43,9 @@ def test_gelu_new(
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
o
ps
.
gelu_new
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_new"
)
(
x
)
layer
=
NewGELU
(
)
o
ut
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
...
...
@@ -69,7 +62,7 @@ def test_gelu_fast(
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"cuda"
)
o
ps
.
gelu_fast
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_fast"
)
(
x
)
layer
=
FastGELU
(
)
o
ut
=
layer
(
x
)
ref_out
=
layer
.
_forward
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
tests/kernels/test_layernorm.py
View file @
9b294976
import
pytest
import
torch
import
torch.nn
as
nn
from
vllm.
_C
import
ops
from
vllm.
model_executor.layers.layernorm
import
RMSNorm
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HIDDEN_SIZES
=
[
67
,
768
,
2048
,
5120
,
8192
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
4096
]
# Arbitrary values for testing
HIDDEN_SIZES
=
[
768
,
5120
,
8192
]
# Arbitrary values for testing
ADD_RESIDUAL
=
[
False
,
True
]
SEEDS
=
[
0
]
class
RefRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
super
().
__init__
()
weight
=
torch
.
empty
(
hidden_size
)
weight
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
*
hidden_states
.
to
(
input_dtype
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
HIDDEN_SIZES
)
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
def
test_rms_norm
(
num_tokens
:
int
,
hidden_size
:
int
,
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
scale
=
float
(
hidden_size
**-
0.5
)
x
=
torch
.
empty
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
.
uniform_
(
-
scale
,
scale
)
ref
=
RefRMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
out
,
x
,
ref
.
weight
.
data
,
ref
.
variance_epsilon
,
)
ref_out
=
ref
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-5
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
()
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_out
=
layer
.
_forward
(
x
,
residual
)
out
=
layer
(
x
,
residual
)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions.
# Therefore, we use a larger tolerance.
if
add_residual
:
assert
torch
.
allclose
(
out
[
0
],
ref_out
[
0
],
atol
=
1e-2
,
rtol
=
1e-2
)
assert
torch
.
allclose
(
out
[
1
],
ref_out
[
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
tests/kernels/test_pos_encoding.py
View file @
9b294976
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
import
pytest
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.
_C
import
op
s
from
vllm.
model_executor.layers.rotary_embedding
import
get_r
op
e
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
12
,
40
,
52
]
# Arbitrary values for testing
NUM_TOKENS
=
[
11
,
83
,
2048
]
# Arbitrary values for testing
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
def
rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
def
apply_rope
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
rotate_fn
=
rotate_neox
if
is_neox_style
else
rotate_gptj
q_embed
=
(
q
*
cos
)
+
(
rotate_fn
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_fn
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
RefRotaryEmbedding
(
nn
.
Module
):
"""Reference implementation of rotary embedding."""
def
__init__
(
self
,
dim
:
int
,
is_neox_style
:
bool
,
max_position_embeddings
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
super
().
__init__
()
self
.
rotary_dim
=
dim
self
.
is_neox_style
=
is_neox_style
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
.
float
())
if
is_neox_style
:
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
else
:
emb
=
torch
.
repeat_interleave
(
freqs
,
2
,
-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
sin
,
persistent
=
False
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
# [num_tokens]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
query_rot
=
query_rot
.
transpose
(
0
,
1
)
key_rot
=
key_rot
.
transpose
(
0
,
1
)
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query_rot
,
key_rot
=
apply_rope
(
query_rot
,
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query_rot
.
transpose
(
0
,
1
).
contiguous
()
key_rot
=
key_rot
.
transpose
(
0
,
1
).
contiguous
()
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return
query
,
key
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
...
...
@@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module):
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
is_neox_style
:
bool
,
num_tokens
:
int
,
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
,
rotary_dim
:
Optional
[
int
],
...
...
@@ -122,53 +41,25 @@ def test_rotary_embedding(
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
"cuda"
)
query
=
torch
.
randn
(
num_tokens
,
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
rope
=
rope
.
to
(
dtype
).
cuda
()
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
"cuda"
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
key
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
rotary_dim
))
t
=
torch
.
arange
(
max_position
).
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
"cuda"
)
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query
=
query
.
clone
()
out_key
=
key
.
clone
()
ops
.
rotary_embedding
(
positions
,
out_query
,
out_key
,
head_size
,
cos_sin_cache
,
is_neox_style
,
)
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbedding
(
dim
=
rotary_dim
,
is_neox_style
=
is_neox_style
,
max_position_embeddings
=
max_position
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
"cuda"
)
ref_query
,
ref_key
=
ref_rotary_embedding
(
positions
,
query
.
view
(
num_tokens
,
num_heads
,
head_size
),
key
.
view
(
num_tokens
,
num_heads
,
head_size
),
)
ref_query
=
ref_query
.
view
(
num_tokens
,
num_heads
*
head_size
)
ref_key
=
ref_key
.
view
(
num_tokens
,
num_heads
*
head_size
)
key
=
torch
.
randn_like
(
query
)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query
,
ref_key
=
rope
.
_forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rope
.
forward
(
positions
,
query
,
key
)
# Compare the results.
assert
torch
.
allclose
(
out_query
,
ref_query
,
atol
=
1e-5
,
rtol
=
1e-5
)
assert
torch
.
allclose
(
out_key
,
ref_key
,
atol
=
1e-5
,
rtol
=
1e-5
)
vllm/model_executor/layers/activation.py
View file @
9b294976
"""Custom activation functions."""
import
math
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm._C
import
ops
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
...
...
@@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
...
...
@@ -32,6 +39,12 @@ class SiluAndMul(nn.Module):
class
NewGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
c
=
math
.
sqrt
(
2.0
/
math
.
pi
)
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
c
*
(
x
+
0.044715
*
torch
.
pow
(
x
,
3.0
))))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_new
(
out
,
x
)
...
...
@@ -40,6 +53,11 @@ class NewGELU(nn.Module):
class
FastGELU
(
nn
.
Module
):
def
_forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
x
*
0.7978845608
*
(
1.0
+
0.044715
*
x
*
x
)))
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
torch
.
empty_like
(
x
)
ops
.
gelu_fast
(
out
,
x
)
...
...
vllm/model_executor/layers/layernorm.py
View file @
9b294976
...
...
@@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
_forward
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
if
residual
is
not
None
:
x
=
x
+
residual
.
to
(
torch
.
float32
)
residual
=
x
.
to
(
orig_dtype
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
.
to
(
orig_dtype
)
*
self
.
weight
if
residual
is
None
:
return
x
else
:
return
x
,
residual
def
forward
(
self
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
9b294976
...
...
@@ -30,6 +30,19 @@ import torch.nn as nn
from
vllm._C
import
ops
def
_rotate_neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
_rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
class
RotaryEmbedding
(
nn
.
Module
):
"""Original rotary positional embedding."""
...
...
@@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module):
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
_forward
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""PyTorch-native implementation equivalent to forward()."""
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
else
:
query
=
query_rot
key
=
key_rot
query
=
query
.
flatten
(
-
2
)
key
=
key
.
flatten
(
-
2
)
return
query
,
key
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment