Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
edb59a94
Unverified
Commit
edb59a94
authored
Nov 12, 2025
by
TJian
Committed by
GitHub
Nov 12, 2025
Browse files
[ROCm] [Bugfix] Fix `fused_qknorm_rope_kernel` rocm compatibility (#28500)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
c5f10cc1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
37 additions
and
38 deletions
+37
-38
csrc/fused_qknorm_rope_kernel.cu
csrc/fused_qknorm_rope_kernel.cu
+27
-27
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-2
csrc/type_convert.cuh
csrc/type_convert.cuh
+4
-3
tests/compile/test_qk_norm_rope_fusion.py
tests/compile/test_qk_norm_rope_fusion.py
+2
-2
tests/kernels/core/test_fused_qk_norm_rope.py
tests/kernels/core/test_fused_qk_norm_rope.py
+2
-2
vllm/config/compilation.py
vllm/config/compilation.py
+2
-2
No files found.
csrc/fused_qknorm_rope_kernel.cu
View file @
edb59a94
...
...
@@ -35,10 +35,12 @@
CHECK_TH_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define FINAL_MASK 0xffffffff
#ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL
#else
#define FINAL_MASK 0xffffffff
#endif
// TODO: suport for AMD ROCM platform
#ifndef USE_ROCM
namespace
tensorrt_llm
::
common
{
template
<
typename
T
,
int
num
>
struct
packed_as
;
...
...
@@ -60,7 +62,7 @@ struct packed_as<uint, 4> {
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
);
return
val
;
...
...
@@ -97,12 +99,12 @@ __global__ void fusedQKNormRopeKernel(
int64_t
const
*
position_ids
,
// Position IDs for RoPE
int
const
num_tokens
// Number of tokens
)
{
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
#if
(
!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
) && !defined(USE_ROCM)
if
constexpr
((
std
::
is_same_v
<
scalar_t_in
,
c10
::
BFloat16
>
)
||
std
::
is_same_v
<
scalar_t_cache
,
c10
::
BFloat16
>
)
{
return
;
}
else
{
#endif
#endif
using
Converter
=
vllm
::
_typeConvert
<
scalar_t_in
>
;
static_assert
(
Converter
::
exists
,
...
...
@@ -179,7 +181,7 @@ __global__ void fusedQKNormRopeKernel(
{
vec_T
vec
=
*
reinterpret_cast
<
vec_T
const
*>
(
&
qkv
[
offsetThread
]);
constexpr
int
num_packed_elems
=
elemSizeBytes
/
sizeof
(
T2_in
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packed_elems
;
i
++
)
{
// Interpret the generic vector chunk as the specific packed type
T2_in
packed_val
=
*
(
reinterpret_cast
<
T2_in
*>
(
&
vec
)
+
i
);
...
...
@@ -200,7 +202,7 @@ __global__ void fusedQKNormRopeKernel(
float
rms_rcp
=
rsqrtf
(
sumOfSquares
/
static_cast
<
float
>
(
head_dim
)
+
eps
);
// Normalize elements
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
int
dim
=
laneId
*
numElemsPerThread
+
i
;
float
weight
=
isQ
?
Converter
::
convert
(
q_weight
[
dim
])
...
...
@@ -222,7 +224,7 @@ __global__ void fusedQKNormRopeKernel(
if
constexpr
(
interleave
)
{
// Perform interleaving. Use pre-computed cos/sin values.
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
/
2
;
++
i
)
{
int
const
idx0
=
2
*
i
;
int
const
idx1
=
2
*
i
+
1
;
...
...
@@ -245,9 +247,9 @@ __global__ void fusedQKNormRopeKernel(
__syncwarp
();
// Get the data from the other half of the warp. Use pre-computed cos/sin
// values.
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
numElemsPerThread
;
i
++
)
{
elements2
[
i
]
=
__shfl_xor_sync
(
0xffffffff
,
elements
[
i
],
16
);
elements2
[
i
]
=
__shfl_xor_sync
(
FINAL_MASK
,
elements
[
i
],
16
);
if
(
laneId
<
16
)
{
elements2
[
i
]
=
-
elements2
[
i
];
}
...
...
@@ -269,7 +271,7 @@ __global__ void fusedQKNormRopeKernel(
{
vec_T
vec
;
constexpr
int
num_packed_elems
=
elemSizeBytes
/
sizeof
(
T2_in
);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
num_packed_elems
;
i
++
)
{
// Convert from float2 back to the specific packed type
T2_in
packed_val
=
Converter
::
convert
(
...
...
@@ -280,21 +282,21 @@ __global__ void fusedQKNormRopeKernel(
*
reinterpret_cast
<
vec_T
*>
(
&
qkv
[
offsetThread
])
=
vec
;
}
#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
#if
(
!defined(__CUDA_ARCH__) || __CUDA_ARCH__ < 800
) && !defined(USE_ROCM)
}
#endif
#endif
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
// Borrowed from
// https://github.com/flashinfer-ai/flashinfer/blob/8125d079a43e9a0ba463a4ed1b639cefd084cec9/include/flashinfer/pos_enc.cuh#L568
#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}
template
<
typename
scalar_t_in
,
typename
scalar_t_cache
>
void
launchFusedQKNormRope
(
void
*
qkv
,
int
const
num_tokens
,
...
...
@@ -413,6 +415,4 @@ void fused_qk_norm_rope(
stream
);
});
});
}
#endif // not USE_ROCM
\ No newline at end of file
}
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
edb59a94
...
...
@@ -175,7 +175,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm"
,
torch
::
kCUDA
,
&
fused_add_rms_norm
);
#ifndef USE_ROCM
// Function for fused QK Norm and RoPE
ops
.
def
(
"fused_qk_norm_rope(Tensor! qkv, int num_heads_q, "
...
...
@@ -183,7 +182,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor q_weight, Tensor k_weight, Tensor cos_sin_cache, "
"bool is_neox, Tensor position_ids) -> ()"
);
ops
.
impl
(
"fused_qk_norm_rope"
,
torch
::
kCUDA
,
&
fused_qk_norm_rope
);
#endif
// Apply repetition penalties to logits in-place
ops
.
def
(
...
...
csrc/type_convert.cuh
View file @
edb59a94
...
...
@@ -67,9 +67,9 @@ struct _typeConvert<c10::Half> {
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if
(
defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
) || defined(USE_ROCM)
// CUDA_ARCH < 800 does not have BF16 support
//
TODO: Add in ROCm support once public headers handle bf16 maturely
//
ROCm 7.0+ supports bfloat16
template
<
>
struct
_typeConvert
<
c10
::
BFloat16
>
{
static
constexpr
bool
exists
=
true
;
...
...
@@ -89,7 +89,8 @@ struct _typeConvert<c10::BFloat16> {
return
__float22bfloat162_rn
(
x
);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800) ||
// defined(USE_ROCM)
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
...
...
tests/compile/test_qk_norm_rope_fusion.py
View file @
edb59a94
...
...
@@ -113,8 +113,8 @@ class QKNormRoPETestModel(torch.nn.Module):
@
pytest
.
mark
.
parametrize
(
"enable_rope_custom_op"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Only test on cuda platform"
,
not
current_platform
.
is_cuda
_alike
(),
reason
=
"Only test on cuda
and rocm
platform"
,
)
def
test_qk_norm_rope_fusion
(
eps
,
is_neox
,
enable_rms_norm_custom_op
,
enable_rope_custom_op
,
dtype
...
...
tests/kernels/core/test_fused_qk_norm_rope.py
View file @
edb59a94
...
...
@@ -44,8 +44,8 @@ def _apply_qk_norm_rope(
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"fused_qk_norm_rope custom op requires cuda platform"
,
not
current_platform
.
is_cuda
_alike
(),
reason
=
"fused_qk_norm_rope custom op requires cuda
and rocm
platform"
,
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
...
...
vllm/config/compilation.py
View file @
edb59a94
...
...
@@ -184,10 +184,10 @@ class PassConfig:
"Fusion enabled but reshape elimination disabled. "
"Allreduce + rms norm + quant (fp8) fusion might not work"
)
if
self
.
enable_qk_norm_rope_fusion
and
not
current_platform
.
is_cuda
():
if
self
.
enable_qk_norm_rope_fusion
and
not
current_platform
.
is_cuda
_alike
():
logger
.
warning_once
(
"QK Norm + RoPE fusion enabled but the current platform is not "
"CUDA. The fusion will be disabled."
"CUDA
or ROCm
. The fusion will be disabled."
)
self
.
enable_qk_norm_rope_fusion
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment