Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7151fd54
"docs/vscode:/vscode.git/clone" did not exist on "abbfb6134dc73359cba015dbd1ad30fafd25a891"
Commit
7151fd54
authored
Jul 25, 2025
by
zhuwenwen
Browse files
Triton-fused DeepseekScalingRotaryEmbedding
parent
8da1c576
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
17 deletions
+138
-17
tests/neuron/1_core/test_rotary_embedding.py
tests/neuron/1_core/test_rotary_embedding.py
+60
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+78
-16
No files found.
tests/neuron/1_core/test_rotary_embedding.py
View file @
7151fd54
...
...
@@ -7,7 +7,8 @@ Tests for miscellaneous utilities
import
pytest
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.platforms
import
current_platform
...
...
@@ -66,3 +67,61 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
ref_query
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_deepseek_rotary_embedding
():
device
=
torch
.
device
(
"cuda:0"
)
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
"cuda:0"
)
batch_size
=
10
base
=
10000
num_heads
=
8
max_position
=
4096
is_neox_style
=
False
rotary_dim
=
32
head_size
=
64
scaling_factor
=
40.0
rot
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
torch
.
float32
,
reference
=
False
).
to
(
device
)
rot_ref
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
torch
.
float32
,
reference
=
True
).
to
(
device
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
),
device
=
device
)
# query is [batch, num_heads, head_size]
# key is [batch, 1, head_size]
# cos_sin is [batch, head_size]
query
=
torch
.
randn
(
batch_size
,
num_heads
,
head_size
,
dtype
=
torch
.
float32
,
device
=
device
)
key
=
torch
.
randn
(
batch_size
,
1
,
head_size
,
dtype
=
torch
.
float32
,
device
=
device
)
ref_query
,
ref_key
=
rot_ref
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rot
.
forward
(
positions
,
query
,
key
)
torch
.
testing
.
assert_close
(
out_key
.
cpu
(),
ref_key
.
cpu
(),
atol
=
1e-4
,
rtol
=
1e-4
)
torch
.
testing
.
assert_close
(
out_query
.
cpu
(),
ref_query
.
cpu
(),
atol
=
1e-4
,
rtol
=
1e-4
)
vllm/model_executor/layers/rotary_embedding.py
View file @
7151fd54
...
...
@@ -30,6 +30,9 @@ from typing import Any, Optional, Union
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
transformers
import
PretrainedConfig
from
vllm.model_executor.custom_op
import
CustomOp
...
...
@@ -796,6 +799,34 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
@
triton
.
jit
def
deepseek_scaling_rotary_emb_kernel_gptj
(
cos_sin
,
q
,
stride1
:
int
,
stride2
:
int
,
stride_cs
:
int
,
dim1
:
int
,
dim2
:
int
,
dim3
:
int
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid0
=
tl
.
program_id
(
0
)
pid1
=
tl
.
program_id
(
1
)
pid2
=
tl
.
program_id
(
2
)
offsets_cs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid2
*
BLOCK_SIZE
offsets_q
=
tl
.
arange
(
0
,
BLOCK_SIZE
*
2
)
+
pid2
*
BLOCK_SIZE
*
2
offsets
=
pid0
*
stride1
+
pid1
*
stride2
+
offsets_q
mask
=
offsets_cs
<
dim3
mask2
=
offsets_q
<
dim3
*
2
v_cos
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
offsets_cs
,
mask
=
mask
)
v_cos2
=
tl
.
interleave
(
v_cos
,
v_cos
)
v_sin
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
dim3
+
offsets_cs
,
mask
=
mask
)
v_sin2
=
tl
.
interleave
(
v_sin
,
v_sin
)
x12
=
tl
.
load
(
q
+
offsets
,
mask
=
mask2
)
x1
,
x2
=
tl
.
split
(
x12
.
reshape
([
BLOCK_SIZE
,
2
]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl
.
debug_barrier
()
x12_
=
tl
.
ravel
(
tl
.
join
(
-
x2
,
x1
))
x12
=
x12
*
v_cos2
+
x12_
*
v_sin2
tl
.
store
(
q
+
offsets
,
x12
,
mask
=
mask2
)
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
...
...
@@ -818,12 +849,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
reference
:
bool
=
False
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
self
.
reference
=
reference
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
...
...
@@ -874,17 +907,45 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
assert
key
is
not
None
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
if
query
.
device
.
type
==
'cuda'
and
not
self
.
is_neox_style
\
and
not
self
.
reference
:
assert
len
(
query
.
shape
)
==
3
def
call
(
q
):
BLOCK_SIZE
=
64
grid
=
(
q
.
shape
[
-
3
],
q
.
shape
[
-
2
],
triton
.
cdiv
(
self
.
rotary_dim
//
2
,
BLOCK_SIZE
),
)
deepseek_scaling_rotary_emb_kernel_gptj
[
grid
](
cos_sin
,
q
,
stride1
=
q
.
stride
()[
-
3
],
stride2
=
q
.
stride
()[
-
2
],
stride_cs
=
cos_sin
.
stride
()[
-
2
],
dim1
=
q
.
shape
[
0
],
dim2
=
q
.
shape
[
1
],
dim3
=
self
.
rotary_dim
//
2
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
1
)
call
(
query
)
call
(
key
)
return
query
,
key
else
:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
...
...
@@ -899,6 +960,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment