Unverified Commit 04f0b4cb authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

minor: update sgl-kernel setup (#3107)

parent 4505a436
...@@ -38,10 +38,10 @@ def _get_version(): ...@@ -38,10 +38,10 @@ def _get_version():
return line.split("=")[1].strip().strip('"') return line.split("=")[1].strip().strip('"')
cutlass = root / "3rdparty" / "cutlass"
cutlass_default = root / "3rdparty" / "cutlass" cutlass_default = root / "3rdparty" / "cutlass"
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
flashinfer = root / "3rdparty" / "flashinfer" flashinfer = root / "3rdparty" / "flashinfer"
turbomind = root / "3rdparty" / "turbomind"
include_dirs = [ include_dirs = [
cutlass.resolve() / "include", cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include", cutlass.resolve() / "tools" / "util" / "include",
...@@ -49,6 +49,8 @@ include_dirs = [ ...@@ -49,6 +49,8 @@ include_dirs = [
flashinfer.resolve() / "include", flashinfer.resolve() / "include",
flashinfer.resolve() / "include" / "gemm", flashinfer.resolve() / "include" / "gemm",
flashinfer.resolve() / "csrc", flashinfer.resolve() / "csrc",
turbomind.resolve(),
turbomind.resolve() / "src",
] ]
nvcc_flags = [ nvcc_flags = [
"-DNDEBUG", "-DNDEBUG",
...@@ -63,6 +65,11 @@ nvcc_flags = [ ...@@ -63,6 +65,11 @@ nvcc_flags = [
"-use_fast_math", "-use_fast_math",
"-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_F16",
] ]
nvcc_flags_fp8 = [
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
sources = [ sources = [
"src/sgl-kernel/csrc/trt_reduce_internal.cu", "src/sgl-kernel/csrc/trt_reduce_internal.cu",
...@@ -73,6 +80,7 @@ sources = [ ...@@ -73,6 +80,7 @@ sources = [
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/group_gemm.cu", "3rdparty/flashinfer/csrc/group_gemm.cu",
...@@ -92,13 +100,7 @@ if torch.cuda.is_available(): ...@@ -92,13 +100,7 @@ if torch.cuda.is_available():
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
if sm_version >= 90: if sm_version >= 90:
nvcc_flags.extend( nvcc_flags.extend(nvcc_flags_fp8)
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
if sm_version >= 80: if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
else: else:
...@@ -107,13 +109,7 @@ else: ...@@ -107,13 +109,7 @@ else:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu")
if enable_fp8: if enable_fp8:
nvcc_flags.extend( nvcc_flags.extend(nvcc_flags_fp8)
[
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
)
if enable_bf16: if enable_bf16:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
......
// Adapted from
// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu
#include <turbomind/kernels/core/array_ops.h>
#include <turbomind/kernels/core/common.h>
#include <cub/block/block_reduce.cuh>
using namespace turbomind;
template <class T, class Tacc, int block_dim, int vec_size>
__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual, T* __restrict__ hidden_states,
const T* __restrict__ weights, const T* __restrict__ bias, int dims, int num,
float eps, float inv_dims) {
const int ti = blockIdx.x;
const int di = threadIdx.x * vec_size;
if (ti >= num) {
return;
}
residual += dims * ti;
hidden_states += dims * ti;
Array<Tacc, vec_size> accum{};
Array<T, vec_size> r_vec;
Array<T, vec_size> h_vec;
Array<T, vec_size> b_vec;
for (int i = di; i < dims; i += block_dim * vec_size) {
Load(r_vec, &residual[i]);
Load(h_vec, &hidden_states[i]);
using namespace ops;
r_vec = r_vec + h_vec;
if (bias) {
Ldg(b_vec, &bias[i]);
r_vec = r_vec + b_vec;
}
Store(&residual[i], r_vec);
Array<Tacc, vec_size> tmp = cast<Tacc>(r_vec);
accum = accum + tmp * tmp;
}
float sum{};
PRAGMA_UNROLL
for (int i = 0; i < vec_size; ++i) {
sum += accum[i];
}
using BlockReduce = cub::BlockReduce<Tacc, block_dim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum = BlockReduce{temp_storage}.Sum(sum);
__shared__ float shared_sum;
if (threadIdx.x == 0) {
shared_sum = rsqrtf(sum * inv_dims + eps);
}
__syncthreads();
sum = shared_sum;
Array<T, vec_size> w_vec;
for (int i = di; i < dims; i += block_dim * vec_size) {
Load(r_vec, &residual[i]);
Ldg(w_vec, &weights[i]);
PRAGMA_UNROLL
for (int c = 0; c < vec_size; ++c) {
r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c];
}
Store(&hidden_states[i], r_vec);
}
}
template <class T>
void invokeBiasResidualRMSNorm(T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num,
float eps, cudaStream_t st) {
constexpr int vec_size = 16 / sizeof(T);
constexpr int threads = 512;
const int blocks = num;
BiasResidualRMSNormKernel<T, float, threads, vec_size>
<<<blocks, threads, 0, st>>>(residual, hidden_states, weights, bias, dims, num, eps, 1.f / dims);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment