Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
097978a1
Unverified
Commit
097978a1
authored
Dec 22, 2025
by
Jee Jee Li
Committed by
GitHub
Dec 21, 2025
Browse files
[Kernel] Enable fused_qknorm_rope_kernel supports partial rope (#30821)
Signed-off-by:
Jee Jee Li
<
pandaleefree@gmail.com
>
parent
7e065eba
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
55 deletions
+67
-55
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+62
-53
tests/kernels/core/test_fused_qk_norm_rope.py
tests/kernels/core/test_fused_qk_norm_rope.py
+5
-2
No files found.
csrc/fused_qknorm_rope_kernel.cu
View file @
097978a1
...
...
@@ -107,7 +107,8 @@ __global__ void fusedQKNormRopeKernel(
void
const
*
k_weight_void
,
// RMSNorm weights for key
void
const
*
cos_sin_cache_void
,
// Pre-computed cos/sin cache
int64_t
const
*
position_ids
,
// Position IDs for RoPE
int
const
num_tokens
// Number of tokens
int
const
num_tokens
,
// Number of tokens
int
const
rotary_dim
// Dimension for RoPE
)
{
#if (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800) && !defined(USE_ROCM)
if
constexpr
((
std
::
is_same_v
<
scalar_t_in
,
c10
::
BFloat16
>
)
||
...
...
@@ -227,22 +228,24 @@ __global__ void fusedQKNormRopeKernel(
// Calculate cache pointer for this position - similar to
// pos_encoding_kernels.cu
T_cache
const
*
cache_ptr
=
cos_sin_cache
+
pos_id
*
head
_dim
;
int
const
embed_dim
=
head
_dim
/
2
;
T_cache
const
*
cache_ptr
=
cos_sin_cache
+
pos_id
*
rotary
_dim
;
int
const
embed_dim
=
rotary
_dim
/
2
;
T_cache
const
*
cos_ptr
=
cache_ptr
;
T_cache
const
*
sin_ptr
=
cache_ptr
+
embed_dim
;
int
const
rotary_lanes
=
rotary_dim
/
numElemsPerThread
;
// rotary range
if
(
laneId
<
rotary_lanes
)
{
if
constexpr
(
interleave
)
{
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
/
2
;
++
i
)
{
int
const
idx0
=
2
*
i
;
int
const
idx1
=
2
*
i
+
1
;
// Global dimension index in the head
int
const
dim_idx
=
laneId
*
numElemsPerThread
+
idx0
;
float
const
val0
=
elements
[
idx0
];
float
const
val1
=
elements
[
idx1
];
int
const
dim_idx
=
laneId
*
numElemsPerThread
+
idx0
;
int
const
half_dim
=
dim_idx
/
2
;
float
const
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
...
...
@@ -255,19 +258,20 @@ __global__ void fusedQKNormRopeKernel(
}
else
{
// Before data exchange with in warp, we need to sync.
__syncwarp
();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
int
pairOffset
=
(
rotary_dim
/
2
)
/
numElemsPerThread
;
// Get the data from the other half of the warp. Use pre-computed
// cos/sin values.
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
16
);
if
(
laneId
<
16
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
pairOffset
);
if
(
laneId
<
pairOffset
)
{
elements2
[
i
]
=
-
elements2
[
i
];
}
int
dim_idx
=
laneId
*
numElemsPerThread
+
i
;
dim_idx
=
(
dim_idx
*
2
)
%
head_dim
;
dim_idx
=
(
dim_idx
*
2
)
%
rotary_dim
;
int
half_dim
=
dim_idx
/
2
;
// Use pre-computed cos/sin from cache
float
cos_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
cos_ptr
+
half_dim
));
float
sin_val
=
CacheConverter
::
convert
(
VLLM_LDG
(
sin_ptr
+
half_dim
));
...
...
@@ -276,7 +280,7 @@ __global__ void fusedQKNormRopeKernel(
// __shfl_xor_sync does not provide memfence. Need to sync again.
__syncwarp
();
}
}
// Store.
{
vec_T
vec
;
...
...
@@ -312,10 +316,10 @@ template <typename scalar_t_in, typename scalar_t_cache>
void
launchFusedQKNormRope
(
void
*
qkv
,
int
const
num_tokens
,
int
const
num_heads_q
,
int
const
num_heads_k
,
int
const
num_heads_v
,
int
const
head_dim
,
floa
t
const
eps
,
void
const
*
q_weight
,
void
const
*
k
_weight
,
void
const
*
cos_sin_cache
,
bool
const
interleave
,
int64_t
const
*
position_ids
,
cudaStream_t
stream
)
{
in
t
const
rotary_dim
,
float
const
eps
,
void
const
*
q
_weight
,
void
const
*
k_weight
,
void
const
*
cos_sin_cache
,
bool
const
interleave
,
int64_t
const
*
position_ids
,
cudaStream_t
stream
)
{
constexpr
int
blockSize
=
256
;
int
const
warpsPerBlock
=
blockSize
/
32
;
...
...
@@ -332,7 +336,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
64
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
,
rotary_dim
);
});
break
;
case
128
:
...
...
@@ -340,7 +344,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
128
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
,
rotary_dim
);
});
break
;
case
256
:
...
...
@@ -348,7 +352,7 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens,
fusedQKNormRopeKernel
<
scalar_t_in
,
scalar_t_cache
,
256
,
INTERLEAVE
>
<<<
gridDim
,
blockDim
,
0
,
stream
>>>
(
qkv
,
num_heads_q
,
num_heads_k
,
num_heads_v
,
eps
,
q_weight
,
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
);
k_weight
,
cos_sin_cache
,
position_ids
,
num_tokens
,
rotary_dim
);
});
break
;
default:
...
...
@@ -392,12 +396,16 @@ void fused_qk_norm_rope(
"Query weights size must match head dimension"
);
TORCH_CHECK
(
k_weight
.
size
(
0
)
==
head_dim
,
"Key weights size must match head dimension"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
==
head_dim
,
"Cos/sin cache dimension must match head_dim"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
%
2
==
0
,
"rotary_dim must be even"
);
TORCH_CHECK
(
cos_sin_cache
.
size
(
1
)
<=
head_dim
,
"rotary_dim must be less than or equal to head_dim"
);
TORCH_CHECK
(
qkv
.
scalar_type
()
==
q_weight
.
scalar_type
()
&&
qkv
.
scalar_type
()
==
k_weight
.
scalar_type
(),
"qkv, q_weight and k_weight must have the same dtype"
);
int64_t
rotary_dim
=
cos_sin_cache
.
size
(
1
);
int64_t
num_tokens
=
qkv
.
size
(
0
);
TORCH_CHECK
(
position_ids
.
size
(
0
)
==
num_tokens
,
"Number of tokens in position_ids must match QKV"
);
...
...
@@ -419,7 +427,8 @@ void fused_qk_norm_rope(
qkv
.
data_ptr
(),
static_cast
<
int
>
(
num_tokens
),
static_cast
<
int
>
(
num_heads_q
),
static_cast
<
int
>
(
num_heads_k
),
static_cast
<
int
>
(
num_heads_v
),
static_cast
<
int
>
(
head_dim
),
static_cast
<
float
>
(
eps
),
q_weight
.
data_ptr
(),
k_weight
.
data_ptr
(),
static_cast
<
int
>
(
cos_sin_cache
.
size
(
1
)),
static_cast
<
float
>
(
eps
),
q_weight
.
data_ptr
(),
k_weight
.
data_ptr
(),
cos_sin_cache
.
data_ptr
(),
!
is_neox
,
reinterpret_cast
<
int64_t
const
*>
(
position_ids
.
data_ptr
()),
stream
);
...
...
tests/kernels/core/test_fused_qk_norm_rope.py
View file @
097978a1
...
...
@@ -13,6 +13,7 @@ DTYPES = [torch.bfloat16, torch.float16]
IS_NEOX
=
[
True
,
False
]
EPS_VALUES
=
[
1e-5
,
1e-6
]
SEEDS
=
[
13
]
PARTIAL_ROPE
=
[
True
,
False
]
CUDA_DEVICES
=
[
"cuda:0"
]
...
...
@@ -52,6 +53,7 @@ def _apply_qk_norm_rope(
@
pytest
.
mark
.
parametrize
(
"is_neox"
,
IS_NEOX
)
@
pytest
.
mark
.
parametrize
(
"eps"
,
EPS_VALUES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"rotary_ratio"
,
[
1.0
,
0.5
,
0.25
])
@
torch
.
inference_mode
()
def
test_fused_qk_norm_rope_matches_reference
(
device
:
str
,
...
...
@@ -59,6 +61,7 @@ def test_fused_qk_norm_rope_matches_reference(
is_neox
:
bool
,
eps
:
float
,
seed
:
int
,
rotary_ratio
:
float
,
):
torch
.
set_default_device
(
device
)
current_platform
.
seed_everything
(
seed
)
...
...
@@ -76,10 +79,10 @@ def test_fused_qk_norm_rope_matches_reference(
k_norm
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
q_weight
=
q_norm
.
weight
.
data
k_weight
=
k_norm
.
weight
.
data
rotary_dim
=
int
(
head_dim
*
rotary_ratio
)
rope
=
RotaryEmbedding
(
head_size
=
head_dim
,
rotary_dim
=
head
_dim
,
rotary_dim
=
rotary
_dim
,
max_position_embeddings
=
4096
,
base
=
10000.0
,
is_neox_style
=
is_neox
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment