"sgl-kernel/python/sgl_kernel/__init__.py" did not exist on "b02da24a5b8cc0b8e4971f59a7e0f8afcfeab9b3"
Unverified Commit 6619f48e authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix cu118 group gemm compile issue (#3097)

parent 3ed0a547
...@@ -62,6 +62,23 @@ nvcc_flags = [ ...@@ -62,6 +62,23 @@ nvcc_flags = [
"-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_F16",
] ]
sources = [
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/group_gemm.cu",
"3rdparty/flashinfer/csrc/norm.cu",
"3rdparty/flashinfer/csrc/sampling.cu",
"3rdparty/flashinfer/csrc/renorm.cu",
]
enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1"
enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1"
enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1"
...@@ -71,6 +88,7 @@ sm_version = _get_device_sm() ...@@ -71,6 +88,7 @@ sm_version = _get_device_sm()
if torch.cuda.is_available(): if torch.cuda.is_available():
if cuda_version >= (12, 0) and sm_version >= 90: if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
if sm_version >= 90: if sm_version >= 90:
nvcc_flags.extend( nvcc_flags.extend(
[ [
...@@ -85,6 +103,7 @@ else: ...@@ -85,6 +103,7 @@ else:
# compilation environment without GPU # compilation environment without GPU
if enable_sm90a: if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
if enable_fp8: if enable_fp8:
nvcc_flags.extend( nvcc_flags.extend(
[ [
...@@ -110,26 +129,11 @@ for flag in [ ...@@ -110,26 +129,11 @@ for flag in [
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"] libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [ ext_modules = [
CUDAExtension( CUDAExtension(
name="sgl_kernel.ops._kernels", name="sgl_kernel.ops._kernels",
sources=[ sources=sources,
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/group_gemm.cu",
"3rdparty/flashinfer/csrc/group_gemm_sm90.cu",
"3rdparty/flashinfer/csrc/norm.cu",
"3rdparty/flashinfer/csrc/sampling.cu",
"3rdparty/flashinfer/csrc/renorm.cu",
],
include_dirs=include_dirs, include_dirs=include_dirs,
extra_compile_args={ extra_compile_args={
"nvcc": nvcc_flags, "nvcc": nvcc_flags,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment