Commit d297cda2 authored by lizhigong's avatar lizhigong
Browse files

add merge_state_v2 in sgl_kernel

parent 63c8d8d0
...@@ -185,6 +185,7 @@ elif _is_hip: ...@@ -185,6 +185,7 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import ( from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize, awq_dequantize_triton as awq_dequantize,
) )
from sgl_kernel import merge_state_v2
elif _is_npu: elif _is_npu:
import custom_ops # noqa: F401 import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401 import sgl_kernel_npu # noqa: F401
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <algorithm> #include <algorithm>
#include <optional> #include <optional>
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils_rocm.h"
// Helper functions to convert between different data types // Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel. // (float, half, bfloat16) for the merge attention states kernel.
...@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) { ...@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s); d = __float2bfloat16(s);
} }
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template <typename scalar_t, const uint NUM_THREADS> template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel( __global__ void merge_attn_states_kernel(
......
...@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()"); m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick); m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/attention
*/
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
/* /*
* From csrc/allreduce * From csrc/allreduce
*/ */
......
...@@ -50,6 +50,7 @@ sources = [ ...@@ -50,6 +50,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu", "csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu", "csrc/speculative/eagle_utils.cu",
"csrc/kvcacheio/transfer.cu", "csrc/kvcacheio/transfer.cu",
"csrc/attention/merge_attn_states.cu",
] ]
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment