Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
320a622e
Unverified
Commit
320a622e
authored
Sep 06, 2023
by
Woosuk Kwon
Committed by
GitHub
Sep 06, 2023
Browse files
[BugFix] Implement RoPE for GPT-J (#941)
parent
c9927c1a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
122 additions
and
72 deletions
+122
-72
csrc/pos_encoding.cpp
csrc/pos_encoding.cpp
+6
-5
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+68
-45
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+38
-18
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+5
-2
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+5
-2
No files found.
csrc/pos_encoding.cpp
View file @
320a622e
#include <torch/extension.h>
#include <torch/extension.h>
void
rotary_embedding
_neox
(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
torch
::
Tensor
&
key
,
int
head_size
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
m
.
def
(
"rotary_embedding
_neox
"
,
"rotary_embedding"
,
&
rotary_embedding
_neox
,
&
rotary_embedding
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
"Apply GPT-NeoX
or GPT-J
style rotary embedding to query and key"
);
}
}
csrc/pos_encoding_kernels.cu
View file @
320a622e
...
@@ -5,8 +5,38 @@
...
@@ -5,8 +5,38 @@
namespace
vllm
{
namespace
vllm
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_neox_kernel
(
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
__ldg
(
cos_ptr
+
x_index
);
sin
=
__ldg
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
__ldg
(
cos_ptr
+
x_index
/
2
);
sin
=
__ldg
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_kv_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_kv_heads, head_size]
...
@@ -23,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -23,58 +53,37 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
const
int
y_index
=
embed_dim
+
rot_offset
;
sin_ptr
,
rot_offset
,
embed_dim
);
const
int
out_x
=
token_idx
*
query_stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
query_stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
}
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
const
int
y_index
=
embed_dim
+
rot_offset
;
sin_ptr
,
rot_offset
,
embed_dim
);
const
int
out_x
=
token_idx
*
key_stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
key_stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
}
}
}
// namespace vllm
}
// namespace vllm
void
rotary_embedding
_neox
(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
{
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
int
num_tokens
=
query
.
size
(
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
...
@@ -87,18 +96,32 @@ void rotary_embedding_neox(
...
@@ -87,18 +96,32 @@ void rotary_embedding_neox(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
query
.
scalar_type
(),
"rotary_embedding
_neox
"
,
"rotary_embedding"
,
[
&
]
{
[
&
]
{
vllm
::
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
if
(
is_neox
)
{
positions
.
data_ptr
<
int64_t
>
(),
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
query_stride
,
rot_dim
,
key_stride
,
query_stride
,
num_heads
,
key_stride
,
num_kv_heads
,
num_heads
,
head_size
);
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
});
}
}
tests/kernels/test_pos_encoding.py
View file @
320a622e
...
@@ -7,49 +7,64 @@ import torch.nn.functional as F
...
@@ -7,49 +7,64 @@ import torch.nn.functional as F
from
vllm
import
pos_encoding_ops
from
vllm
import
pos_encoding_ops
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
12
,
40
,
52
]
# Arbitrary values for testing
NUM_HEADS
=
[
7
,
12
,
40
,
52
]
# Arbitrary values for testing
NUM_TOKENS
=
[
7
,
83
,
2048
]
# Arbitrary values for testing
NUM_TOKENS
=
[
11
,
83
,
2048
]
# Arbitrary values for testing
SEEDS
=
[
0
]
SEEDS
=
[
0
]
def
rotate_
half
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
rotate_
neox
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
def
rotate_gptj
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
x
.
flatten
(
-
2
)
def
apply_rope
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
rotate_fn
=
rotate_neox
if
is_neox_style
else
rotate_gptj
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
q_embed
=
(
q
*
cos
)
+
(
rotate_fn
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_fn
(
k
)
*
sin
)
return
q_embed
,
k_embed
return
q_embed
,
k_embed
class
RefRotaryEmbedding
Neox
(
nn
.
Module
):
class
RefRotaryEmbedding
(
nn
.
Module
):
"""Reference implementation of
the GPT-NeoX style
rotary embedding."""
"""Reference implementation of rotary embedding."""
def
__init__
(
def
__init__
(
self
,
self
,
dim
:
int
,
dim
:
int
,
max_position_embeddings
:
int
=
2048
,
is_neox_style
:
bool
,
max_position_embeddings
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
rotary_dim
=
dim
self
.
rotary_dim
=
dim
self
.
is_neox_style
=
is_neox_style
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
# Create cos and sin embeddings.
# Create cos and sin embeddings.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
t
=
torch
.
arange
(
max_position_embeddings
).
float
()
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
.
float
())
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
inv_freq
.
float
())
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
if
is_neox_style
:
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
else
:
emb
=
torch
.
repeat_interleave
(
freqs
,
2
,
-
1
)
cos
=
emb
.
cos
().
to
(
dtype
=
inv_freq
.
dtype
)
cos
=
emb
.
cos
().
to
(
dtype
=
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
inv_freq
.
dtype
)
sin
=
emb
.
sin
().
to
(
dtype
=
inv_freq
.
dtype
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
self
.
register_buffer
(
"cos_cached"
,
cos
,
persistent
=
False
)
...
@@ -61,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
...
@@ -61,7 +76,6 @@ class RefRotaryEmbeddingNeox(nn.Module):
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
query
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
key
:
torch
.
Tensor
,
# [num_tokens, num_heads, head_size]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
...
@@ -71,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
...
@@ -71,7 +85,9 @@ class RefRotaryEmbeddingNeox(nn.Module):
key_rot
=
key_rot
.
transpose
(
0
,
1
)
key_rot
=
key_rot
.
transpose
(
0
,
1
)
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
cos
=
F
.
embedding
(
positions
,
self
.
cos_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
sin
=
F
.
embedding
(
positions
,
self
.
sin_cached
)
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
)
query_rot
,
key_rot
=
apply_rope
(
query_rot
,
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query_rot
=
query_rot
.
transpose
(
0
,
1
).
contiguous
()
query_rot
=
query_rot
.
transpose
(
0
,
1
).
contiguous
()
key_rot
=
key_rot
.
transpose
(
0
,
1
).
contiguous
()
key_rot
=
key_rot
.
transpose
(
0
,
1
).
contiguous
()
...
@@ -82,6 +98,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
...
@@ -82,6 +98,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
return
query
,
key
return
query
,
key
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
...
@@ -89,7 +106,8 @@ class RefRotaryEmbeddingNeox(nn.Module):
...
@@ -89,7 +106,8 @@ class RefRotaryEmbeddingNeox(nn.Module):
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rotary_embedding_neox
(
def
test_rotary_embedding
(
is_neox_style
:
bool
,
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -104,15 +122,15 @@ def test_rotary_embedding_neox(
...
@@ -104,15 +122,15 @@ def test_rotary_embedding_neox(
torch
.
random
.
manual_seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
'
cuda
'
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
num_tokens
,
),
device
=
"
cuda
"
)
query
=
torch
.
randn
(
num_tokens
,
query
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
'
cuda
'
)
device
=
"
cuda
"
)
key
=
torch
.
randn
(
num_tokens
,
key
=
torch
.
randn
(
num_tokens
,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
dtype
,
dtype
=
dtype
,
device
=
'
cuda
'
)
device
=
"
cuda
"
)
# Create the rotary embedding.
# Create the rotary embedding.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
...
@@ -126,20 +144,22 @@ def test_rotary_embedding_neox(
...
@@ -126,20 +144,22 @@ def test_rotary_embedding_neox(
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query
=
query
.
clone
()
out_query
=
query
.
clone
()
out_key
=
key
.
clone
()
out_key
=
key
.
clone
()
pos_encoding_ops
.
rotary_embedding
_neox
(
pos_encoding_ops
.
rotary_embedding
(
positions
,
positions
,
out_query
,
out_query
,
out_key
,
out_key
,
head_size
,
head_size
,
cos_sin_cache
,
cos_sin_cache
,
is_neox_style
,
)
)
# Run the reference implementation.
# Run the reference implementation.
ref_rotary_embedding
=
RefRotaryEmbedding
Neox
(
ref_rotary_embedding
=
RefRotaryEmbedding
(
dim
=
rotary_dim
,
dim
=
rotary_dim
,
is_neox_style
=
is_neox_style
,
max_position_embeddings
=
max_position
,
max_position_embeddings
=
max_position
,
base
=
base
,
base
=
base
,
).
to
(
dtype
=
dtype
,
device
=
'
cuda
'
)
).
to
(
dtype
=
dtype
,
device
=
"
cuda
"
)
ref_query
,
ref_key
=
ref_rotary_embedding
(
ref_query
,
ref_key
=
ref_rotary_embedding
(
positions
,
positions
,
query
.
view
(
num_tokens
,
num_heads
,
head_size
),
query
.
view
(
num_tokens
,
num_heads
,
head_size
),
...
...
vllm/model_executor/layers/attention.py
View file @
320a622e
...
@@ -242,7 +242,7 @@ class PagedAttention(nn.Module):
...
@@ -242,7 +242,7 @@ class PagedAttention(nn.Module):
class
PagedAttentionWithRoPE
(
PagedAttention
):
class
PagedAttentionWithRoPE
(
PagedAttention
):
"""PagedAttention with
GPT-NeoX style
rotary embedding."""
"""PagedAttention with rotary embedding."""
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -253,8 +253,10 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -253,8 +253,10 @@ class PagedAttentionWithRoPE(PagedAttention):
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
num_kv_heads
:
Optional
[
int
]
=
None
,
num_kv_heads
:
Optional
[
int
]
=
None
,
is_neox_style
:
bool
=
True
,
)
->
None
:
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
)
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
)
self
.
is_neox_style
=
is_neox_style
# Create the cos and sin cache.
# Create the cos and sin cache.
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
rotary_dim
,
2
)
/
rotary_dim
))
...
@@ -303,12 +305,13 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -303,12 +305,13 @@ class PagedAttentionWithRoPE(PagedAttention):
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# to the attention op.
pos_encoding_ops
.
rotary_embedding
_neox
(
pos_encoding_ops
.
rotary_embedding
(
positions
,
positions
,
query
,
query
,
key
,
key
,
self
.
head_size
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
)
)
return
super
().
forward
(
return
super
().
forward
(
query
,
query
,
...
...
vllm/model_executor/models/gpt_j.py
View file @
320a622e
...
@@ -67,8 +67,11 @@ class GPTJAttention(nn.Module):
...
@@ -67,8 +67,11 @@ class GPTJAttention(nn.Module):
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
assert
getattr
(
config
,
"rotary"
,
True
)
assert
getattr
(
config
,
"rotary"
,
True
)
assert
config
.
rotary_dim
%
2
==
0
assert
config
.
rotary_dim
%
2
==
0
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_size
,
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
scaling
,
config
.
rotary_dim
)
self
.
head_size
,
scaling
,
config
.
rotary_dim
,
is_neox_style
=
False
)
self
.
warmup
=
False
self
.
warmup
=
False
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment