Unverified Commit 6b39f9cf authored by Rain Jiang's avatar Rain Jiang Committed by GitHub
Browse files

Support compile sgl-kernel on cuda 13.0 (#9721)

parent 07c9d8fb
...@@ -78,7 +78,7 @@ FetchContent_Populate(repo-triton) ...@@ -78,7 +78,7 @@ FetchContent_Populate(repo-triton)
FetchContent_Declare( FetchContent_Declare(
repo-flashinfer repo-flashinfer
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7 GIT_TAG 018b551825c8e5579206e6eb9d3229fa679202b3
GIT_SHALLOW OFF GIT_SHALLOW OFF
) )
FetchContent_Populate(repo-flashinfer) FetchContent_Populate(repo-flashinfer)
...@@ -174,11 +174,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A) ...@@ -174,11 +174,28 @@ if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "12.8" OR SGL_KERNEL_ENABLE_SM100A)
list(APPEND SGL_KERNEL_CUDA_FLAGS list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_100,code=sm_100" "-gencode=arch=compute_100,code=sm_100"
"-gencode=arch=compute_100a,code=sm_100a" "-gencode=arch=compute_100a,code=sm_100a"
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
"-gencode=arch=compute_120,code=sm_120" "-gencode=arch=compute_120,code=sm_120"
"-gencode=arch=compute_120a,code=sm_120a" "-gencode=arch=compute_120a,code=sm_120a"
) )
# refer sm_121, sm_110 and sm_101 description https://github.com/pytorch/pytorch/pull/156176
if ("${CUDA_VERSION}" VERSION_GREATER_EQUAL "13.0")
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_103,code=sm_103"
"-gencode=arch=compute_103a,code=sm_103a"
"-gencode=arch=compute_110,code=sm_110"
"-gencode=arch=compute_110a,code=sm_110a"
"-gencode=arch=compute_121,code=sm_121"
"-gencode=arch=compute_121a,code=sm_121a"
"--compress-mode=size"
)
else()
list(APPEND SGL_KERNEL_CUDA_FLAGS
"-gencode=arch=compute_101,code=sm_101"
"-gencode=arch=compute_101a,code=sm_101a"
)
endif()
else() else()
list(APPEND SGL_KERNEL_CUDA_FLAGS list(APPEND SGL_KERNEL_CUDA_FLAGS
"-use_fast_math" "-use_fast_math"
...@@ -261,12 +278,6 @@ set(SOURCES ...@@ -261,12 +278,6 @@ set(SOURCES
"csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_moe_data.cu"
"csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu" "csrc/moe/cutlass_moe/w4a8/w4a8_grouped_mm_c3x.cu"
"csrc/moe/marlin_moe_wna16/ops.cu" "csrc/moe/marlin_moe_wna16/ops.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_bf16_ku8b128.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku4b8.cu"
"csrc/moe/marlin_moe_wna16/kernel_fp16_ku8b128.cu"
"csrc/moe/moe_align_kernel.cu" "csrc/moe/moe_align_kernel.cu"
"csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/moe_topk_softmax_kernels.cu"
......
...@@ -9,6 +9,7 @@ import jinja2 ...@@ -9,6 +9,7 @@ import jinja2
FILE_HEAD = """ FILE_HEAD = """
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
...@@ -33,6 +34,17 @@ TEMPLATE = ( ...@@ -33,6 +34,17 @@ TEMPLATE = (
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
KERNEL_FILE_TEMPLATE = (
"// auto generated by generate.py\n"
"// clang-format off\n"
"#pragma once\n\n"
"{% for kernel_file in kernel_files %}"
'#include "{{ kernel_file }}"\n'
"{% endfor %}"
)
KERNEL_FILE_NAME = "kernel_marlin.cuh"
# int8 with zero point case (sglang::kU8) is also supported, # int8 with zero point case (sglang::kU8) is also supported,
# we don't add it to reduce wheel size. # we don't add it to reduce wheel size.
SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"] SCALAR_TYPES = ["sglang::kU4", "sglang::kU4B8", "sglang::kU8B128"]
...@@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"] ...@@ -48,11 +60,12 @@ DTYPES = ["fp16", "bf16"]
def remove_old_kernels(): def remove_old_kernels():
for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cuh"):
subprocess.call(["rm", "-f", filename]) subprocess.call(["rm", "-f", filename])
def generate_new_kernels(): def generate_new_kernels():
kernel_files = set()
for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES):
has_zp = "B" not in scalar_type has_zp = "B" not in scalar_type
all_template_str_list = [] all_template_str_list = []
...@@ -95,10 +108,20 @@ def generate_new_kernels(): ...@@ -95,10 +108,20 @@ def generate_new_kernels():
file_content = FILE_HEAD + "\n\n" file_content = FILE_HEAD + "\n\n"
file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" file_content += "\n\n".join(all_template_str_list) + "\n\n}\n"
filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cu" filename = f"kernel_{dtype}_{scalar_type[8:].lower()}.cuh"
with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: with open(os.path.join(os.path.dirname(__file__), filename), "w") as f:
f.write(file_content) f.write(file_content)
kernel_files.add(filename)
kernel_files = list(kernel_files)
kernel_files.sort()
file_content = jinja2.Template(KERNEL_FILE_TEMPLATE).render(
kernel_files=kernel_files
)
with open(os.path.join(os.path.dirname(__file__), KERNEL_FILE_NAME), "w") as f:
f.write(file_content)
if __name__ == "__main__": if __name__ == "__main__":
......
#pragma once
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py // auto generated by generate.py
// clang-format off // clang-format off
#pragma once
#include "kernel.h" #include "kernel.h"
#include "marlin_template.h" #include "marlin_template.h"
......
// auto generated by generate.py
// clang-format off
#pragma once
#include "kernel_bf16_ku4.cuh"
#include "kernel_bf16_ku4b8.cuh"
#include "kernel_bf16_ku8b128.cuh"
#include "kernel_fp16_ku4.cuh"
#include "kernel_fp16_ku4b8.cuh"
#include "kernel_fp16_ku8b128.cuh"
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
/* /*
* Adapted from https://github.com/IST-DASLab/marlin * Adapted from https://github.com/IST-DASLab/marlin
*/ */
#pragma once
#ifndef MARLIN_NAMESPACE_NAME #ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#endif #endif
#include "kernel.h" #include "kernel.h"
#include "kernel_marlin.cuh"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert( \ static_assert( \
......
...@@ -23,6 +23,7 @@ limitations under the License. ...@@ -23,6 +23,7 @@ limitations under the License.
#ifndef USE_ROCM #ifndef USE_ROCM
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <cub/util_type.cuh> #include <cub/util_type.cuh>
#include <cuda/functional>
#else #else
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
#include <hipcub/util_type.hpp> #include <hipcub/util_type.hpp>
...@@ -33,6 +34,16 @@ limitations under the License. ...@@ -33,6 +34,16 @@ limitations under the License.
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
// Define reduction operators based on CUDA version
// CUDA 13 (12.9+) deprecated cub::Max/Min in favor of cuda::maximum/minimum
#if CUDA_VERSION >= 12090
using MaxReduceOp = cuda::maximum<>;
using MinReduceOp = cuda::minimum<>;
#else
using MaxReduceOp = cub::Max;
using MinReduceOp = cub::Min;
#endif
/// Aligned array type /// Aligned array type
template < template <
typename T, typename T,
...@@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__ ...@@ -72,7 +83,6 @@ __launch_bounds__(TPB) __global__
const int thread_row_offset = blockIdx.x * num_cols; const int thread_row_offset = blockIdx.x * num_cols;
cub::Sum sum;
float threadData(-FLT_MAX); float threadData(-FLT_MAX);
// Don't touch finished rows. // Don't touch finished rows.
...@@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__ ...@@ -85,7 +95,7 @@ __launch_bounds__(TPB) __global__
threadData = max(convert_to_float<T>(input[idx]), threadData); threadData = max(convert_to_float<T>(input[idx]), threadData);
} }
const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
float_max = maxElem; float_max = maxElem;
...@@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__ ...@@ -99,7 +109,7 @@ __launch_bounds__(TPB) __global__
threadData += exp((convert_to_float<T>(input[idx]) - float_max)); threadData += exp((convert_to_float<T>(input[idx]) - float_max));
} }
const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); const auto Z = BlockReduce(tmpStorage).Sum(threadData);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
normalizing_factor = 1.f / Z; normalizing_factor = 1.f / Z;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment