Commit 698bc661 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_dev_lzg' into 'v0.5.4_dev'

fix bug on flash attention use for chunkprefill and radix cache

See merge request OpenDAS/sglang!14
parents 1eaad6d1 8d0b2f15
......@@ -86,9 +86,10 @@ def flash_attn_varlen_func(
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=return_softmax_lse,
)
\ No newline at end of file
......@@ -185,6 +185,7 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
from sgl_kernel import merge_state_v2
elif _is_npu:
import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401
......
......@@ -4,7 +4,7 @@
#include <algorithm>
#include <optional>
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils_rocm.h"
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
......@@ -27,6 +27,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s);
}
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
......
......@@ -34,6 +34,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/attention
*/
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
/*
* From csrc/allreduce
*/
......
......@@ -50,6 +50,7 @@ sources = [
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/kvcacheio/transfer.cu",
"csrc/attention/merge_attn_states.cu",
]
cxx_flags = ["-O3"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment