Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7151fd54
Commit
7151fd54
authored
Jul 25, 2025
by
zhuwenwen
Browse files
Triton-fused DeepseekScalingRotaryEmbedding
parent
8da1c576
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
17 deletions
+138
-17
tests/neuron/1_core/test_rotary_embedding.py
tests/neuron/1_core/test_rotary_embedding.py
+60
-1
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+78
-16
No files found.
tests/neuron/1_core/test_rotary_embedding.py
View file @
7151fd54
...
@@ -7,7 +7,8 @@ Tests for miscellaneous utilities
...
@@ -7,7 +7,8 @@ Tests for miscellaneous utilities
import
pytest
import
pytest
import
torch
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -66,3 +67,61 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
...
@@ -66,3 +67,61 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
ref_query
,
ref_query
,
atol
=
1e-2
,
atol
=
1e-2
,
rtol
=
1e-2
)
rtol
=
1e-2
)
def
test_deepseek_rotary_embedding
():
device
=
torch
.
device
(
"cuda:0"
)
current_platform
.
seed_everything
(
0
)
torch
.
set_default_device
(
"cuda:0"
)
batch_size
=
10
base
=
10000
num_heads
=
8
max_position
=
4096
is_neox_style
=
False
rotary_dim
=
32
head_size
=
64
scaling_factor
=
40.0
rot
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
torch
.
float32
,
reference
=
False
).
to
(
device
)
rot_ref
=
DeepseekScalingRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
scaling_factor
,
torch
.
float32
,
reference
=
True
).
to
(
device
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
),
device
=
device
)
# query is [batch, num_heads, head_size]
# key is [batch, 1, head_size]
# cos_sin is [batch, head_size]
query
=
torch
.
randn
(
batch_size
,
num_heads
,
head_size
,
dtype
=
torch
.
float32
,
device
=
device
)
key
=
torch
.
randn
(
batch_size
,
1
,
head_size
,
dtype
=
torch
.
float32
,
device
=
device
)
ref_query
,
ref_key
=
rot_ref
.
forward
(
positions
,
query
,
key
)
out_query
,
out_key
=
rot
.
forward
(
positions
,
query
,
key
)
torch
.
testing
.
assert_close
(
out_key
.
cpu
(),
ref_key
.
cpu
(),
atol
=
1e-4
,
rtol
=
1e-4
)
torch
.
testing
.
assert_close
(
out_query
.
cpu
(),
ref_query
.
cpu
(),
atol
=
1e-4
,
rtol
=
1e-4
)
vllm/model_executor/layers/rotary_embedding.py
View file @
7151fd54
...
@@ -30,6 +30,9 @@ from typing import Any, Optional, Union
...
@@ -30,6 +30,9 @@ from typing import Any, Optional, Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
triton
import
triton.language
as
tl
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
...
@@ -796,6 +799,34 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
...
@@ -796,6 +799,34 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
@
triton
.
jit
def
deepseek_scaling_rotary_emb_kernel_gptj
(
cos_sin
,
q
,
stride1
:
int
,
stride2
:
int
,
stride_cs
:
int
,
dim1
:
int
,
dim2
:
int
,
dim3
:
int
,
BLOCK_SIZE
:
tl
.
constexpr
):
pid0
=
tl
.
program_id
(
0
)
pid1
=
tl
.
program_id
(
1
)
pid2
=
tl
.
program_id
(
2
)
offsets_cs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
+
pid2
*
BLOCK_SIZE
offsets_q
=
tl
.
arange
(
0
,
BLOCK_SIZE
*
2
)
+
pid2
*
BLOCK_SIZE
*
2
offsets
=
pid0
*
stride1
+
pid1
*
stride2
+
offsets_q
mask
=
offsets_cs
<
dim3
mask2
=
offsets_q
<
dim3
*
2
v_cos
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
offsets_cs
,
mask
=
mask
)
v_cos2
=
tl
.
interleave
(
v_cos
,
v_cos
)
v_sin
=
tl
.
load
(
cos_sin
+
pid0
*
stride_cs
+
dim3
+
offsets_cs
,
mask
=
mask
)
v_sin2
=
tl
.
interleave
(
v_sin
,
v_sin
)
x12
=
tl
.
load
(
q
+
offsets
,
mask
=
mask2
)
x1
,
x2
=
tl
.
split
(
x12
.
reshape
([
BLOCK_SIZE
,
2
]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl
.
debug_barrier
()
x12_
=
tl
.
ravel
(
tl
.
join
(
-
x2
,
x1
))
x12
=
x12
*
v_cos2
+
x12_
*
v_sin2
tl
.
store
(
q
+
offsets
,
x12
,
mask
=
mask2
)
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
"""RotaryEmbedding extended with YaRN method.
"""RotaryEmbedding extended with YaRN method.
...
@@ -818,12 +849,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -818,12 +849,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow
:
int
=
1
,
beta_slow
:
int
=
1
,
mscale
:
float
=
1
,
mscale
:
float
=
1
,
mscale_all_dim
:
float
=
0
,
mscale_all_dim
:
float
=
0
,
reference
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
scaling_factor
=
scaling_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
self
.
beta_slow
=
beta_slow
self
.
reference
=
reference
# Get n-d magnitude scaling corrected for interpolation.
# Get n-d magnitude scaling corrected for interpolation.
self
.
mscale
=
float
(
self
.
mscale
=
float
(
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
yarn_get_mscale
(
self
.
scaling_factor
,
float
(
mscale
))
/
...
@@ -874,30 +907,59 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -874,30 +907,59 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
assert
key
is
not
None
assert
key
is
not
None
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
if
self
.
cos_sin_cache
.
device
!=
positions
.
device
:
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
self
.
cos_sin_cache
:
torch
.
Tensor
=
self
.
cos_sin_cache
.
to
(
positions
.
device
)
positions
.
device
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
cos_sin
=
self
.
cos_sin_cache
[
torch
.
add
(
positions
,
offsets
)
if
offsets
is
not
None
else
positions
]
if
offsets
is
not
None
else
positions
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
query
.
device
.
type
==
'cuda'
and
not
self
.
is_neox_style
\
if
self
.
is_neox_style
:
and
not
self
.
reference
:
# NOTE(woosuk): Here we assume that the positions tensor has the
assert
len
(
query
.
shape
)
==
3
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
def
call
(
q
):
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
BLOCK_SIZE
=
64
grid
=
(
q
.
shape
[
-
3
],
q
.
shape
[
-
2
],
triton
.
cdiv
(
self
.
rotary_dim
//
2
,
BLOCK_SIZE
),
)
deepseek_scaling_rotary_emb_kernel_gptj
[
grid
](
cos_sin
,
q
,
stride1
=
q
.
stride
()[
-
3
],
stride2
=
q
.
stride
()[
-
2
],
stride_cs
=
cos_sin
.
stride
()[
-
2
],
dim1
=
q
.
shape
[
0
],
dim2
=
q
.
shape
[
1
],
dim3
=
self
.
rotary_dim
//
2
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
1
)
call
(
query
)
call
(
key
)
return
query
,
key
else
:
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
self
.
rotary_dim
<
self
.
head_size
:
query_pass
=
query
[...,
self
.
rotary_dim
:]
key_pass
=
key
[...,
self
.
rotary_dim
:]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
if
self
.
is_neox_style
:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos
=
cos
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat
(
1
,
1
,
2
).
unsqueeze
(
-
2
)
else
:
cos
=
cos
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
sin
=
sin
.
repeat_interleave
(
2
,
dim
=-
1
).
unsqueeze
(
-
2
)
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
rotate_fn
=
_rotate_neox
if
self
.
is_neox_style
else
_rotate_gptj
query_rot
=
query_rot
*
cos
+
rotate_fn
(
query_rot
)
*
sin
key_rot
=
key_rot
*
cos
+
rotate_fn
(
key_rot
)
*
sin
if
self
.
rotary_dim
<
self
.
head_size
:
if
self
.
rotary_dim
<
self
.
head_size
:
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment