Unverified Commit 42245551 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[sgl-kernel] Optimize concat_mla_k kernel (#10543)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarPGFLMG <1106310035@qq.com>
parent 2a9d995c
import torch
import triton
import triton.language as tl
from sgl_kernel import concat_mla_k as concat_mla_k_cuda
DEVICE = triton.runtime.driver.active.get_active_torch_device()
num_local_heads = 128
qk_nope_head_dim = 128
qk_rope_head_dim = 64
def create_data(num_tokens):
k_nope_container = torch.randn(
(num_tokens, num_local_heads, qk_nope_head_dim + 128),
dtype=torch.bfloat16,
device="cuda",
)
k_nope = k_nope_container[:, :, :qk_nope_head_dim]
k_rope_container = torch.randn(
(num_tokens, 1, 128 + qk_rope_head_dim), dtype=torch.bfloat16, device="cuda"
)
k_rope = k_rope_container[:, :, -qk_rope_head_dim:]
k = torch.empty(
(num_tokens, num_local_heads, qk_nope_head_dim + qk_rope_head_dim),
dtype=torch.bfloat16,
device="cuda",
)
return dict(k=k, k_nope=k_nope, k_rope=k_rope)
def fn_torch(k, k_nope, k_rope):
k[..., :qk_nope_head_dim] = k_nope
k[..., qk_nope_head_dim:] = k_rope
def fn_hack_non_strided(k, k_nope, k_rope):
k_flatten_view = k.flatten()
k_flatten_view[: k_nope.numel()] = k_nope.flatten()
k2 = k_flatten_view[k_nope.numel() :].view(k_rope.numel(), -1)
k2 = k_rope.flatten()[:, None]
@torch.compile(dynamic=True)
def fn_torch_compiled(k, k_nope, k_rope):
return fn_torch(k, k_nope, k_rope)
def fn_cuda(k, k_nope, k_rope):
concat_mla_k_cuda(k, k_nope, k_rope)
@triton.jit
def fn_triton_kernel(
k_ptr,
k_nope_ptr,
k_rope_ptr,
num_tokens,
QK_NOPE_HEAD_DIM: tl.constexpr,
QK_ROPE_HEAD_DIM: tl.constexpr,
NUM_LOCAL_HEADS: tl.constexpr,
K_NOPE_STRIDE_0: tl.constexpr,
K_NOPE_STRIDE_1: tl.constexpr,
K_STRIDE_0: tl.constexpr,
K_STRIDE_1: tl.constexpr,
K_ROPE_STRIDE_0: tl.constexpr,
BLOCK_ROWS: tl.constexpr,
):
pid = tl.program_id(axis=0)
token_id = pid * BLOCK_ROWS + tl.arange(0, BLOCK_ROWS)
token_mask = token_id < num_tokens
head_id = tl.arange(0, NUM_LOCAL_HEADS)
# nope
nope_sub_id = tl.arange(0, QK_NOPE_HEAD_DIM)
offs_nope = (
token_id[:, None, None] * K_NOPE_STRIDE_0
+ head_id[None, :, None] * K_NOPE_STRIDE_1
+ nope_sub_id[None, None, :]
)
offs_k = (
token_id[:, None, None] * K_STRIDE_0
+ head_id[None, :, None] * K_STRIDE_1
+ nope_sub_id[None, None, :]
)
vals_nope = tl.load(k_nope_ptr + offs_nope, mask=token_mask[:, None, None])
tl.store(k_ptr + offs_k, vals_nope, mask=token_mask[:, None, None])
# rope
rope_sub_id = tl.arange(0, QK_ROPE_HEAD_DIM)
offs_rope = token_id[:, None, None] * K_ROPE_STRIDE_0 + rope_sub_id[None, None, :]
offs_k = (
token_id[:, None, None] * K_STRIDE_0
+ head_id[None, :, None] * K_STRIDE_1
+ rope_sub_id[None, None, :]
+ QK_NOPE_HEAD_DIM
)
vals_rope = tl.load(k_rope_ptr + offs_rope, mask=token_mask[:, None, None])
tl.store(k_ptr + offs_k, vals_rope, mask=token_mask[:, None, None])
def fn_triton(k, k_nope, k_rope):
assert k.device == DEVICE and k_nope.device == DEVICE and k_rope.device == DEVICE
num_tokens, _, _ = k.shape
grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_ROWS"]),)
fn_triton_kernel[grid](
k,
k_nope,
k_rope,
num_tokens,
QK_NOPE_HEAD_DIM=qk_nope_head_dim,
QK_ROPE_HEAD_DIM=qk_rope_head_dim,
NUM_LOCAL_HEADS=num_local_heads,
K_NOPE_STRIDE_0=k_nope.stride(0),
K_NOPE_STRIDE_1=k_nope.stride(1),
K_STRIDE_0=k.stride(0),
K_STRIDE_1=k.stride(1),
K_ROPE_STRIDE_0=k_rope.stride(0),
BLOCK_ROWS=16,
)
def execute_and_get_output(f, data):
data["k"].zero_()
f(**data)
assert data["k"].sum().item() != 0
return data["k"].clone()
torch.manual_seed(0)
data = create_data(num_tokens=32768)
output_ref = execute_and_get_output(fn_torch, data)
output_exp = execute_and_get_output(fn_cuda, data)
# print(output_ref)
# print(output_exp)
if not torch.all(output_ref == output_exp):
abs_delta = torch.abs(output_ref - output_exp)
raise AssertionError(
f"{output_ref=} {output_exp=} "
f"{abs_delta=} "
f"{torch.argwhere(abs_delta != 0.0)=} "
)
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"], # Argument names to use as an x-axis for the plot.
x_vals=[
2048,
4096,
8192,
16384,
32768,
], # Different possible values for `x_name`.
x_log=False, # x axis is logarithmic.
line_arg="provider", # Argument name whose value corresponds to a different line in the plot.
line_vals=[
"torch",
"torch_compiled",
"triton",
"hack_non_strided",
"cuda",
], # Possible values for `line_arg`.
line_names=[
"torch",
"torch_compiled",
"triton",
"hack_non_strided",
"cuda",
], # Label name for the lines.
plot_name="vector-add-performance", # Name for the plot. Used also as a file name for saving the plot.
args={}, # Values for function arguments not in `x_names` and `y_name`.
)
)
def benchmark(num_tokens, provider):
data = create_data(num_tokens=num_tokens)
quantiles = [0.5, 0.2, 0.8]
fn = {
"torch": fn_torch,
"torch_compiled": fn_torch_compiled,
"triton": fn_triton,
"hack_non_strided": fn_hack_non_strided,
"cuda": fn_cuda,
}[provider]
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: fn(**data), quantiles=quantiles
)
return ms, min_ms, max_ms
torch.cuda.cudart().cudaProfilerStart()
benchmark.run(print_data=True, show_plots=True)
torch.cuda.cudart().cudaProfilerStop()
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "pytorch_extension_utils.h" #include "pytorch_extension_utils.h"
#include "utils.cuh"
constexpr int NUM_LOCAL_HEADS = 128; constexpr int NUM_LOCAL_HEADS = 128;
constexpr int QK_NOPE_HEAD_DIM = 128; constexpr int QK_NOPE_HEAD_DIM = 128;
...@@ -12,20 +13,10 @@ constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM; ...@@ -12,20 +13,10 @@ constexpr int K_HEAD_DIM = QK_NOPE_HEAD_DIM + QK_ROPE_HEAD_DIM;
constexpr int HEAD_CHUNK_SIZE = 16; constexpr int HEAD_CHUNK_SIZE = 16;
constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE; constexpr int NUM_HEAD_CHUNKS = NUM_LOCAL_HEADS / HEAD_CHUNK_SIZE;
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
__global__ void concat_mla_k_kernel( __global__ void concat_mla_k_kernel(
nv_bfloat16* k, nv_bfloat16* __restrict__ k,
nv_bfloat16* k_nope, const nv_bfloat16* __restrict__ k_nope,
nv_bfloat16* k_rope, const nv_bfloat16* __restrict__ k_rope,
const int num_tokens, const int num_tokens,
const int k_stride_0, const int k_stride_0,
const int k_stride_1, const int k_stride_1,
...@@ -36,43 +27,50 @@ __global__ void concat_mla_k_kernel( ...@@ -36,43 +27,50 @@ __global__ void concat_mla_k_kernel(
const int token_id = flat_warp_id / NUM_HEAD_CHUNKS; const int token_id = flat_warp_id / NUM_HEAD_CHUNKS;
const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS; const int head_chunk_id = flat_warp_id % NUM_HEAD_CHUNKS;
const int lane_id = get_lane_id(); const int lane_id = get_lane_id();
if (token_id >= num_tokens) return;
if (token_id >= num_tokens) { using NopeVec = int2; // 8B/thread,32 thread = 256B/row
return; using RopeVec = int; // 4B/thread,32 thread = 128B/row
} static_assert(sizeof(NopeVec) * 32 == QK_NOPE_HEAD_DIM * sizeof(nv_bfloat16), "nope vec mismatch");
static_assert(sizeof(RopeVec) * 32 == QK_ROPE_HEAD_DIM * sizeof(nv_bfloat16), "rope vec mismatch");
using KNopeBufType = int2; const int head_row0 = head_chunk_id * HEAD_CHUNK_SIZE;
static_assert(sizeof(KNopeBufType) == QK_NOPE_HEAD_DIM * sizeof(k[0]) / 32);
KNopeBufType k_nope_buf[HEAD_CHUNK_SIZE];
using KRopeBufType = int; const int2* __restrict__ nope_src =
static_assert(sizeof(KRopeBufType) == QK_ROPE_HEAD_DIM * sizeof(k[0]) / 32); reinterpret_cast<const int2*>(k_nope + token_id * k_nope_stride_0 + head_row0 * k_nope_stride_1) + lane_id;
KRopeBufType k_rope_buf;
{ int2* __restrict__ nope_dst = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_row0 * k_stride_1) + lane_id;
const int* base_addr = reinterpret_cast<int*>(k_rope + token_id * k_rope_stride_0);
k_rope_buf = *(base_addr + lane_id);
}
#pragma unroll int* __restrict__ rope_dst =
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { reinterpret_cast<int*>(k + token_id * k_stride_0 + head_row0 * k_stride_1 + QK_NOPE_HEAD_DIM) + lane_id;
const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i;
const int2* base_addr = reinterpret_cast<int2*>(k_nope + token_id * k_nope_stride_0 + head_id * k_nope_stride_1); const int nope_src_stride_v = (k_nope_stride_1 >> 2); // int2 covers 4 bf16
k_nope_buf[i] = *(base_addr + lane_id); const int nope_dst_stride_v = (k_stride_1 >> 2);
} const int rope_dst_stride_v = (k_stride_1 >> 1); // int covers 2 bf16
const int* rope_base = reinterpret_cast<const int*>(k_rope + token_id * k_rope_stride_0);
const RopeVec rope_val = ld_na_global_v1(rope_base + lane_id);
prefetch_L2(nope_src);
NopeVec cur = ld_na_global_v2(nope_src);
#pragma unroll #pragma unroll
for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) { for (int i = 0; i < HEAD_CHUNK_SIZE; ++i) {
const int head_id = head_chunk_id * HEAD_CHUNK_SIZE + i; NopeVec next;
if (i + 1 < HEAD_CHUNK_SIZE) {
{ const int2* next_src = nope_src + nope_src_stride_v;
int2* base_addr = reinterpret_cast<int2*>(k + token_id * k_stride_0 + head_id * k_stride_1); prefetch_L2(next_src);
*(base_addr + lane_id) = k_nope_buf[i]; next = ld_na_global_v2(next_src);
}
{
int* base_addr = reinterpret_cast<int*>(k + token_id * k_stride_0 + head_id * k_stride_1 + QK_NOPE_HEAD_DIM);
*(base_addr + lane_id) = k_rope_buf;
} }
st_na_global_v2(nope_dst, cur);
st_na_global_v1(rope_dst, rope_val);
nope_src += nope_src_stride_v;
nope_dst += nope_dst_stride_v;
rope_dst += rope_dst_stride_v;
cur = next;
} }
} }
......
// Adapted from https://github.com/deepseek-ai/DeepEP/blob/main/csrc/kernels/utils.cuh
#pragma once
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cstdint>
__forceinline__ __device__ int get_lane_id() {
int lane_id;
asm("mov.s32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
int ceil_div(int a, int b) {
return (a + b - 1) / b;
}
__device__ __forceinline__ void st_na_global_v1(const int* ptr, int v) {
asm volatile("st.global.L1::no_allocate.s32 [%0], %1;" ::"l"(ptr), "r"(v) : "memory");
}
__device__ __forceinline__ void st_na_global_v2(const int2* ptr, const int2& v) {
asm volatile("st.global.L1::no_allocate.v2.s32 [%0], {%1, %2};" ::"l"(ptr), "r"(v.x), "r"(v.y) : "memory");
}
__device__ __forceinline__ void st_na_global_v4(const int4* ptr, const int4& v) {
asm volatile(
"st.global.L1::no_allocate.v4.s32 [%0], {%1, %2, %3, %4};" ::"l"(ptr), "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w)
: "memory");
}
__device__ __forceinline__ int ld_na_global_v1(const int* ptr) {
int r;
#ifdef USE_L2_HINT
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.s32 %0, [%1];" : "=r"(r) : "l"(ptr));
#else
asm volatile("ld.global.nc.L1::no_allocate.s32 %0, [%1];" : "=r"(r) : "l"(ptr));
#endif
return r;
}
__device__ __forceinline__ int2 ld_na_global_v2(const int2* ptr) {
int2 r;
#ifdef USE_L2_HINT
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr));
#else
asm volatile("ld.global.nc.L1::no_allocate.v2.s32 {%0, %1}, [%2];" : "=r"(r.x), "=r"(r.y) : "l"(ptr));
#endif
return r;
}
__device__ __forceinline__ int4 ld_na_global_v4(const int4* ptr) {
int4 r;
#ifdef USE_L2_HINT
asm volatile("ld.global.nc.L1::no_allocate.L2::128B.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(r.x), "=r"(r.y), "=r"(r.z), "=r"(r.w)
: "l"(ptr));
#else
asm volatile("ld.global.nc.L1::no_allocate.v4.s32 {%0, %1, %2, %3}, [%4];"
: "=r"(r.x), "=r"(r.y), "=r"(r.z), "=r"(r.w)
: "l"(ptr));
#endif
return r;
}
__device__ __forceinline__ void prefetch_L2(const void* p) {
#if defined(ENABLE_L2_PREFETCH)
asm volatile("prefetch.global.L2 [%0];" ::"l"(p));
#endif
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment