Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d631290e
"test/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "6de7021e6be4bff0aa9c3078db999a82d96b1e3a"
Unverified
Commit
d631290e
authored
Sep 02, 2025
by
Lianmin Zheng
Committed by
GitHub
Sep 02, 2025
Browse files
Remove annoying warnings in sgl kernel build (#9905)
parent
37565b7f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
43 additions
and
36 deletions
+43
-36
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+34
-32
sgl-kernel/Makefile
sgl-kernel/Makefile
+1
-2
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
+2
-2
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
+1
-0
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
+5
-0
No files found.
sgl-kernel/CMakeLists.txt
View file @
d631290e
...
@@ -3,6 +3,7 @@ project(sgl-kernel LANGUAGES CXX CUDA)
...
@@ -3,6 +3,7 @@ project(sgl-kernel LANGUAGES CXX CUDA)
# CMake
# CMake
cmake_policy
(
SET CMP0169 OLD
)
cmake_policy
(
SET CMP0169 OLD
)
cmake_policy
(
SET CMP0177 NEW
)
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/utils.cmake
)
include
(
${
CMAKE_CURRENT_LIST_DIR
}
/cmake/utils.cmake
)
set
(
CMAKE_COLOR_DIAGNOSTICS ON
)
set
(
CMAKE_COLOR_DIAGNOSTICS ON
)
set
(
CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL
"ON"
)
set
(
CMAKE_VERBOSE_MAKEFILE ON CACHE BOOL
"ON"
)
...
@@ -50,14 +51,7 @@ FetchContent_Declare(
...
@@ -50,14 +51,7 @@ FetchContent_Declare(
)
)
FetchContent_Populate
(
repo-cutlass
)
FetchContent_Populate
(
repo-cutlass
)
FetchContent_Declare
(
# DeepGEMM
repo-fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-fmt
)
FetchContent_Declare
(
FetchContent_Declare
(
repo-deepgemm
repo-deepgemm
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
GIT_REPOSITORY https://github.com/sgl-project/DeepGEMM
...
@@ -66,6 +60,14 @@ FetchContent_Declare(
...
@@ -66,6 +60,14 @@ FetchContent_Declare(
)
)
FetchContent_Populate
(
repo-deepgemm
)
FetchContent_Populate
(
repo-deepgemm
)
FetchContent_Declare
(
repo-fmt
GIT_REPOSITORY https://github.com/fmtlib/fmt
GIT_TAG 553ec11ec06fbe0beebfbb45f9dc3c9eabd83d28
GIT_SHALLOW OFF
)
FetchContent_Populate
(
repo-fmt
)
# Triton
# Triton
FetchContent_Declare
(
FetchContent_Declare
(
repo-triton
repo-triton
...
@@ -148,21 +150,40 @@ set(SGL_KERNEL_CUDA_FLAGS
...
@@ -148,21 +150,40 @@ set(SGL_KERNEL_CUDA_FLAGS
"--expt-extended-lambda"
"--expt-extended-lambda"
"--threads=32"
"--threads=32"
# Suppress warnings
# Supress warnings
"-Xcompiler=-Wconversion"
"-Xcompiler=-Wno-clang-format-violations"
"-Xcompiler=-fno-strict-aliasing"
"-Xcompiler=-Wno-conversion"
"-Xcompiler=-Wno-deprecated-declarations"
"-Xcompiler=-Wno-terminate"
"-Xcompiler=-Wfatal-errors"
"-Xcompiler=-ftemplate-backtrace-limit=1"
"-Xcudafe=--diag_suppress=177"
# variable was declared but never referenced
# uncomment to debug
# uncomment to debug
# "--ptxas-options=-v"
# "--ptxas-options=-v"
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
# "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage"
)
)
option
(
SGL_KERNEL_ENABLE_SM100A
"Enable SM100A"
OFF
)
option
(
SGL_KERNEL_ENABLE_SM90A
"Enable SM90A"
OFF
)
option
(
SGL_KERNEL_ENABLE_BF16
"Enable BF16"
ON
)
option
(
SGL_KERNEL_ENABLE_BF16
"Enable BF16"
ON
)
option
(
SGL_KERNEL_ENABLE_FP8
"Enable FP8"
ON
)
option
(
SGL_KERNEL_ENABLE_FP8
"Enable FP8"
ON
)
option
(
SGL_KERNEL_ENABLE_FP4
"Enable FP4"
OFF
)
option
(
SGL_KERNEL_ENABLE_FP4
"Enable FP4"
OFF
)
option
(
SGL_KERNEL_ENABLE_FA3
"Enable FA3"
OFF
)
option
(
SGL_KERNEL_ENABLE_FA3
"Enable FA3"
OFF
)
option
(
SGL_KERNEL_ENABLE_SM90A
"Enable SM90A"
OFF
)
option
(
SGL_KERNEL_ENABLE_SM100A
"Enable SM100A"
OFF
)
if
(
SGL_KERNEL_ENABLE_BF16
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_BF16"
)
endif
()
if
(
SGL_KERNEL_ENABLE_FP8
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_FP8"
"-DFLASHINFER_ENABLE_FP8_E4M3"
"-DFLASHINFER_ENABLE_FP8_E5M2"
)
endif
()
if
(
ENABLE_BELOW_SM90
)
if
(
ENABLE_BELOW_SM90
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
...
@@ -210,31 +231,12 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
...
@@ -210,31 +231,12 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.4" OR SGL_KERNEL_ENABLE_SM90A)
)
)
endif
()
endif
()
if
(
SGL_KERNEL_ENABLE_BF16
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_BF16"
)
endif
()
if
(
SGL_KERNEL_ENABLE_FP8
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-DFLASHINFER_ENABLE_FP8"
"-DFLASHINFER_ENABLE_FP8_E4M3"
"-DFLASHINFER_ENABLE_FP8_E5M2"
)
endif
()
if
(
"
${
CUDA_VERSION
}
"
VERSION_GREATER_EQUAL
"12.8"
OR SGL_KERNEL_ENABLE_FP4
)
if
(
"
${
CUDA_VERSION
}
"
VERSION_GREATER_EQUAL
"12.8"
OR SGL_KERNEL_ENABLE_FP4
)
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
list
(
APPEND SGL_KERNEL_CUDA_FLAGS
"-DENABLE_NVFP4=1"
"-DENABLE_NVFP4=1"
)
)
endif
()
endif
()
string
(
REPLACE
"-D__CUDA_NO_HALF_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
string
(
REPLACE
"-D__CUDA_NO_HALF_CONVERSIONS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
string
(
REPLACE
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
string
(
REPLACE
"-D__CUDA_NO_HALF2_OPERATORS__"
""
CMAKE_CUDA_FLAGS
"
${
CMAKE_CUDA_FLAGS
}
"
)
set
(
SOURCES
set
(
SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/custom_all_reduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
"csrc/allreduce/mscclpp_allreduce.cu"
...
...
sgl-kernel/Makefile
View file @
d631290e
...
@@ -21,12 +21,11 @@ submodule: ## Initialize and update git submodules
...
@@ -21,12 +21,11 @@ submodule: ## Initialize and update git submodules
ln
:
submodule
##
Create compilation database
ln
:
submodule
##
Create compilation database
@
rm
-rf
build
&&
mkdir
build
&&
cd
build
&&
cmake ..
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
YES
-DCMAKE_POLICY_VERSION_MINIMUM
=
3.5
@
rm
-rf
build
&&
mkdir
build
&&
cd
build
&&
cmake ..
-DCMAKE_EXPORT_COMPILE_COMMANDS
=
YES
-DCMAKE_POLICY_VERSION_MINIMUM
=
3.5
install
:
submodule
##
Install package in development mode
install
:
submodule
##
Install package in development mode
@
pip
install
-e
.
--no-build-isolation
@
pip
install
-e
.
--no-build-isolation
build
:
install-deps submodule
##
Build and install wheel package
build
:
install-deps submodule
##
Build and install wheel package
@
rm
-rf
dist/
*
||
true
&&
export
MAX_JOBS
=
$(nproc)
&&
CMAKE_POLICY_VERSION_MINIMUM
=
3.5
CMAKE_BUILD_PARALLEL_LEVEL
=
$(nproc)
uv build
--wheel
-Cbuild-dir
=
build
.
--verbose
--color
=
always
--no-build-isolation
&&
pip3
install
dist/
*
whl
--force-reinstall
--no-deps
@
rm
-rf
dist/
*
||
true
&&
CMAKE_POLICY_VERSION_MINIMUM
=
3.5
MAX_JOBS
=
$(nproc)
CMAKE_BUILD_PARALLEL_LEVEL
=
$(nproc)
uv build
--wheel
-Cbuild-dir
=
build
.
--verbose
--color
=
always
--no-build-isolation
&&
pip3
install
dist/
*
whl
--force-reinstall
--no-deps
clean
:
##
Remove build artifacts
clean
:
##
Remove build artifacts
@
rm
-rf
build dist
*
.egg-info
@
rm
-rf
build dist
*
.egg-info
...
...
sgl-kernel/csrc/attention/cutlass_mla_kernel.cu
View file @
d631290e
...
@@ -162,7 +162,7 @@ typename T::Fmha::Arguments args_from_options(
...
@@ -162,7 +162,7 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
// perform worse with larger context length and smaller batch sizes.
num_kv_splits
,
// split_kv
static_cast
<
int
>
(
num_kv_splits
)
,
// split_kv
nullptr
,
// is_var_split_kv
nullptr
,
// is_var_split_kv
};
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
...
@@ -259,7 +259,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
...
@@ -259,7 +259,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
// Assumes device 0 when getting sm_count.
// Assumes device 0 when getting sm_count.
arguments
.
hw_info
.
sm_count
=
arguments
.
hw_info
.
sm_count
=
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
sm_count
<=
0
?
cutlass
::
KernelHardwareInfo
::
query_device_multiprocessor_count
(
/*device_id=*/
0
)
:
sm_count
;
arguments
.
split_kv
=
num_kv_splits
;
arguments
.
split_kv
=
static_cast
<
int
>
(
num_kv_splits
)
;
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
MlaSm100Type
::
Fmha
::
set_split_kv
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
return
MlaSm100Type
::
Fmha
::
get_workspace_size
(
arguments
);
...
...
sgl-kernel/csrc/gemm/dsv3_fused_a_gemm.cu
View file @
d631290e
...
@@ -131,6 +131,7 @@ __device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit) {
...
@@ -131,6 +131,7 @@ __device__ bool try_wait_barrier(uint64_t* smem_ptr, int phase_bit) {
:
"r"
(
smem_int_ptr
),
"r"
(
phase_bit
));
:
"r"
(
smem_int_ptr
),
"r"
(
phase_bit
));
return
static_cast
<
bool
>
(
wait_complete
);
return
static_cast
<
bool
>
(
wait_complete
);
#endif
#endif
return
false
;
}
}
// Barrier arrive
// Barrier arrive
...
...
sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu
View file @
d631290e
...
@@ -541,6 +541,11 @@ void quant_impl(
...
@@ -541,6 +541,11 @@ void quant_impl(
}
}
}
}
// Avoid redefinition warnings
#undef CHECK_CONTIGUOUS
#undef CHECK_TH_CUDA
#undef CHECK_INPUT
/*Quantization entry for fp4 experts quantization*/
/*Quantization entry for fp4 experts quantization*/
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
#define CHECK_CONTIGUOUS(x, m) TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment