Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98c89e16
Unverified
Commit
98c89e16
authored
May 07, 2025
by
Yong Hoon Shin
Committed by
GitHub
May 07, 2025
Browse files
Make key optional for rotary embedding (#17566)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
324a3119
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
221 additions
and
151 deletions
+221
-151
csrc/cpu/pos_encoding.cpp
csrc/cpu/pos_encoding.cpp
+24
-15
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+1
-1
csrc/ops.h
csrc/ops.h
+4
-4
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+55
-43
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+2
-2
tests/kernels/core/test_pos_encoding.py
tests/kernels/core/test_pos_encoding.py
+34
-15
tests/kernels/core/test_rotary_embedding.py
tests/kernels/core/test_rotary_embedding.py
+4
-3
tests/neuron/1_core/test_rotary_embedding.py
tests/neuron/1_core/test_rotary_embedding.py
+21
-12
vllm/_custom_ops.py
vllm/_custom_ops.py
+8
-6
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+68
-50
No files found.
csrc/cpu/pos_encoding.cpp
View file @
98c89e16
...
@@ -9,7 +9,8 @@ void rotary_embedding_impl(
...
@@ -9,7 +9,8 @@ void rotary_embedding_impl(
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
/// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
...
@@ -85,10 +86,13 @@ void rotary_embedding_impl(
...
@@ -85,10 +86,13 @@ void rotary_embedding_impl(
compute_loop
(
token_head
,
cache_ptr
,
query
);
compute_loop
(
token_head
,
cache_ptr
,
query
);
}
}
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
if
(
key
!=
nullptr
)
{
const
int
head_idx
=
i
;
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
head_idx
=
i
;
compute_loop
(
token_head
,
cache_ptr
,
key
);
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
compute_loop
(
token_head
,
cache_ptr
,
key
);
}
}
}
}
}
}
}
...
@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
...
@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
/// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
...
@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
...
@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
}
}
}
}
if
(
key
==
nullptr
)
{
return
;
}
#pragma omp parallel for collapse(2)
#pragma omp parallel for collapse(2)
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
...
@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
...
@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
};
// namespace
};
// namespace
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
int
num_tokens
=
positions
.
numel
();
int
num_tokens
=
positions
.
numel
();
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key
->
size
(
-
1
)
/
head_size
:
num_heads
;
int64_t
key_stride
=
key
.
stride
(
-
2
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
-
2
)
:
0
;
int64_t
query_stride
=
query
.
stride
(
-
2
);
int64_t
query_stride
=
query
.
stride
(
-
2
);
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
...
@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
...
@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
if
(
is_neox
)
{
if
(
is_neox
)
{
rotary_embedding_impl
(
rotary_embedding_impl
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
head_size
,
num_tokens
);
key_stride
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
}
else
{
}
else
{
rotary_embedding_gptj_impl
(
rotary_embedding_gptj_impl
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
head_size
,
num_tokens
);
key_stride
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
}
}
CPU_KERNEL_GUARD_OUT
(
rotary_embedding_impl
)
CPU_KERNEL_GUARD_OUT
(
rotary_embedding_impl
)
...
...
csrc/cpu/torch_bindings.cpp
View file @
98c89e16
...
@@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -117,7 +117,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
ops
.
def
(
"rotary_embedding(Tensor positions, Tensor! query,"
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!
?
key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCPU
,
&
rotary_embedding
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCPU
,
&
rotary_embedding
);
...
...
csrc/ops.h
View file @
98c89e16
...
@@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
...
@@ -86,13 +86,13 @@ void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
std
::
optional
<
torch
::
Tensor
>
residual
);
std
::
optional
<
torch
::
Tensor
>
residual
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
int64_t
rot_dim
,
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
...
...
csrc/pos_encoding_kernels.cu
View file @
98c89e16
...
@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
...
@@ -38,7 +38,8 @@ inline __device__ void apply_rotary_embedding(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
...
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
...
@@ -57,13 +58,15 @@ inline __device__ void apply_rotary_embedding(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
if
(
key
!=
nullptr
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
const
int
rot_offset
=
i
%
embed_dim
;
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
}
}
}
...
@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
...
@@ -74,7 +77,8 @@ __global__ void rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
...
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
...
@@ -98,7 +102,8 @@ __global__ void batched_rotary_embedding_kernel(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
...
@@ -127,10 +132,12 @@ void rotary_embedding(
...
@@ -127,10 +132,12 @@ void rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
std
::
optional
<
torch
::
Tensor
>
key
,
// [num_tokens, num_kv_heads * head_size] or
// null or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
bool
is_neox
)
{
...
@@ -138,40 +145,40 @@ void rotary_embedding(
...
@@ -138,40 +145,40 @@ void rotary_embedding(
int64_t
num_tokens
=
positions
.
numel
();
int64_t
num_tokens
=
positions
.
numel
();
int
positions_ndim
=
positions
.
dim
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
"query, key and positions must have the same number of tokens"
);
}
}
if
(
positions_ndim
==
2
)
{
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
"query, key and positions must have the same batch_size and seq_len"
);
}
}
// Make sure head_size is valid for query and key
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
// hidden_size = num_heads * head_size
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have consistent number of heads
// Make sure query and key have consistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -181,15 +188,16 @@ void rotary_embedding(
...
@@ -181,15 +188,16 @@ void rotary_embedding(
if
(
is_neox
)
{
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
head_size
);
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
}
});
});
}
}
...
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
...
@@ -204,10 +212,12 @@ void batched_rotary_embedding(
// [num_tokens, num_heads * head_size] or
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
std
::
optional
<
torch
::
Tensor
>
// [num_tokens, num_kv_heads * head_size] or
key
,
// null or
// [batch_size, seq_len, num_heads, head_size] or
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_heads, head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int64_t
rot_dim
,
bool
is_neox
,
int64_t
rot_dim
,
...
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
...
@@ -221,38 +231,38 @@ void batched_rotary_embedding(
"cos_sin_cache_offsets"
);
"cos_sin_cache_offsets"
);
int
positions_ndim
=
positions
.
dim
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key
.
// Make sure num_tokens dim is consistent across positions, query, and key
TORCH_CHECK
(
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
siz
e
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
(
!
key
.
has_valu
e
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
,
"query, key and positions must have the same number of tokens"
);
"query, key and positions must have the same number of tokens"
);
}
}
if
(
positions_ndim
==
2
)
{
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
(
!
key
.
has_value
()
||
key
->
size
(
0
)
==
positions
.
size
(
0
)
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
(
!
key
.
has_value
()
||
key
->
size
(
1
)
==
positions
.
size
(
1
)
)
,
"query, key and positions must have the same batch_size and seq_len"
);
"query, key and positions must have the same batch_size and seq_len"
);
}
}
// Make sure head_size is valid for query and key
// Make sure head_size is valid for query and key
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
has_value
()
?
key
->
numel
()
/
num_tokens
:
0
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have concistent number of heads
// Make sure query and key have concistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
int
num_kv_heads
=
key
.
has_value
()
?
key_hidden_size
/
head_size
:
num_heads
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
seq_dim_idx
)
:
0
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
...
@@ -263,14 +273,16 @@ void batched_rotary_embedding(
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
}
...
...
csrc/torch_bindings.cpp
View file @
98c89e16
...
@@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -176,7 +176,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops
.
def
(
ops
.
def
(
"rotary_embedding(Tensor positions, Tensor! query,"
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!
?
key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
" Tensor cos_sin_cache, bool is_neox) -> ()"
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
ops
.
impl
(
"rotary_embedding"
,
torch
::
kCUDA
,
&
rotary_embedding
);
...
@@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -184,7 +184,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// (supports multiple loras).
// (supports multiple loras).
ops
.
def
(
ops
.
def
(
"batched_rotary_embedding(Tensor positions, Tensor! query,"
"batched_rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor!
?
key, int head_size,"
" Tensor cos_sin_cache, bool is_neox,"
" Tensor cos_sin_cache, bool is_neox,"
" int rot_dim,"
" int rot_dim,"
" Tensor cos_sin_cache_offsets) -> ()"
);
" Tensor cos_sin_cache_offsets) -> ()"
);
...
...
tests/kernels/core/test_pos_encoding.py
View file @
98c89e16
...
@@ -21,6 +21,7 @@ SEEDS = [0]
...
@@ -21,6 +21,7 @@ SEEDS = [0]
CUDA_DEVICES
=
[
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
]
USE_KEY
=
[
True
,
False
]
def
_get_flat_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
def
_get_flat_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
...
@@ -46,6 +47,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
...
@@ -46,6 +47,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
def
test_rotary_embedding
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
...
@@ -58,6 +60,7 @@ def test_rotary_embedding(
...
@@ -58,6 +60,7 @@ def test_rotary_embedding(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
...
@@ -74,7 +77,7 @@ def test_rotary_embedding(
...
@@ -74,7 +77,7 @@ def test_rotary_embedding(
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
...
@@ -85,10 +88,14 @@ def test_rotary_embedding(
...
@@ -85,10 +88,14 @@ def test_rotary_embedding(
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
if
use_key
:
ref_key
,
torch
.
testing
.
assert_close
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -101,6 +108,7 @@ def test_rotary_embedding(
...
@@ -101,6 +108,7 @@ def test_rotary_embedding(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding
(
def
test_batched_rotary_embedding
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
...
@@ -113,6 +121,7 @@ def test_batched_rotary_embedding(
...
@@ -113,6 +121,7 @@ def test_batched_rotary_embedding(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
...
@@ -129,7 +138,7 @@ def test_batched_rotary_embedding(
...
@@ -129,7 +138,7 @@ def test_batched_rotary_embedding(
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
...
@@ -145,10 +154,14 @@ def test_batched_rotary_embedding(
...
@@ -145,10 +154,14 @@ def test_batched_rotary_embedding(
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
if
use_key
:
ref_key
,
torch
.
testing
.
assert_close
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
@@ -160,6 +173,7 @@ def test_batched_rotary_embedding(
...
@@ -160,6 +173,7 @@ def test_batched_rotary_embedding(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"use_key"
,
USE_KEY
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding_multi_lora
(
def
test_batched_rotary_embedding_multi_lora
(
is_neox_style
:
bool
,
is_neox_style
:
bool
,
...
@@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -171,6 +185,7 @@ def test_batched_rotary_embedding_multi_lora(
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
,
seed
:
int
,
device
:
str
,
device
:
str
,
use_key
:
bool
,
max_position
:
int
=
8192
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
base
:
int
=
10000
,
)
->
None
:
)
->
None
:
...
@@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -190,7 +205,7 @@ def test_batched_rotary_embedding_multi_lora(
seq_len
,
seq_len
,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
dtype
)
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
offset_map
=
torch
.
tensor
(
offset_map
=
torch
.
tensor
(
list
(
list
(
...
@@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora(
...
@@ -214,10 +229,14 @@ def test_batched_rotary_embedding_multi_lora(
ref_query
,
ref_query
,
atol
=
get_default_atol
(
out_query
),
atol
=
get_default_atol
(
out_query
),
rtol
=
get_default_rtol
(
out_query
))
rtol
=
get_default_rtol
(
out_query
))
torch
.
testing
.
assert_close
(
out_key
,
if
use_key
:
ref_key
,
torch
.
testing
.
assert_close
(
out_key
,
atol
=
get_default_atol
(
out_key
),
ref_key
,
rtol
=
get_default_rtol
(
out_key
))
atol
=
get_default_atol
(
out_key
),
rtol
=
get_default_rtol
(
out_key
))
else
:
assert
ref_key
is
None
and
out_key
is
None
,
\
"expected returned key to be None"
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
tests/kernels/core/test_rotary_embedding.py
View file @
98c89e16
...
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
...
@@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
def
rotary_embedding_opcheck
(
rot
,
def
rotary_embedding_opcheck
(
rot
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
offsets
:
Optional
[
torch
.
Tensor
]
=
None
):
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
cos_sin_cache
=
rot
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
...
@@ -37,9 +37,10 @@ def rotary_embedding_opcheck(rot,
...
@@ -37,9 +37,10 @@ def rotary_embedding_opcheck(rot,
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
[
32
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"head_size"
,
[
32
,
108
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
11
,
1024
])
@
pytest
.
mark
.
parametrize
(
"use_key"
,
[
True
,
False
])
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
def
test_rotary_embedding_opcheck
(
dist_init
,
device
,
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
):
seq_len
,
use_key
):
batch_size
=
1
batch_size
=
1
base
=
10000
base
=
10000
num_heads
=
7
num_heads
=
7
...
@@ -54,7 +55,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
...
@@ -54,7 +55,7 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
device
)
device
=
device
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
rotary_embedding_opcheck
(
rot
,
positions
,
query
,
key
)
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
offsets
=
torch
.
zeros
(
batch_size
*
seq_len
,
...
...
tests/neuron/1_core/test_rotary_embedding.py
View file @
98c89e16
...
@@ -11,14 +11,16 @@ from vllm.platforms import current_platform
...
@@ -11,14 +11,16 @@ from vllm.platforms import current_platform
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"max_position,is_neox_style,rotary_dim,head_size,seq_len"
,
[
"max_position,is_neox_style,rotary_dim,head_size,seq_len,use_key"
,
[
(
16
,
False
,
32
,
32
,
1024
),
(
16
,
False
,
32
,
32
,
1024
,
True
),
(
16
,
False
,
32
,
128
,
1024
),
(
16
,
False
,
32
,
128
,
1024
,
True
),
(
16
,
True
,
32
,
32
,
1024
),
(
16
,
True
,
32
,
32
,
1024
,
True
),
(
16
,
True
,
32
,
128
,
1024
),
(
16
,
True
,
32
,
128
,
1024
,
True
),
(
16
,
False
,
32
,
128
,
1024
,
False
),
(
16
,
True
,
32
,
128
,
1024
,
False
),
])
])
def
test_rotary_embedding_opcheck
(
max_position
,
is_neox_style
,
rotary_dim
,
def
test_rotary_embedding_opcheck
(
max_position
,
is_neox_style
,
rotary_dim
,
head_size
,
seq_len
):
head_size
,
seq_len
,
use_key
):
import
torch_xla.core.xla_model
as
xm
import
torch_xla.core.xla_model
as
xm
device
=
xm
.
xla_device
()
device
=
xm
.
xla_device
()
...
@@ -40,19 +42,26 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
...
@@ -40,19 +42,26 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
num_heads
*
head_size
,
num_heads
*
head_size
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
"cpu"
)
device
=
"cpu"
)
key
=
torch
.
randn_like
(
query
)
key
=
torch
.
randn_like
(
query
)
if
use_key
else
None
assert
positions
.
is_cpu
,
\
assert
positions
.
is_cpu
,
\
"reference input tensor is expected to be CPU tensor."
"reference input tensor is expected to be CPU tensor."
ref_query
,
ref_key
=
rot
.
to
(
device
=
"cpu"
).
forward_native
(
ref_query
,
ref_key
=
rot
.
to
(
device
=
"cpu"
).
forward_native
(
positions
,
query
,
key
)
positions
,
query
,
key
)
out_query
,
out_key
=
rot
.
to
(
device
=
device
).
forward_neuron
(
out_query
,
out_key
=
rot
.
to
(
device
=
device
).
forward_neuron
(
positions
.
to
(
device
=
device
),
query
.
to
(
device
=
device
),
positions
.
to
(
device
=
device
),
query
.
to
(
device
=
device
),
key
.
to
(
device
=
device
))
key
.
to
(
device
=
device
)
if
key
is
not
None
else
None
)
assert
out_query
.
is_xla
and
out_key
.
is_xla
,
\
if
use_key
:
"output tensor is expected to be XLA tensor"
assert
out_query
.
is_xla
and
out_key
.
is_xla
,
\
"output tensor is expected to be XLA tensor"
torch
.
testing
.
assert_close
(
out_key
.
cpu
(),
ref_key
,
atol
=
1e-2
,
rtol
=
1e-2
)
else
:
assert
out_key
is
None
,
"expected returned key to be None"
assert
out_query
.
is_xla
,
\
"output tensor is expected to be XLA tensor"
torch
.
testing
.
assert_close
(
out_query
.
cpu
(),
torch
.
testing
.
assert_close
(
out_query
.
cpu
(),
ref_query
,
ref_query
,
atol
=
1e-2
,
atol
=
1e-2
,
rtol
=
1e-2
)
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
out_key
.
cpu
(),
ref_key
,
atol
=
1e-2
,
rtol
=
1e-2
)
vllm/_custom_ops.py
View file @
98c89e16
...
@@ -153,34 +153,36 @@ def merge_attn_states(output: torch.Tensor,
...
@@ -153,34 +153,36 @@ def merge_attn_states(output: torch.Tensor,
def
rotary_embedding
(
def
rotary_embedding
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
,
head_size
:
int
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
is_neox
:
bool
,
)
->
None
:
)
->
None
:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous
=
query
.
contiguous
()
query_contiguous
=
query
.
contiguous
()
key_contiguous
=
key
.
contiguous
()
key_contiguous
=
key
.
contiguous
()
if
key
is
not
None
else
None
torch
.
ops
.
_C
.
rotary_embedding
(
positions
,
query_contiguous
,
key_contiguous
,
torch
.
ops
.
_C
.
rotary_embedding
(
positions
,
query_contiguous
,
key_contiguous
,
head_size
,
cos_sin_cache
,
is_neox
)
head_size
,
cos_sin_cache
,
is_neox
)
query
.
copy_
(
query_contiguous
)
query
.
copy_
(
query_contiguous
)
key
.
copy_
(
key_contiguous
)
if
key
is
not
None
:
key
.
copy_
(
key_contiguous
)
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
def
batched_rotary_embedding
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
key
:
Optional
[
torch
.
Tensor
]
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
,
rot_dim
:
int
,
rot_dim
:
int
,
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
cos_sin_cache_offsets
:
torch
.
Tensor
)
->
None
:
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous
=
query
.
contiguous
()
query_contiguous
=
query
.
contiguous
()
key_contiguous
=
key
.
contiguous
()
key_contiguous
=
key
.
contiguous
()
if
key
is
not
None
else
None
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query_contiguous
,
torch
.
ops
.
_C
.
batched_rotary_embedding
(
positions
,
query_contiguous
,
key_contiguous
,
head_size
,
key_contiguous
,
head_size
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache
,
is_neox
,
rot_dim
,
cos_sin_cache_offsets
)
cos_sin_cache_offsets
)
query
.
copy_
(
query_contiguous
)
query
.
copy_
(
query_contiguous
)
key
.
copy_
(
key_contiguous
)
if
key
is
not
None
:
key
.
copy_
(
key_contiguous
)
# layer norm ops
# layer norm ops
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
98c89e16
...
@@ -138,9 +138,9 @@ class RotaryEmbedding(CustomOp):
...
@@ -138,9 +138,9 @@ class RotaryEmbedding(CustomOp):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
"""A PyTorch-native implementation of forward()."""
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
+
offsets
...
@@ -157,22 +157,24 @@ class RotaryEmbedding(CustomOp):
...
@@ -157,22 +157,24 @@ class RotaryEmbedding(CustomOp):
self
.
is_neox_style
)
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
# key may be None in some cases, e.g. cross-layer KV sharing
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
key
is
not
None
:
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_shape
=
key
.
shape
key_pass
=
key
[...,
self
.
rotary_dim
:]
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
_apply_rotary_emb_torch
(
key_rot
,
cos
,
sin
,
key_rot
=
key
[...,
:
self
.
rotary_dim
]
self
.
is_neox_style
)
key_pass
=
key
[...,
self
.
rotary_dim
:]
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
key_rot
=
_apply_rotary_emb_torch
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
def
forward_cuda
(
def
forward_cuda
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
...
@@ -198,32 +200,39 @@ class RotaryEmbedding(CustomOp):
...
@@ -198,32 +200,39 @@ class RotaryEmbedding(CustomOp):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
as
ops
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
positions
.
device
,
dtype
=
query
.
dtype
)
dtype
=
query
.
dtype
)
# ops.rotary_embedding()/batched_rotary_embedding()
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
# are in-place operations that update the query and key tensors.
if
offsets
is
not
None
:
if
key
is
None
:
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
# XPU kernel doesn't support key=None so fall back to native impl
self
.
cos_sin_cache
,
# TODO(sarckk): add support for optional key in
self
.
is_neox_style
,
self
.
rotary_dim
,
# ipex.llm.functional.rotary_embedding_batched
offsets
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
else
:
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
if
offsets
is
not
None
:
self
.
cos_sin_cache
,
self
.
is_neox_style
)
ops
.
batched_rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
,
self
.
rotary_dim
,
offsets
)
else
:
ops
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_hpu
(
def
forward_hpu
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
from
habana_frameworks.torch.hpex.kernels
import
(
from
habana_frameworks.torch.hpex.kernels
import
(
RotaryPosEmbeddingMode
,
apply_rotary_pos_emb
)
RotaryPosEmbeddingMode
,
apply_rotary_pos_emb
)
if
offsets
is
not
None
:
if
offsets
is
not
None
:
...
@@ -265,21 +274,23 @@ class RotaryEmbedding(CustomOp):
...
@@ -265,21 +274,23 @@ class RotaryEmbedding(CustomOp):
rope_mode
)
rope_mode
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
if
key
is
not
None
:
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_shape
=
key
.
shape
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_rot
=
apply_rotary_pos_emb
(
key_rot
,
cos
,
sin
,
None
,
0
,
rope_mode
)
key_pass
=
key
[...,
self
.
rotary_dim
:]
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
key_rot
=
apply_rotary_pos_emb
(
key_rot
,
cos
,
sin
,
None
,
0
,
rope_mode
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
def
forward_neuron
(
def
forward_neuron
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
def
_apply_rotary_emb_neuron
(
def
_apply_rotary_emb_neuron
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -319,14 +330,16 @@ class RotaryEmbedding(CustomOp):
...
@@ -319,14 +330,16 @@ class RotaryEmbedding(CustomOp):
query_shape
=
query
.
shape
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_shape
=
key
.
shape
if
key
is
not
None
:
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
if
self
.
rotary_dim
==
self
.
head_size
:
if
self
.
rotary_dim
==
self
.
head_size
:
query
=
_apply_rotary_emb
(
query
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
_apply_rotary_emb
(
query
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
query
.
reshape
(
query_shape
)
query
=
query
.
reshape
(
query_shape
)
key
=
_apply_rotary_emb
(
key
,
cos
,
sin
,
self
.
is_neox_style
)
if
key
is
not
None
:
key
=
key
.
reshape
(
key_shape
)
key
=
_apply_rotary_emb
(
key
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
key
.
reshape
(
key_shape
)
else
:
else
:
head_size
=
query
.
shape
[
-
1
]
head_size
=
query
.
shape
[
-
1
]
query_reshaped
=
query
.
view
(
-
1
,
head_size
)
query_reshaped
=
query
.
view
(
-
1
,
head_size
)
...
@@ -339,14 +352,15 @@ class RotaryEmbedding(CustomOp):
...
@@ -339,14 +352,15 @@ class RotaryEmbedding(CustomOp):
query
=
torch
.
cat
((
query_rot
,
query_pass
),
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
dim
=-
1
).
reshape
(
query_shape
)
key_reshaped
=
key
.
view
(
-
1
,
head_size
)
if
key
is
not
None
:
key_pass
=
key_reshaped
[:,
self
.
rotary_dim
:].
view
(
key_reshaped
=
key
.
view
(
-
1
,
head_size
)
*
key
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
key_pass
=
key_reshaped
[:,
self
.
rotary_dim
:].
view
(
key_rot
=
key_reshaped
[:,
:
self
.
rotary_dim
].
view
(
*
key
.
shape
[:
-
1
],
head_size
-
self
.
rotary_dim
)
*
key
.
shape
[:
-
1
],
self
.
rotary_dim
)
key_rot
=
key_reshaped
[:,
:
self
.
rotary_dim
].
view
(
key_rot
=
_apply_rotary_emb_neuron
(
key_rot
,
cos
,
sin
,
*
key
.
shape
[:
-
1
],
self
.
rotary_dim
)
self
.
is_neox_style
)
key_rot
=
_apply_rotary_emb_neuron
(
key_rot
,
cos
,
sin
,
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
return
query
,
key
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
@@ -672,9 +686,10 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
...
@@ -672,9 +686,10 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
key
is
not
None
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
query
=
query
.
view
(
*
query
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
key
=
key
.
view
(
*
key
.
shape
[:
-
1
],
-
1
,
self
.
head_size
)
...
@@ -782,10 +797,11 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -782,10 +797,11 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
assert
key
is
not
None
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
if
self
.
rotary_dim
<
self
.
head_size
:
...
@@ -912,8 +928,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
...
@@ -912,8 +928,9 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
def
forward
(
def
forward
(
self
,
self
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
key
is
not
None
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
query
.
device
)
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
query_
=
torch
.
view_as_complex
(
query
.
float
().
reshape
(
*
query
.
shape
[:
-
1
],
-
1
,
2
))
*
query
.
shape
[:
-
1
],
-
1
,
2
))
...
@@ -957,8 +974,8 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -957,8 +974,8 @@ class MRotaryEmbedding(RotaryEmbedding):
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]
]
:
"""PyTorch-native implementation equivalent to forward().
"""PyTorch-native implementation equivalent to forward().
Args:
Args:
...
@@ -969,6 +986,7 @@ class MRotaryEmbedding(RotaryEmbedding):
...
@@ -969,6 +986,7 @@ class MRotaryEmbedding(RotaryEmbedding):
key: [num_tokens, num_kv_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
"""
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
positions
.
ndim
==
1
or
positions
.
ndim
==
2
assert
key
is
not
None
num_tokens
=
positions
.
shape
[
-
1
]
num_tokens
=
positions
.
shape
[
-
1
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
cos_sin
=
self
.
cos_sin_cache
[
positions
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment