"vscode:/vscode.git/clone" did not exist on "ad3ca3e74eaef86b5e061b27706849601db33430"
Unverified Commit 95f789ad authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

minor: cleanup sgl-kernel (#3143)

parent 4f118a39
...@@ -40,6 +40,10 @@ Development build: ...@@ -40,6 +40,10 @@ Development build:
make build make build
``` ```
Note:
The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`.
### Testing & Benchmarking ### Testing & Benchmarking
1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests) 1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests)
......
...@@ -82,10 +82,8 @@ sources = [ ...@@ -82,10 +82,8 @@ sources = [
"src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/rotary_embedding.cu",
"src/sgl-kernel/csrc/fused_add_rms_norm.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
"3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu",
"3rdparty/flashinfer/csrc/group_gemm.cu", "3rdparty/flashinfer/csrc/group_gemm.cu",
......
// Adapted from
// https://github.com/InternLM/lmdeploy/blob/800b6010c0bf76aadf678bc38a507b749fb9774c/src/turbomind/kernels/norm/rms_norm.cu
#include <turbomind/kernels/core/array_ops.h>
#include <turbomind/kernels/core/common.h>
#include <cub/block/block_reduce.cuh>
using namespace turbomind;
template <class T, class Tacc, int block_dim, int vec_size>
__global__ void BiasResidualRMSNormKernel(T* __restrict__ residual, T* __restrict__ hidden_states,
const T* __restrict__ weights, const T* __restrict__ bias, int dims, int num,
float eps, float inv_dims) {
const int ti = blockIdx.x;
const int di = threadIdx.x * vec_size;
if (ti >= num) {
return;
}
residual += dims * ti;
hidden_states += dims * ti;
Array<Tacc, vec_size> accum{};
Array<T, vec_size> r_vec;
Array<T, vec_size> h_vec;
Array<T, vec_size> b_vec;
for (int i = di; i < dims; i += block_dim * vec_size) {
Load(r_vec, &residual[i]);
Load(h_vec, &hidden_states[i]);
using namespace ops;
r_vec = r_vec + h_vec;
if (bias) {
Ldg(b_vec, &bias[i]);
r_vec = r_vec + b_vec;
}
Store(&residual[i], r_vec);
Array<Tacc, vec_size> tmp = cast<Tacc>(r_vec);
accum = accum + tmp * tmp;
}
float sum{};
PRAGMA_UNROLL
for (int i = 0; i < vec_size; ++i) {
sum += accum[i];
}
using BlockReduce = cub::BlockReduce<Tacc, block_dim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
sum = BlockReduce{temp_storage}.Sum(sum);
__shared__ float shared_sum;
if (threadIdx.x == 0) {
shared_sum = rsqrtf(sum * inv_dims + eps);
}
__syncthreads();
sum = shared_sum;
Array<T, vec_size> w_vec;
for (int i = di; i < dims; i += block_dim * vec_size) {
Load(r_vec, &residual[i]);
Ldg(w_vec, &weights[i]);
PRAGMA_UNROLL
for (int c = 0; c < vec_size; ++c) {
r_vec[c] = (T)((float)r_vec[c] * sum) * w_vec[c];
}
Store(&hidden_states[i], r_vec);
}
}
template <class T>
void invokeBiasResidualRMSNorm(T* residual, T* hidden_states, const T* weights, const T* bias, int dims, int num,
float eps, cudaStream_t st) {
constexpr int vec_size = 16 / sizeof(T);
constexpr int threads = 512;
const int blocks = num;
BiasResidualRMSNormKernel<T, float, threads, vec_size>
<<<blocks, threads, 0, st>>>(residual, hidden_states, weights, bias, dims, num, eps, 1.f / dims);
}
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <torch/extension.h>
#include "utils.h"
#define THREADS_PER_BLOCK 128 #define THREADS_PER_BLOCK 128
......
...@@ -3,28 +3,14 @@ ...@@ -3,28 +3,14 @@
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
#include "utils.h"
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ #define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
#define CEILDIV(x, y) (((x) + (y)-1) / (y)) #define CEILDIV(x, y) (((x) + (y)-1) / (y))
...@@ -39,7 +25,6 @@ ...@@ -39,7 +25,6 @@
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) {
// don't worry about overflow because num_experts is relatively small
return row * total_col + col; return row * total_col + col;
} }
......
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
template <typename scalar_t>
__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
scalar_t* output, const int32_t numel) {
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t stride = blockDim.x * gridDim.x;
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using vec_t = flashinfer::vec_t<scalar_t, vec_size>;
const int32_t num_vec_elems = numel / vec_size;
#pragma unroll 1
for (int32_t i = tid; i < num_vec_elems; i += stride) {
vec_t logits_vec, penalties_vec, out_vec;
logits_vec.cast_load(logits + i * vec_size);
penalties_vec.cast_load(scaling_penalties + i * vec_size);
#pragma unroll
for (uint32_t j = 0; j < vec_size; ++j) {
out_vec[j] = logits_vec[j] > scalar_t(0.0f) ? logits_vec[j] / penalties_vec[j] : logits_vec[j] * penalties_vec[j];
}
out_vec.cast_store(output + i * vec_size);
}
// process the remaining elements
const int32_t start_idx = num_vec_elems * vec_size;
for (int32_t i = start_idx + tid; i < numel; i += stride) {
scalar_t logit = logits[i];
scalar_t penalty = scaling_penalties[i];
output[i] = logit > scalar_t(0.0f) ? logit / penalty : logit * penalty;
}
}
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) {
auto output = torch::empty_like(logits);
const auto numel = logits.numel();
const int threads = 512;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(logits.scalar_type(), scalar_t, [&] {
uint32_t vec_size = 16 / sizeof(scalar_t);
const int blocks = (numel + threads * vec_size - 1) / (threads * vec_size);
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
static_cast<scalar_t*>(logits.data_ptr()), static_cast<scalar_t*>(scaling_penalties.data_ptr()),
static_cast<scalar_t*>(output.data_ptr()), numel);
return true;
});
return output;
}
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tuple> #include <tuple>
#include "trt_reduce_internal.cuh" #include "trt_reduce_internal.cuh"
#include "utils.h"
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <cassert> #include <cassert>
#include "trt_reduce_internal.cuh" #include "trt_reduce_internal.cuh"
#include "utils.h"
using namespace trt_llm; using namespace trt_llm;
......
#pragma once #pragma once
#include <Python.h> #include <Python.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <vector> #include <vector>
#include "utils.h"
#define _CONCAT(A, B) A##B #define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B) #define CONCAT(A, B) _CONCAT(A, B)
...@@ -36,9 +35,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b ...@@ -36,9 +35,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);
// sampling_scaling_penalties
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties);
// int8_scaled_mm // int8_scaled_mm
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
......
...@@ -17,12 +17,11 @@ ...@@ -17,12 +17,11 @@
*/ */
#pragma once #pragma once
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdint.h> #include <stdint.h>
#include <torch/all.h> #include <torch/all.h>
#include "utils.h"
namespace trt_llm { namespace trt_llm {
constexpr size_t WARP_SIZE = 32; constexpr size_t WARP_SIZE = 32;
constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36;
......
#pragma once #pragma once
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <pytorch_extension_utils.h> #include <pytorch_extension_utils.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <sstream> #include <sstream>
#include "sgl_kernels_ops.h"
struct cuda_error : public std::runtime_error { struct cuda_error : public std::runtime_error {
/** /**
* @brief Constructs a `cuda_error` object with the given `message`. * @brief Constructs a `cuda_error` object with the given `message`.
......
...@@ -28,10 +28,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { ...@@ -28,10 +28,6 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// sampling_scaling_penalties
m.def("sampling_scaling_penalties(Tensor logits, Tensor scaling_penalties) -> Tensor");
m.impl("sampling_scaling_penalties", torch::kCUDA, &sampling_scaling_penalties);
// int8_scaled_mm // int8_scaled_mm
m.def( m.def(
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
......
import pytest
import torch
from sgl_kernel import sampling_scaling_penalties
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
dtypes = [torch.float32, torch.half, torch.bfloat16]
@pytest.mark.parametrize("batch_size", batch_sizes)
@pytest.mark.parametrize("vocab_size", vocab_sizes)
@pytest.mark.parametrize("dtype", dtypes)
def test_sampling_scaling_penalties(batch_size, vocab_size, dtype):
device = torch.device("cuda")
rtol = 1e-3
atol = 1e-3
logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype)
scaling_penalties = (
torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5
)
ref_output = torch.where(
logits > 0, logits / scaling_penalties, logits * scaling_penalties
)
kernel_output = sampling_scaling_penalties(logits, scaling_penalties)
torch.testing.assert_close(
kernel_output,
ref_output,
rtol=rtol,
atol=atol,
msg=f"Failed for batch_size={batch_size}, vocab_size={vocab_size}, dtype={dtype}",
)
if __name__ == "__main__":
pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment