Unverified Commit b62e7e99 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: adapt merge_state (#5337)

parent 7d3b7c87
......@@ -44,6 +44,12 @@ jobs:
cuda-version: '12.8'
name: Build Wheel (CUDA ${{ matrix.cuda-version }})
steps:
- name: Skip unnecessary builds on push to main
if: github.event_name == 'push' && (matrix.cuda-version == '11.8' || matrix.cuda-version == '12.8')
run: |
echo "Skipping CUDA ${{ matrix.cuda-version }} build on push to main"
exit 0
- name: Cleanup
run: |
sudo rm -rf $GITHUB_WORKSPACE/* || true
......@@ -87,7 +93,7 @@ jobs:
- name: Install
run: |
bash scripts/ci_install_dependency.sh
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2
pip3 install torch==2.5.1 && pip3 install pytest
pip3 uninstall sgl-kernel -y || true
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
pip3 list | grep sgl-kernel
......
......@@ -25,6 +25,8 @@ find_package(Torch REQUIRED)
# clean Torch Flag
clear_cuda_arches(CMAKE_FLAG)
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
include(FetchContent)
# cutlass
......@@ -104,6 +106,7 @@ set(SGL_KERNEL_CUDA_FLAGS
"--expt-relaxed-constexpr"
"-Xcompiler=-Wconversion"
"-Xcompiler=-fno-strict-aliasing"
"--threads=16"
)
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
......@@ -160,6 +163,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
set(SOURCES
"csrc/allreduce/custom_all_reduce.cu"
"csrc/attention/cascade.cu"
"csrc/attention/cutlass_mla_kernel.cu"
"csrc/attention/lightning_attention_decode_kernel.cu"
"csrc/elementwise/activation.cu"
......
// Adapted from
// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <flashinfer/attention/cascade.cuh>
#include "pytorch_extension_utils.h"
using namespace flashinfer;
void merge_state(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
CHECK_INPUT(v_a);
CHECK_INPUT(s_a);
CHECK_INPUT(v_b);
CHECK_INPUT(s_b);
auto device = v_a.device();
CHECK_EQ(s_a.device(), device);
CHECK_EQ(v_b.device(), device);
CHECK_EQ(s_b.device(), device);
CHECK_DIM(3, v_a);
CHECK_DIM(2, s_a);
CHECK_DIM(3, v_b);
CHECK_DIM(2, s_b);
CHECK_SHAPE(v_a, v_b);
CHECK_SHAPE(s_a, s_b);
CHECK_EQ(v_a.size(0), s_a.size(0));
CHECK_EQ(v_a.size(1), s_b.size(1));
unsigned int seq_len = v_a.size(0);
unsigned int num_heads = v_a.size(1);
unsigned int head_dim = v_a.size(2);
const c10::cuda::OptionalCUDAGuard device_guard(v_a.device());
auto stream = at::cuda::getCurrentCUDAStream();
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] {
cudaError_t status = MergeState(
static_cast<c_type*>(v_a.data_ptr()),
static_cast<float*>(s_a.data_ptr()),
static_cast<c_type*>(v_b.data_ptr()),
static_cast<float*>(s_b.data_ptr()),
static_cast<c_type*>(v_merged.data_ptr()),
static_cast<float*>(s_merged.data_ptr()),
seq_len,
num_heads,
head_dim,
stream);
TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status));
return true;
});
TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type");
}
......@@ -45,6 +45,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
"new_kv) -> ()");
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state", torch::kCUDA, &merge_state);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor workspace) -> ()");
......
......@@ -87,6 +87,8 @@ void lightning_attention_decode(
const torch::Tensor& slope,
torch::Tensor output,
torch::Tensor new_kv);
void merge_state(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
......
......@@ -15,6 +15,7 @@ from sgl_kernel.attention import (
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
lightning_attention_decode,
merge_state,
)
from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace,
......
from typing import Tuple
import torch
......@@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
)
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
v_merged = torch.empty_like(v_a)
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def cutlass_mla_decode(
q_nope_and_q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
......@@ -54,7 +67,7 @@ def cutlass_mla_decode(
(B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
)
torch.ops.sgl_kernel.cutlass_mla_decode(
torch.ops.sgl_kernel.cutlass_mla_decode.default(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
)
return out
......@@ -63,6 +76,6 @@ def cutlass_mla_decode(
def cutlass_mla_get_workspace_size(
max_seq_len: int, num_batches: int, sm_count: int = 0
) -> int:
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size(
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
max_seq_len, num_batches, sm_count
)
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py
from typing import List
import pytest
import torch
import triton
import triton.language as tl
from sgl_kernel import merge_state
def check_input(x: torch.Tensor):
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
assert x.is_contiguous(), f"{str(x)} must be contiguous"
def check_dim(d, x: torch.Tensor):
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
def check_shape(a: torch.Tensor, b: torch.Tensor):
assert a.dim() == b.dim(), "tensors should have same dim"
for i in range(a.dim()):
assert a.size(i) == b.size(
i
), f"tensors shape mismatch, {a.size()} and {b.size()}"
def check_device(tensors: List[torch.Tensor]):
device = tensors[0].device
for t in tensors:
assert (
t.device == device
), f"All tensors should be on the same device, but got {device} and {t.device}"
@triton.jit
def state_merge(o, m, d, other_o, other_m, other_d):
m_max = tl.maximum(m, other_m)
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
return o, m_max, d
@triton.jit
def state_normalize(o, m, d):
o = o / d
return o, m, d
@triton.jit
def state_get_lse(o, m, d):
return m + tl.log2(d)
@triton.jit
def merge_state_kernel(
v_a_ptr,
s_a_ptr,
v_b_ptr,
s_b_ptr,
v_merged_ptr,
s_merged_ptr,
num_heads,
head_dim,
bdx: tl.constexpr,
bdy: tl.constexpr,
):
pos = tl.program_id(axis=0)
for tx in tl.range(bdx):
for head_idx in tl.range(bdy):
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
offsets = (pos * num_heads + head_idx) * head_dim + tx
v_a = tl.load(v_a_ptr + offsets)
v_b = tl.load(v_b_ptr + offsets)
v_merged, s_max, d = state_merge(
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
)
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
tl.store(v_merged_ptr + v_merged_offset, v_merged)
if s_merged_ptr:
tl.store(
s_merged_ptr + pos * num_heads + head_idx,
tl.log2(d) + s_max,
)
def merge_state_triton(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
):
check_input(v_a)
check_input(s_a)
check_input(v_b)
check_input(s_b)
check_device([v_a, s_a, v_b, s_b])
check_dim(3, v_a)
check_dim(2, s_a)
check_dim(3, v_b)
check_dim(2, s_b)
check_shape(v_a, v_b)
check_shape(s_a, s_b)
assert v_a.size(0) == s_a.size(0)
assert v_a.size(1) == s_b.size(1)
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
seq_len = v_a.size(0)
num_heads = v_a.size(1)
head_dim = v_a.size(2)
v_merged = torch.empty_like(v_a).to(s_a.device)
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
bdx = head_dim
bdy = num_heads
merge_state_kernel[lambda meta: (seq_len,)](
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
)
return v_merged, s_merged
@pytest.mark.parametrize("seq_len", [2048])
@pytest.mark.parametrize("num_heads", [32])
@pytest.mark.parametrize("head_dim", [128])
def test_merge_state(seq_len, num_heads, head_dim):
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
v_merged, s_merged = merge_state_triton(va, sa, vb, sb)
v_merged_std, s_merged_std = merge_state(va, sa, vb, sb)
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment