Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bfe93801
Unverified
Commit
bfe93801
authored
Sep 17, 2025
by
Aidyn-A
Committed by
GitHub
Sep 17, 2025
Browse files
Apply fixes for CUDA 13 (#24599)
Signed-off-by:
Aidyn-A
<
aidyn.b.aitzhan@gmail.com
>
parent
9fccd04e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
47 additions
and
56 deletions
+47
-56
CMakeLists.txt
CMakeLists.txt
+10
-0
csrc/cub_helpers.h
csrc/cub_helpers.h
+17
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+4
-9
csrc/layernorm_quant_kernels.cu
csrc/layernorm_quant_kernels.cu
+4
-9
csrc/moe/topk_softmax_kernels.cu
csrc/moe/topk_softmax_kernels.cu
+3
-13
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+2
-9
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+2
-7
csrc/quantization/fused_kernels/layernorm_utils.cuh
csrc/quantization/fused_kernels/layernorm_utils.cuh
+5
-9
No files found.
CMakeLists.txt
View file @
bfe93801
...
@@ -175,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -175,6 +175,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
list
(
APPEND VLLM_GPU_FLAGS
"--threads=
${
NVCC_THREADS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"--threads=
${
NVCC_THREADS
}
"
)
endif
()
endif
()
#
# Set CUDA include flags for CXX compiler.
#
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-I
${
CUDA_TOOLKIT_ROOT_DIR
}
/include"
)
if
(
CUDA_VERSION VERSION_GREATER_EQUAL 13.0
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-I
${
CUDA_TOOLKIT_ROOT_DIR
}
/include/cccl"
)
endif
()
endif
()
#
#
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
...
...
csrc/cub_helpers.h
0 → 100644
View file @
bfe93801
#pragma once
#ifndef USE_ROCM
#include <cub/cub.cuh>
#if CUB_VERSION >= 200800
#include <cuda/std/functional>
using
CubAddOp
=
cuda
::
std
::
plus
<>
;
using
CubMaxOp
=
cuda
::
maximum
<>
;
#else // if CUB_VERSION < 200800
using
CubAddOp
=
cub
::
Sum
;
using
CubMaxOp
=
cub
::
Max
;
#endif // CUB_VERSION
#else
#include <hipcub/hipcub.hpp>
using
CubAddOp
=
cub
::
Sum
;
using
CubMaxOp
=
cub
::
Max
;
#endif // USE_ROCM
csrc/layernorm_kernels.cu
View file @
bfe93801
#include "type_convert.cuh"
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include <torch/cuda.h>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
...
@@ -30,7 +25,7 @@ __global__ void rms_norm_kernel(
...
@@ -30,7 +25,7 @@ __global__ void rms_norm_kernel(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
c
ub
::
Sum
{},
blockDim
.
x
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
C
ub
AddOp
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
...
@@ -85,7 +80,7 @@ fused_add_rms_norm_kernel(
...
@@ -85,7 +80,7 @@ fused_add_rms_norm_kernel(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
c
ub
::
Sum
{},
blockDim
.
x
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
C
ub
AddOp
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
...
@@ -126,7 +121,7 @@ fused_add_rms_norm_kernel(
...
@@ -126,7 +121,7 @@ fused_add_rms_norm_kernel(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
c
ub
::
Sum
{},
blockDim
.
x
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
C
ub
AddOp
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
...
...
csrc/layernorm_quant_kernels.cu
View file @
bfe93801
...
@@ -8,16 +8,11 @@
...
@@ -8,16 +8,11 @@
#include "type_convert.cuh"
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include <torch/cuda.h>
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
...
@@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
...
@@ -39,7 +34,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
c
ub
::
Sum
{},
blockDim
.
x
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
C
ub
AddOp
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
...
@@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -100,7 +95,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
c
ub
::
Sum
{},
blockDim
.
x
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
C
ub
AddOp
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
...
@@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
...
@@ -149,7 +144,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
c
ub
::
Sum
{},
blockDim
.
x
);
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
C
ub
AddOp
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
...
...
csrc/moe/topk_softmax_kernels.cu
View file @
bfe93801
...
@@ -20,17 +20,7 @@
...
@@ -20,17 +20,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include "../cuda_compat.h"
#include "../cuda_compat.h"
#include "../cub_helpers.h"
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#include <cuda/std/functional>
using
AddOp
=
cuda
::
std
::
plus
<
float
>
;
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using
AddOp
=
cub
::
Sum
;
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
...
@@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__
...
@@ -79,7 +69,7 @@ __launch_bounds__(TPB) __global__
threadData
=
max
(
static_cast
<
float
>
(
input
[
idx
]),
threadData
);
threadData
=
max
(
static_cast
<
float
>
(
input
[
idx
]),
threadData
);
}
}
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
c
ub
::
Max
());
const
float
maxElem
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
C
ubMax
Op
());
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
float_max
=
maxElem
;
float_max
=
maxElem
;
...
@@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__
...
@@ -94,7 +84,7 @@ __launch_bounds__(TPB) __global__
threadData
+=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
));
threadData
+=
exp
((
static_cast
<
float
>
(
input
[
idx
])
-
float_max
));
}
}
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
AddOp
());
const
auto
Z
=
BlockReduce
(
tmpStorage
).
Reduce
(
threadData
,
Cub
AddOp
());
if
(
threadIdx
.
x
==
0
)
if
(
threadIdx
.
x
==
0
)
{
{
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
bfe93801
...
@@ -7,17 +7,10 @@
...
@@ -7,17 +7,10 @@
#include <cmath>
#include <cmath>
#include "../../cub_helpers.h"
#include "../../dispatch_utils.h"
#include "../../dispatch_utils.h"
#include "../vectorization_utils.cuh"
#include "../vectorization_utils.cuh"
#ifndef USE_ROCM
#include <cub/cub.cuh>
#include <cub/util_type.cuh>
#else
#include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp>
#endif
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
static
inline
__device__
int8_t
float_to_int8_rn
(
float
x
)
{
#ifdef USE_ROCM
#ifdef USE_ROCM
static
constexpr
auto
i8_min
=
static
constexpr
auto
i8_min
=
...
@@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
...
@@ -173,7 +166,7 @@ __global__ void dynamic_scaled_int8_quant_kernel(
});
});
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
thread_max
,
c
ub
::
Max
{},
blockDim
.
x
);
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
thread_max
,
C
ubMax
Op
{},
blockDim
.
x
);
__shared__
float
absmax
;
__shared__
float
absmax
;
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
absmax
=
block_max
;
absmax
=
block_max
;
...
...
csrc/quantization/fp8/common.cu
View file @
bfe93801
#include "common.cuh"
#include "common.cuh"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#include "../../cub_helpers.h"
#include "../vectorization_utils.cuh"
#include "../vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
namespace
vllm
{
template
<
typename
scalar_t
,
typename
fp8_type
>
template
<
typename
scalar_t
,
typename
fp8_type
>
...
@@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
...
@@ -116,7 +111,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel_strided(
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
__shared__
typename
BlockReduce
::
TempStorage
tmp
;
const
float
block_max
=
const
float
block_max
=
BlockReduce
(
tmp
).
Reduce
(
absmax_val
,
c
ub
::
Max
{},
blockDim
.
x
);
BlockReduce
(
tmp
).
Reduce
(
absmax_val
,
C
ubMax
Op
{},
blockDim
.
x
);
__shared__
float
token_scale
;
__shared__
float
token_scale
;
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
...
...
csrc/quantization/fused_kernels/layernorm_utils.cuh
View file @
bfe93801
...
@@ -8,11 +8,7 @@
...
@@ -8,11 +8,7 @@
#include "quantization/utils.cuh"
#include "quantization/utils.cuh"
#include "quant_conversions.cuh"
#include "quant_conversions.cuh"
#ifndef USE_ROCM
#include "../../cub_helpers.h"
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
namespace
vllm
{
...
@@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
...
@@ -36,7 +32,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
ss
=
BlockReduce
(
reduceStore
).
Reduce
(
ss
,
c
ub
::
Sum
{},
blockDim
.
x
);
ss
=
BlockReduce
(
reduceStore
).
Reduce
(
ss
,
C
ub
AddOp
{},
blockDim
.
x
);
__shared__
float
s_rms
;
__shared__
float
s_rms
;
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -73,7 +69,7 @@ __device__ void compute_dynamic_per_token_scales(
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
block_absmax_val_maybe
=
block_absmax_val_maybe
=
BlockReduce
(
reduceStore
)
BlockReduce
(
reduceStore
)
.
Reduce
(
block_absmax_val_maybe
,
c
ub
::
Max
{},
blockDim
.
x
);
.
Reduce
(
block_absmax_val_maybe
,
C
ubMax
Op
{},
blockDim
.
x
);
__shared__
float
s_token_scale
;
__shared__
float
s_token_scale
;
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
...
@@ -169,7 +165,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
ss
=
BlockReduce
(
reduceStore
).
Reduce
(
ss
,
c
ub
::
Sum
{},
blockDim
.
x
);
ss
=
BlockReduce
(
reduceStore
).
Reduce
(
ss
,
C
ub
AddOp
{},
blockDim
.
x
);
__shared__
float
s_rms
;
__shared__
float
s_rms
;
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
@@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -240,7 +236,7 @@ __device__ void compute_dynamic_per_token_scales(
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
block_absmax_val_maybe
=
block_absmax_val_maybe
=
BlockReduce
(
reduceStore
)
BlockReduce
(
reduceStore
)
.
Reduce
(
block_absmax_val_maybe
,
c
ub
::
Max
{},
blockDim
.
x
);
.
Reduce
(
block_absmax_val_maybe
,
C
ubMax
Op
{},
blockDim
.
x
);
__shared__
float
s_token_scale
;
__shared__
float
s_token_scale
;
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment