Unverified Commit ac2dc35d authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

support lightning_attention_decode in sgl-kernel for MiniMax-Text-01 (#3030)

parent 3e032c07
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
import triton import triton
import triton.language as tl import triton.language as tl
from einops import rearrange from einops import rearrange
from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode
@triton.jit @triton.jit
...@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params): ...@@ -332,7 +333,6 @@ def test_lightning_attention_implementations(model_params):
model_params["num_attention_heads"], model_params["num_attention_heads"],
d, d,
d, d,
dtype=dtype,
device=device, device=device,
) )
with torch.no_grad(): with torch.no_grad():
...@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params): ...@@ -350,7 +350,13 @@ def test_lightning_attention_implementations(model_params):
q = q.transpose(1, 2) q = q.transpose(1, 2)
k = k.transpose(1, 2) k = k.transpose(1, 2)
v = v.transpose(1, 2) v = v.transpose(1, 2)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
past_kv = past_kv.contiguous()
slope_rate = slope_rate.contiguous()
# Test Triton implementation
triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate)
triton_output = triton_output.transpose(1, 2).contiguous() triton_output = triton_output.transpose(1, 2).contiguous()
triton_output = triton_output.view(batch_size, seq_len, -1) triton_output = triton_output.view(batch_size, seq_len, -1)
...@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params): ...@@ -358,22 +364,50 @@ def test_lightning_attention_implementations(model_params):
triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output
triton_output = model_attn.out_proj(triton_output) triton_output = model_attn.out_proj(triton_output)
# Test SGL implementation
sgl_output = torch.empty_like(v)
sgl_new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv)
sgl_output = sgl_output.transpose(1, 2).contiguous()
sgl_output = sgl_output.view(batch_size, seq_len, -1)
sgl_output = model_attn.norm(sgl_output)
sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output
sgl_output = model_attn.out_proj(sgl_output)
# Verify Triton implementation results
torch.testing.assert_close( torch.testing.assert_close(
model_output, model_output,
triton_output, triton_output,
rtol=1e-3, rtol=1e-3,
atol=1e-2, atol=1e-2,
msg="Lightning attention implementations produce different output results", msg="Triton lightning attention implementation produces different output results",
) )
torch.testing.assert_close( torch.testing.assert_close(
new_kv, new_kv,
triton_new_kv, triton_new_kv,
rtol=1e-3, rtol=1e-3,
atol=1e-2, atol=1e-2,
msg="Lightning attention implementations produce different kv results", msg="Triton lightning attention implementation produces different kv results",
) )
print("✅ Two implementations match") # Verify SGL implementation results
torch.testing.assert_close(
model_output,
sgl_output,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different output results",
)
torch.testing.assert_close(
new_kv,
sgl_new_kv,
rtol=1e-3,
atol=1e-2,
msg="SGL lightning attention implementation produces different kv results",
)
print("✅ All implementations match")
def _build_slope_tensor(n_attention_heads: int): def _build_slope_tensor(n_attention_heads: int):
...@@ -408,12 +442,13 @@ def get_benchmark(): ...@@ -408,12 +442,13 @@ def get_benchmark():
x_names=["batch_size", "seq_len"], x_names=["batch_size", "seq_len"],
x_vals=[list(_) for _ in configs], x_vals=[list(_) for _ in configs],
line_arg="provider", line_arg="provider",
line_vals=["Original", "Triton"], line_vals=["Original", "Triton", "SGL"],
line_names=[ line_names=[
"Original PyTorch Implementation", "Original PyTorch Implementation",
"Triton Implementation", "Triton Implementation",
"SGL Implementation",
], ],
styles=[("blue", "-"), ("green", "-")], styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us", ylabel="us",
plot_name="lightning-attention-decode-performance", plot_name="lightning-attention-decode-performance",
args={}, args={},
...@@ -446,7 +481,6 @@ def get_benchmark(): ...@@ -446,7 +481,6 @@ def get_benchmark():
params["num_attention_heads"], params["num_attention_heads"],
d, d,
d, d,
dtype=dtype,
device=device, device=device,
) )
...@@ -461,7 +495,7 @@ def get_benchmark(): ...@@ -461,7 +495,7 @@ def get_benchmark():
), ),
quantiles=quantiles, quantiles=quantiles,
) )
else: elif provider == "Triton":
def run_triton(): def run_triton():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
...@@ -483,6 +517,33 @@ def get_benchmark(): ...@@ -483,6 +517,33 @@ def get_benchmark():
run_triton, run_triton,
quantiles=quantiles, quantiles=quantiles,
) )
else: # SGL
def run_sgl():
qkv = model_attn.act(model_attn.qkv_proj(hidden_states))
new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1)
qkv = qkv.view(*new_shape)
q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1)
q = q.transpose(1, 2).contiguous()
k = k.transpose(1, 2).contiguous()
v = v.transpose(1, 2).contiguous()
output = torch.empty_like(v)
new_kv = torch.empty_like(past_kv)
sgl_lightning_attention_decode(
q, k, v, past_kv, slope_rate, output, new_kv
)
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, -1)
output = model_attn.norm(output)
output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output
return model_attn.out_proj(output)
ms, min_ms, max_ms = triton.testing.do_bench(
run_sgl,
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
import itertools
import math
import torch
import triton
import triton.language as tl
from sgl_kernel import lightning_attention_decode
def next_power_of_2(n):
return 2 ** (int(math.ceil(math.log(n, 2))))
@triton.jit
def _decode_kernel(
Q,
K,
V,
KV,
Out,
S,
b: tl.constexpr,
h: tl.constexpr,
n: tl.constexpr,
d: tl.constexpr,
d_original: tl.constexpr,
e: tl.constexpr,
e_original: tl.constexpr,
):
off_bh = tl.program_id(0)
off_h = off_bh % h
qk_offset = off_bh * n * d
v_offset = off_bh * n * e
o_offset = off_bh * n * e
kv_offset = off_bh * d * e
s = tl.load(S + off_h)
ratio = tl.exp(-s)
d_idx = tl.arange(0, d)
e_idx = tl.arange(0, e)
# Create masks for original dimensions
d_mask = d_idx < d_original
e_mask = e_idx < e_original
# Load with masking
q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0)
k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0)
v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0)
# Load KV with 2D masking
kv = tl.load(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
mask=(d_mask[:, None] & e_mask[None, :]),
other=0.0,
)
# Compute outer product using element-wise operations
k_v_prod = k[:, None] * v[None, :]
kv = ratio * kv + k_v_prod
# Store KV with 2D masking
tl.store(
KV + kv_offset + d_idx[:, None] * e + e_idx[None, :],
kv.to(KV.dtype.element_ty),
mask=(d_mask[:, None] & e_mask[None, :]),
)
# Compute matrix-vector multiplication using element-wise operations and reduction
o = tl.sum(q[:, None] * kv, axis=0)
# Store output with masking
tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask)
def triton_lightning_attn_decode(q, k, v, kv, s):
"""Triton implementation of Lightning Attention decode operation"""
b, h, n, d = q.shape
e = v.shape[-1]
assert n == 1, "Sequence length must be 1 in decode mode"
# Get padded dimensions (power of 2)
d_padded = next_power_of_2(d)
e_padded = next_power_of_2(e)
# Create output tensor (padded)
o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
# Create padded tensors without actually padding the data
q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device)
k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device)
v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device)
kv_padded = torch.empty(
b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device
)
# Copy data to padded tensors
q_padded[..., :d] = q
k_padded[..., :d] = k
v_padded[..., :e] = v
kv_padded[..., :d, :e] = kv
# Launch kernel
grid = (b * h, 1)
_decode_kernel[grid](
q_padded,
k_padded,
v_padded,
kv_padded,
o_padded,
s,
b=b,
h=h,
n=n,
d=d_padded,
d_original=d,
e=e_padded,
e_original=e,
)
# Get unpadded outputs
o = o_padded[..., :e]
kv_out = kv_padded[..., :d, :e]
return o, kv_out
def lightning_attention_decode_naive(q, k, v, past_kv, slope):
"""Naive implementation of lightning attention decode"""
original_dtype = q.dtype
ratio = torch.exp(-slope) # [h, 1, 1]
kv = past_kv
b, h, n, d = q.shape
output = []
for i in range(n):
kv = ratio * kv.to(torch.float32) + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d",
q[:, :, i : i + 1].to(torch.float32),
kv.to(torch.float32),
)
output.append(qkv)
output = torch.concat(output, dim=-2)
return output.to(original_dtype), kv
def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv):
return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
def calculate_diff(batch_size):
dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
q = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
k = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
output_naive, new_kv_naive = lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
)
output_kernel = torch.empty_like(output_naive)
new_kv_kernel = torch.empty_like(new_kv_naive)
lightning_attention_decode_kernel(
q.clone(),
k.clone(),
v.clone(),
past_kv.clone(),
slope.clone(),
output_kernel,
new_kv_kernel,
)
output_triton, new_kv_triton = triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
)
if (
torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2)
and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2)
and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [i for i in range(1, 65)] # 1 to 128
configs = [(bs,) for bs in batch_size_range]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["naive", "kernel", "triton"],
line_names=["PyTorch Naive", "SGL Kernel", "Triton"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="lightning-attention-decode-performance",
args={},
)
)
def benchmark(batch_size, provider):
dtype = torch.bfloat16
device = torch.device("cuda")
num_heads = 64
head_dim = 96
seq_len = 1
q = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
k = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
v = torch.randn(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
quantiles=quantiles,
)
elif provider == "kernel":
output = torch.empty(
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: lightning_attention_decode_kernel(
q.clone(),
k.clone(),
v.clone(),
past_kv.clone(),
slope.clone(),
output,
new_kv,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/lightning_attention_decode_sgl/",
help="Path to save lightning attention decode benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(batch_size=4)
# Run performance benchmark
benchmark.run(print_data=True)
...@@ -100,6 +100,7 @@ ext_modules = [ ...@@ -100,6 +100,7 @@ ext_modules = [
"src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu", "src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu", "src/sgl-kernel/csrc/rotary_embedding.cu",
"3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/activation.cu",
......
...@@ -10,6 +10,7 @@ from sgl_kernel.ops import ( ...@@ -10,6 +10,7 @@ from sgl_kernel.ops import (
get_graph_buffer_ipc_meta, get_graph_buffer_ipc_meta,
init_custom_reduce, init_custom_reduce,
int8_scaled_mm, int8_scaled_mm,
lightning_attention_decode,
moe_align_block_size, moe_align_block_size,
register_graph_buffers, register_graph_buffers,
rmsnorm, rmsnorm,
...@@ -35,5 +36,6 @@ __all__ = [ ...@@ -35,5 +36,6 @@ __all__ = [
"rmsnorm", "rmsnorm",
"rotary_embedding", "rotary_embedding",
"sampling_scaling_penalties", "sampling_scaling_penalties",
"lightning_attention_decode",
"silu_and_mul", "silu_and_mul",
] ]
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "utils.h"
#define THREADS_PER_BLOCK 128
template <typename T>
__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d]
const T* __restrict__ k, // [b, h, 1, d]
const T* __restrict__ v, // [b, h, 1, e]
const float* __restrict__ past_kv, // [b, h, d, e]
const float* __restrict__ slope, // [h, 1, 1]
T* __restrict__ output, // [b, h, 1, e]
float* __restrict__ new_kv, // [b, h, d, e]
const int batch_size, const int num_heads, const int qk_dim,
const int v_dim) {
extern __shared__ char smem[];
T* q_shared = reinterpret_cast<T*>(smem);
T* k_shared = reinterpret_cast<T*>(smem + qk_dim * sizeof(T));
T* v_shared = reinterpret_cast<T*>(smem + 2 * qk_dim * sizeof(T));
float* new_kv_shared = reinterpret_cast<float*>(smem + (2 * qk_dim + v_dim) * sizeof(T));
T* output_shared =
reinterpret_cast<T*>(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float));
const int32_t tid = threadIdx.x;
const int32_t current_head = blockIdx.x;
const int32_t b = current_head / num_heads;
const int32_t h = current_head % num_heads;
if (b >= batch_size) return;
const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim;
const int32_t v_offset = b * num_heads * v_dim + h * v_dim;
const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim;
for (int d = tid; d < qk_dim; d += blockDim.x) {
q_shared[d] = q[qk_offset + d];
k_shared[d] = k[qk_offset + d];
}
for (int e = tid; e < v_dim; e += blockDim.x) {
v_shared[e] = v[v_offset + e];
}
__syncthreads();
const float ratio = expf(-1.0f * slope[h]);
for (int d = tid; d < qk_dim; d += blockDim.x) {
T k_val = k_shared[d];
for (int e = 0; e < v_dim; ++e) {
int past_kv_idx = kv_offset + d * v_dim + e;
T v_val = v_shared[e];
float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val;
int shared_idx = d * (v_dim + 1) + e;
new_kv_shared[shared_idx] = new_val;
}
}
__syncthreads();
for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) {
int d = idx / v_dim;
int e = idx % v_dim;
int shared_idx = d * (v_dim + 1) + e;
int global_idx = kv_offset + idx;
new_kv[global_idx] = new_kv_shared[shared_idx];
}
__syncthreads();
for (int e = tid; e < v_dim; e += blockDim.x) {
float sum = 0.0f;
for (int d = 0; d < qk_dim; ++d) {
int shared_idx = d * (v_dim + 1) + e;
sum += q_shared[d] * new_kv_shared[shared_idx];
}
output_shared[e] = static_cast<T>(sum);
}
__syncthreads();
if (tid == 0) {
for (int e = 0; e < v_dim; ++e) {
output[v_offset + e] = output_shared[e];
}
}
}
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
torch::Tensor new_kv) {
TORCH_CHECK(q.is_contiguous(), "q must be contiguous");
TORCH_CHECK(k.is_contiguous(), "k must be contiguous");
TORCH_CHECK(v.is_contiguous(), "v must be contiguous");
TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous");
auto batch_size = q.size(0);
auto num_heads = q.size(1);
auto qk_dim = q.size(3);
auto v_dim = v.size(3);
dim3 block(THREADS_PER_BLOCK);
dim3 grid(batch_size * num_heads);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] {
size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float);
lightning_attention_decode_kernel<scalar_t><<<grid, block, smem_size, stream>>>(
q.data_ptr<scalar_t>(), k.data_ptr<scalar_t>(), v.data_ptr<scalar_t>(), past_kv.data_ptr<float>(),
slope.data_ptr<float>(), output.data_ptr<scalar_t>(), new_kv.data_ptr<float>(), batch_size, num_heads,
qk_dim, v_dim);
}));
}
...@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma ...@@ -26,6 +26,11 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias); const c10::optional<torch::Tensor>& bias);
// lightning_attention_decode
void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v,
const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output,
torch::Tensor new_kv);
// rotary embedding // rotary embedding
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox); torch::Tensor& cos_sin_cache, bool is_neox);
...@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -69,6 +74,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)"); m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm // int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
// lightning_attention_decode
m.def("lightning_attention_decode", &lightning_attention_decode, "Lightning Attention Ddecode (CUDA)");
// rotary embedding // rotary embedding
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)"); m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
// rms norm // rms norm
......
...@@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import ( ...@@ -14,6 +14,9 @@ from sgl_kernel.ops._kernels import (
) )
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import (
lightning_attention_decode as _lightning_attention_decode,
)
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm from sgl_kernel.ops._kernels import rmsnorm as _rmsnorm
...@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): ...@@ -86,6 +89,10 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
) )
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
_lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox) return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
......
import pytest
import torch
from sgl_kernel import lightning_attention_decode
def naive_lightning_attention_decode(q, k, v, past_kv, slope):
"""Naive implementation of lightning attention decode"""
original_dtype = q.dtype
ratio = torch.exp(-slope) # [h, 1, 1]
kv = past_kv
b, h, n, d = q.shape
output = []
for i in range(n):
kv = ratio * kv.to(torch.float32) + torch.einsum(
"... n d, ... n e -> ... d e",
k[:, :, i : i + 1],
v[:, :, i : i + 1],
)
qkv = torch.einsum(
"... n e, ... e d -> ... n d",
q[:, :, i : i + 1].to(torch.float32),
kv.to(torch.float32),
)
output.append(qkv)
output = torch.concat(output, dim=-2)
return output.to(original_dtype), kv
configs = [
# (batch_size, num_heads, dim, embed_dim)
(1, 8, 64, 64),
(2, 8, 64, 64),
(1, 32, 32, 64),
(2, 32, 32, 64),
(4, 32, 64, 64),
(4, 32, 64, 64),
(16, 64, 96, 96),
(64, 64, 96, 96),
]
dtypes = [torch.float32, torch.float16, torch.bfloat16]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs)
def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim):
device = torch.device("cuda")
q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype)
v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype)
past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device)
slope = torch.randn(num_heads, 1, 1, device=device)
ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope)
output = torch.empty_like(ref_output)
new_kv = torch.empty_like(ref_new_kv)
lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv)
rtol = 1e-2
atol = 1e-2
torch.testing.assert_close(
output,
ref_output,
rtol=rtol,
atol=atol,
msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
)
torch.testing.assert_close(
new_kv,
ref_new_kv,
rtol=rtol,
atol=atol,
msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, "
f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment