Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
85ac82d2
Unverified
Commit
85ac82d2
authored
Feb 07, 2025
by
Isotr0py
Committed by
GitHub
Feb 06, 2025
Browse files
[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)
parent
1e57b1ee
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
57 deletions
+115
-57
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+89
-14
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+22
-9
vllm/attention/backends/mla/utils.py
vllm/attention/backends/mla/utils.py
+3
-22
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+1
-12
No files found.
csrc/pos_encoding_kernels.cu
View file @
85ac82d2
...
...
@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int64_t
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
// num_tokens = batch_size * seq_len
int64_t
num_tokens
=
positions
.
numel
();
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have consistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int64_t
query_stride
=
query
.
stride
(
-
2
);
int64_t
key_stride
=
key
.
stride
(
-
2
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
@@ -165,19 +201,58 @@ and process in batched manner.
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
// [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int64_t
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
or [batch_size]
)
{
// num_tokens = batch_size * seq_len
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int64_t
query_stride
=
query
.
stride
(
-
2
);
int64_t
key_stride
=
key
.
stride
(
-
2
);
TORCH_CHECK
(
positions
.
size
(
0
)
==
num_tokens
||
positions
.
numel
()
==
num_tokens
,
"positions must have the same num_tokens or batch_size as "
"cos_sin_cache_offsets"
);
int
positions_ndim
=
positions
.
dim
();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK
(
positions_ndim
==
1
||
positions_ndim
==
2
,
"positions must have shape [num_tokens] or [batch_size, seq_len]"
);
if
(
positions_ndim
==
1
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
),
"query, key and positions must have the same number of tokens"
);
}
if
(
positions_ndim
==
2
)
{
TORCH_CHECK
(
query
.
size
(
0
)
==
positions
.
size
(
0
)
&&
key
.
size
(
0
)
==
positions
.
size
(
0
)
&&
query
.
size
(
1
)
==
positions
.
size
(
1
)
&&
key
.
size
(
1
)
==
positions
.
size
(
1
),
"query, key and positions must have the same batch_size and seq_len"
);
}
// Make sure head_size is valid for query and key
int
query_hidden_size
=
query
.
numel
()
/
num_tokens
;
int
key_hidden_size
=
key
.
numel
()
/
num_tokens
;
TORCH_CHECK
(
query_hidden_size
%
head_size
==
0
);
TORCH_CHECK
(
key_hidden_size
%
head_size
==
0
);
// Make sure query and key have concistent number of heads
int
num_heads
=
query_hidden_size
/
head_size
;
int
num_kv_heads
=
key_hidden_size
/
head_size
;
TORCH_CHECK
(
num_heads
%
num_kv_heads
==
0
);
int
seq_dim_idx
=
positions_ndim
-
1
;
int64_t
query_stride
=
query
.
stride
(
seq_dim_idx
);
int64_t
key_stride
=
key
.
stride
(
seq_dim_idx
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
<
int64_t
>
(
num_heads
*
rot_dim
/
2
,
512
));
...
...
tests/kernels/test_pos_encoding.py
View file @
85ac82d2
# SPDX-License-Identifier: Apache-2.0
from
itertools
import
accumulate
,
product
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
import
pytest
import
torch
...
...
@@ -24,7 +24,21 @@ CUDA_DEVICES = [
]
def
_get_flat_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
)
->
tuple
[
int
,
...]:
return
(
batch_size
,
seq_len
,
num_heads
*
head_size
)
def
_get_batch_tensor_shape
(
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
head_size
:
int
)
->
tuple
[
int
,
...]:
return
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
TENSORS_SHAPES_FN
=
[
_get_batch_tensor_shape
,
_get_flat_tensor_shape
]
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"tensor_shape_fn"
,
TENSORS_SHAPES_FN
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
@@ -36,6 +50,7 @@ CUDA_DEVICES = [
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
is_neox_style
:
bool
,
tensor_shape_fn
:
Callable
[[
int
,
int
,
int
,
int
],
tuple
[
int
]],
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
...
...
@@ -58,10 +73,8 @@ def test_rotary_embedding(
rope
=
rope
.
to
(
dtype
=
dtype
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# NOTE(woosuk): The reference implementation should be executed first
...
...
@@ -80,6 +93,7 @@ def test_rotary_embedding(
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
@
pytest
.
mark
.
parametrize
(
"tensor_shape_fn"
,
TENSORS_SHAPES_FN
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
BATCH_SIZES
)
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
SEQ_LENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
@@ -91,6 +105,7 @@ def test_rotary_embedding(
@
torch
.
inference_mode
()
def
test_batched_rotary_embedding
(
is_neox_style
:
bool
,
tensor_shape_fn
:
Callable
[[
int
,
int
,
int
,
int
],
tuple
[
int
]],
batch_size
:
int
,
seq_len
:
int
,
num_heads
:
int
,
...
...
@@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
rope
=
rope
.
to
(
dtype
=
dtype
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
query_shape
=
tensor_shape_fn
(
batch_size
,
seq_len
,
num_heads
,
head_size
)
query
=
torch
.
randn
(
query_shape
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# NOTE(woosuk): The reference implementation should be executed first
...
...
vllm/attention/backends/mla/utils.py
View file @
85ac82d2
...
...
@@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
apply_pure_rope
(
self
,
input_positions
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
seq_len
=
input_positions
.
size
(
0
)
ori_q_pe_shape
,
ori_k_pe_shape
=
q_pe
.
shape
,
k_pe
.
shape
q_pe
,
k_pe
=
self
.
rotary_emb
(
input_positions
,
q_pe
.
reshape
(
seq_len
,
-
1
),
k_pe
.
reshape
(
seq_len
,
-
1
),
)
q_pe
,
k_pe
=
q_pe
.
view
(
ori_q_pe_shape
),
k_pe
.
view
(
ori_k_pe_shape
)
return
q_pe
,
k_pe
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
@@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding)
k_pe
=
k_pe
.
unsqueeze
(
1
)
assert
hasattr
(
attn_metadata
,
"input_positions"
)
rope_fn
=
(
self
.
rotary_emb
if
self
.
use_yarn_rope
else
self
.
apply_pure_rope
)
if
is_decode
:
q_nope
=
self
.
_q_proj_and_k_up_proj
(
hidden_states_or_q_c
)
q_pe
=
torch
.
matmul
(
hidden_states_or_q_c
,
self
.
W_QR
)
\
.
view
(
-
1
,
self
.
num_heads
,
self
.
qk_rope_head_dim
)
q_pe
,
k_pe
=
rope_fn
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q_pe
,
k_pe
)
else
:
assert
is_prefill
q
=
self
.
q_proj
(
hidden_states_or_q_c
)[
0
]
\
...
...
@@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# TODO(lucas): there must be a nicer way to write this line
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
=
\
rope_fn
(
self
.
rotary_emb
(
attn_metadata
.
input_positions
,
q
[...,
self
.
qk_nope_head_dim
:],
k_pe
)
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
85ac82d2
...
...
@@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module):
prefix
=
f
"
{
prefix
}
.o_proj"
)
if
rope_scaling
:
rope_scaling
[
"rope_type"
]
=
'deepseek_yarn'
self
.
use_normal_rope
=
False
else
:
self
.
use_normal_rope
=
True
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
max_position
=
max_position_embeddings
,
...
...
@@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module):
k_nope
,
v
=
kv
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k_pe
=
latent_cache
[:,
:,
self
.
kv_lora_rank
:]
if
self
.
use_normal_rope
:
seq_len
=
positions
.
size
(
0
)
ori_q_pe_shape
,
ori_k_pe_shape
=
q_pe
.
shape
,
k_pe
.
shape
q_pe
=
q_pe
.
reshape
(
seq_len
,
-
1
)
k_pe
=
k_pe
.
reshape
(
seq_len
,
-
1
)
q_pe
,
k_pe
=
self
.
rotary_emb
(
positions
,
q_pe
,
k_pe
)
if
self
.
use_normal_rope
:
q_pe
,
k_pe
=
q_pe
.
view
(
ori_q_pe_shape
),
k_pe
.
view
(
ori_k_pe_shape
)
q
[...,
self
.
qk_nope_head_dim
:]
=
q_pe
k
=
torch
.
empty_like
(
q
)
k
[...,
:
self
.
qk_nope_head_dim
]
=
k_nope
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment