Unverified Commit b8240b7a authored by Yuxuan Hu's avatar Yuxuan Hu Committed by GitHub
Browse files

Add sparse fine-tuning kernel for deepseek sparse attention to example (#1296)

* [EXAMPLE] add example for dsa sparse finetuning

* [Refactor]
parent 6bae64f6
from typing import Optional
import torch
import torch.nn.functional as F
from indexer_topk_reducesum import indexer_topk_reducesum_interface
from indexer_bwd import indexer_bwd_interface
from sparse_mla_fwd import sparse_mla_fwd_interface
from sparse_mla_bwd import sparse_mla_bwd
from sparse_mla_topk_reducesum import sparse_mla_topk_reducesum_interface
from einops import einsum, repeat
from utils import get_abs_err, get_err_ratio
class RegsiterLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, loss):
ctx.save_for_backward(loss)
return x
@staticmethod
def backward(ctx, grad):
loss = ctx.saved_tensors
return grad, torch.ones(1, dtype=loss[0].dtype, device=loss[0].device)
register_loss = RegsiterLossFunction.apply
def ref_deepseek_sparse_attention_innner(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
index_sm_scale: Optional[float] = None,
):
dtype = q.dtype
q, kv, index_q, index_k, weights = map(lambda x: x.to(torch.float32),
(q, kv, index_q, index_k, weights))
index_sm_scale = index_q.shape[-1]**-0.5
b, s = index_q.shape[:2]
# tl_topk_indices = tl_topk_indices.to(torch.int64)
# tl_topk_indices[tl_topk_indices == -1] = s
casual_mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
index_logits = einsum(index_q, index_k, 'b s1 h k, b s2 k -> b s1 h s2')
index_logits = F.relu(index_logits)
index_logits = (index_logits * weights.unsqueeze(-1)).sum(
dim=-2, dtype=torch.float32) * index_sm_scale
index_logits = torch.where(casual_mask, index_logits, float('-inf'))
topk_indices = torch.topk(index_logits, k=topk, dim=-1).indices
topk_logits = torch.gather(
F.pad(index_logits, (0, 1), value=float('-inf')), dim=-1, index=topk_indices)
topk_score = F.log_softmax(topk_logits, dim=-1, dtype=torch.float32)
index_topk_score = topk_score
if sm_scale is None:
sm_scale = kv.shape[-1]**-0.5
h = q.shape[-2]
index_mask = torch.zeros((b, s, s + 1), dtype=torch.bool, device="cuda")\
.scatter_(dim=-1, index=topk_indices, src=torch.ones_like(topk_indices, dtype=torch.bool))[:, :, :-1]
mask = repeat(casual_mask & index_mask, 'b s1 s2 -> b s1 h s2', h=h)
k, v = kv, kv[..., :dim_v]
logits = einsum(q, k, 'b s1 h d, b s2 d -> b s1 h s2') * sm_scale
logits = torch.where(mask, logits, float('-inf'))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
o = einsum(attn_score, v, 'b s1 h s2, b s2 d -> b s1 h d')
attn_score = attn_score.sum(dim=-2) # [b, s1, s2]
attn_topk_score = torch.gather(F.pad(attn_score, (0, 1)), dim=-1, index=topk_indices)
attn_topk_score = attn_topk_score / attn_topk_score.sum(dim=-1, keepdim=True)
loss = F.kl_div(
index_topk_score.clip(-100, 0),
attn_topk_score.detach().log().clip(-100, 0),
log_target=True,
reduction="sum")
o = register_loss(o, loss)
return o.to(dtype), topk_indices
def ref_deepseek_sparse_attention(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
index_sm_scale: Optional[float] = None,
):
all_o, all_topk_indices = [], []
for i in range(offsets.shape[0] - 1):
o, topk_indices = ref_deepseek_sparse_attention_innner(
q[None, offsets[i]:offsets[i + 1]],
kv[None, offsets[i]:offsets[i + 1]],
index_q[None, offsets[i]:offsets[i + 1]],
index_k[None, offsets[i]:offsets[i + 1]],
weights[None, offsets[i]:offsets[i + 1]],
topk,
dim_v,
sm_scale,
index_sm_scale,
)
all_o.append(o.squeeze(0))
all_topk_indices.append(topk_indices.squeeze(0))
o = torch.cat(all_o, dim=0)
topk_indices = torch.cat(all_topk_indices, dim=0)
return o, topk_indices
class DSAFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
):
# topk_indices, index_score = ref_index_score(index_q, weights, index_k, topk)
topk_indices, index_score = indexer_topk_reducesum_interface(index_q, weights, index_k,
topk, offsets)
o, lse = sparse_mla_fwd_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), offsets, sm_scale=sm_scale, d_v=dim_v)
ctx.save_for_backward(q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse,
offsets)
ctx.topk = topk
ctx.dim_v = dim_v
ctx.sm_scale = sm_scale
return o, topk_indices
@staticmethod
def backward(
ctx,
do: torch.Tensor,
_1: torch.Tensor,
):
q, kv, index_q, index_k, weights, topk_indices, index_score, o, lse, offsets = ctx.saved_tensors
attn_score = sparse_mla_topk_reducesum_interface(
q, kv.unsqueeze(-2), topk_indices.unsqueeze(-2), lse, offsets,
dim_v=ctx.dim_v).squeeze(-2)
dq, dkv = sparse_mla_bwd(
q,
kv.unsqueeze(-2),
o,
do,
topk_indices.unsqueeze(-2),
lse,
offsets,
sm_scale=ctx.sm_scale)
dindex_q, dweights, dindex_k = indexer_bwd_interface(index_q, weights, index_k, attn_score,
index_score, topk_indices, offsets)
return dq, dkv.squeeze(-2), dindex_q, dindex_k, dweights, None, None, None, None
def deepseek_sparse_attention(
q: torch.Tensor,
kv: torch.Tensor,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
offsets: torch.Tensor,
topk: int,
dim_v: int,
sm_scale: Optional[float] = None,
):
return DSAFunction.apply(q, kv, index_q, index_k, weights, offsets, topk, dim_v, sm_scale)
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
index_D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16().requires_grad_()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16().requires_grad_()
index_q = torch.randn((S, H, index_D)).cuda().bfloat16().requires_grad_()
weights = torch.randn((S, H)).cuda().bfloat16().requires_grad_()
index_k = torch.randn((S, index_D)).cuda().bfloat16().requires_grad_()
do = torch.randn((S, H, D)).cuda().bfloat16().requires_grad_()
offsets = torch.tensor([0, S // 2, S], dtype=torch.int32).cuda()
o, topk_indices = deepseek_sparse_attention(q, kv, index_q, index_k, weights, offsets, topk, D)
o.backward(do)
q_grad, q.grad = q.grad, None
kv_grad, kv.grad = kv.grad, None
index_q_grad, index_q.grad = index_q.grad, None
index_k_grad, index_k.grad = index_k.grad, None
weights_grad, weights.grad = weights.grad, None
ref_o, ref_topk_indices = ref_deepseek_sparse_attention(q, kv, index_q, index_k, weights,
offsets, topk, D)
ref_o.backward(do)
ref_q_grad, q.grad = q.grad, None
ref_kv_grad, kv.grad = kv.grad, None
ref_index_q_grad, index_q.grad = index_q.grad, None
ref_index_k_grad, index_k.grad = index_k.grad, None
ref_weights_grad, weights.grad = weights.grad, None
print(f"o err: {get_abs_err(o, ref_o):.6f} ratio: {get_err_ratio(o, ref_o):.6f}")
print(
f"q.grad err: {get_abs_err(q_grad, ref_q_grad):.6f} ratio: {get_err_ratio(q_grad, ref_q_grad):.6f}"
)
print(
f"kv.grad err: {get_abs_err(kv_grad, ref_kv_grad):.6f} ratio: {get_err_ratio(kv_grad, ref_kv_grad):.6f}"
)
print(
f"index_q.grad err: {get_abs_err(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f} ratio: {get_err_ratio(index_q_grad[:, :64, :], ref_index_q_grad[:, :64, :]):.6f}"
)
print(
f"index_k.grad err: {get_abs_err(index_k_grad, ref_index_k_grad):.6f} ratio: {get_err_ratio(index_k_grad, ref_index_k_grad):.6f}"
)
print(
f"weights.grad err: {get_abs_err(weights_grad, ref_weights_grad):.6f} ratio: {get_err_ratio(weights_grad, ref_weights_grad):.6f}"
)
intersections = []
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
mask = (trt_np != -1)
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
intersections.append(len(intersection) / len(set_ref))
print("average intersections: {:.4f}".format(sum(intersections) / len(intersections)))
test_kernel()
# Modified from: https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
import torch
import torch.nn.functional as F
import functools
from typing import Callable, Any
def tensor_cache(fn: Callable[..., torch.Tensor],) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
If the function is called again with the same input tensors, it will return the cached result.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
last_args: tuple | None = None
last_kwargs: dict | None = None
last_result: Any = None
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if (last_args is not None and last_kwargs is not None) and \
(len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) and \
all(a is b for a, b in zip(args, last_args, strict=False)) and \
all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
return last_result
result = fn(*args, **kwargs)
last_args, last_kwargs, last_result = args, kwargs, result
return result
return wrapper
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_cu_seqlens_from_lens(
lens: torch.LongTensor,
dtype: torch.dtype | None = torch.int32,
) -> torch.LongTensor:
return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0))
@tensor_cache
def prepare_lens_from_cu_seqlens(cu_seqlens: torch.LongTensor,) -> torch.LongTensor:
return torch.diff(cu_seqlens)
@tensor_cache
def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return torch.cat([
torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device)
for n in prepare_lens(cu_seqlens).unbind()
])
@tensor_cache
def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1
@tensor_cache
def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
position_ids = prepare_position_ids(cu_seqlens)
return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens)
import torch
import torch.nn.functional as F
from einops import einsum, repeat
import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
pass_configs = {
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tl.jit(pass_configs=pass_configs)
def tl_indexer_bwd_impl(
heads: int,
dim: int,
topk: int,
sm_scale: Optional[float] = None,
block_I: int = 32,
num_stages: int = 0,
num_threads: int = 128,
):
assert num_stages == 0
assert topk == tl.math.next_power_of_2(topk)
assert topk % block_I == 0
assert heads <= 64 and heads % 8 == 0
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
dtype: str = BF16
accum_dtype: str = FP32
index_q_shape = [seq_len, heads, dim]
weights_shape = [seq_len, heads]
index_k_shape = [seq_len, dim]
shape_p = [seq_len, topk]
topk_indices_shape = [seq_len, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
if sm_scale is None:
sm_scale = dim**-0.5
@T.prim_func
def tl_indexer_bwd_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
dIndexQ: T.Tensor(index_q_shape, dtype),
dWeights: T.Tensor(weights_shape, dtype),
dIndexK: T.Tensor(index_k_shape, dtype),
AttnScore: T.Tensor(shape_p, FP32),
IndexScore: T.Tensor(shape_p, FP32),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
bos = Offsets[i_b]
num_blocks = T.ceildiv(topk, block_I)
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
weights_shared = T.alloc_shared([heads], dtype=dtype)
d_index_q_frag = T.alloc_fragment([heads, dim], dtype=accum_dtype)
d_weights_frag = T.alloc_fragment([heads], dtype=accum_dtype)
T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
T.copy(Weights[bos + i_t, :], weights_shared)
T.fill(d_index_q_frag, 0)
T.fill(d_weights_frag, 0)
for i, j in T.Parallel(heads, dim):
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
for bi_i in T.Pipelined(num_blocks, num_stages=num_stages):
i_st = bi_i * block_I
i_ed = (bi_i + 1) * block_I
indices_shared = T.alloc_shared([block_I], dtype=INT32)
T.copy(TopkIndices[bos + i_t, i_st:i_ed], indices_shared)
index_k_shared = T.alloc_shared([block_I, dim], dtype=dtype)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
index_k_shared[i, j] = T.if_then_else((pos > -1) & (pos <= i_t),
IndexK[bos + pos, j], 0)
attn_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
index_score_shared = T.alloc_shared([block_I], dtype=accum_dtype)
for i in T.Parallel(block_I):
attn_score_shared[i] = AttnScore[bos + i_t, i_st + i]
index_score_shared[i] = IndexScore[bos + i_t, i_st + i]
logits = T.alloc_fragment((block_I, heads), accum_dtype)
T.gemm(
index_k_shared,
index_q_shared,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i, j in T.Parallel(block_I, heads):
logits[i, j] = T.max(logits[i, j], 0)
# dw
d_weights_i = T.alloc_fragment((block_I, heads), accum_dtype)
for i, j in T.Parallel(block_I, heads):
d_weights_i[i,
j] = (index_score_shared[i] - attn_score_shared[i]) * logits[i, j]
T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False)
d_logits_qk = T.alloc_shared((block_I, heads), accum_dtype)
d_logits_qk_cast1 = T.alloc_fragment((block_I, heads), dtype)
d_logits_qk_cast2 = T.alloc_fragment((block_I, heads), dtype)
for i, j in T.Parallel(block_I, heads):
d_relu = T.alloc_var(accum_dtype)
if logits[i, j] > 0:
d_relu = 1.0
else:
d_relu = 0.0
d_logits_qk[i, j] = (index_score_shared[i] -
attn_score_shared[i]) * d_relu * weights_shared[j]
# dq
T.copy(d_logits_qk, d_logits_qk_cast1)
T.gemm(
d_logits_qk_cast1, # [BS, HQ]
index_k_shared, # [BS, K]
d_index_q_frag, # [HQ, K]
transpose_A=True,
transpose_B=False,
clear_accum=False,
)
# dk
T.copy(d_logits_qk, d_logits_qk_cast2)
d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype)
T.gemm(
d_logits_qk_cast2, # [BS, HQ]
index_q_shared, # [HQ, K]
d_index_k_frag, # [BS, K]
transpose_A=False,
transpose_B=False,
clear_accum=True,
)
for i, j in T.Parallel(block_I, dim):
pos = indices_shared[i]
if ((pos > -1) & (pos <= i_t)):
T.atomic_add(dIndexK[bos + pos, j], d_index_k_frag[i, j])
for i, j in T.Parallel(heads, dim):
d_index_q_frag[i, j] = d_index_q_frag[i, j] * sm_scale
T.copy(d_index_q_frag, dIndexQ[bos + i_t, :, :])
T.copy(d_weights_frag, dWeights[bos + i_t, :])
return tl_indexer_bwd_kernel
def indexer_bwd_interface(
q: torch.Tensor,
weights: torch.Tensor,
k: torch.Tensor,
attn_score: torch.Tensor,
index_score: torch.Tensor,
topk_indices: torch.Tensor,
offsets: torch.Tensor,
):
_, heads, dim, topk = *q.shape, topk_indices.shape[-1]
token_indices = prepare_token_indices(offsets)
dq = torch.zeros_like(q)
dweights = torch.zeros_like(weights)
dk = torch.zeros_like(k)
kernel = tl_indexer_bwd_impl(heads, dim, topk)
kernel(q, weights, k, dq, dweights, dk, attn_score, index_score, topk_indices, offsets,
token_indices)
return dq, dweights, dk
def ref_indexer_bwd(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor,
TopkIndices: torch.Tensor, AttnScore: torch.Tensor,
offsets: torch.Tensor) -> torch.Tensor:
Q.requires_grad_(True)
Weights.requires_grad_(True)
K.requires_grad_(True)
softmax_scale = Q.shape[-1]**-0.5
all_loss = []
all_log_topk_prob = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= TopkIndices.shape[-1]
q = Q[offsets[i]:offsets[i + 1]]
weights = Weights[offsets[i]:offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]]
topk_indices = TopkIndices[offsets[i]:offsets[i + 1]]
attn_score = AttnScore[offsets[i]:offsets[i + 1]]
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2') * softmax_scale
logits = F.relu(logits)
score = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32)
score = torch.where(mask, score, float('-inf'))
topk_value = torch.gather(score, dim=-1, index=topk_indices.to(torch.int64))
log_topk_prob = F.log_softmax(topk_value, dim=-1, dtype=torch.float32)
loss = F.kl_div(
log_topk_prob.clip(-100, 0),
attn_score.log().clip(-100, 0),
log_target=True,
reduction="sum")
all_loss.append(loss)
all_log_topk_prob.append(log_topk_prob)
loss = torch.stack(all_loss).sum()
loss.backward()
log_topk_prob = torch.cat(all_log_topk_prob, dim=0)
return log_topk_prob.exp(), Q.grad, Weights.grad, K.grad
def test_kernel(
B=1,
S=2048,
H=16,
D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D)).cuda().bfloat16()
w = torch.randn((S, H)).cuda().bfloat16()
k = torch.randn((S, D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
all_attn_score = []
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
mask = (torch.arange(seq_len)[:, None] >= torch.arange(topk)[None, :]).to(q.device)
logits = torch.ones(seq_len, topk).cuda()
logits = torch.where(mask, logits, float('-inf'))
attn_score = F.softmax(logits, dim=-1, dtype=torch.float32)
all_attn_score.append(attn_score)
attn_score = torch.cat(all_attn_score, dim=0)
topk_indices = repeat(
torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous()
index_score, ref_dq, ref_dw, ref_dk = ref_indexer_bwd(q, w, k, topk_indices, attn_score,
offsets)
dq, dw, dk = indexer_bwd_interface(q, w, k, attn_score, index_score, topk_indices, offsets)
print(f"dq err: {get_abs_err(dq, ref_dq):.6f} ratio: {get_err_ratio(dq, ref_dq):.6f}")
print(f"dq err: {get_abs_err(dw, ref_dw):.6f} ratio: {get_err_ratio(dw, ref_dw):.6f}")
print(f"dq err: {get_abs_err(dk, ref_dk):.6f} ratio: {get_err_ratio(dk, ref_dk):.6f}")
if __name__ == '__main__':
test_kernel()
import math
import torch
import torch.nn.functional as F
from einops import einsum
import tilelang as tl
import tilelang.language as T
from typing import Optional
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
pass_configs = {
tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True,
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tl.jit(pass_configs=pass_configs)
def tl_indexer_topk_reducesum_impl(
heads: int,
dim: int,
topk: int,
sm_scale: Optional[float] = None,
block_K: int = 32,
dtype: str = FP32,
num_stages: int = 0,
num_threads: int = 128,
):
assert topk == tl.math.next_power_of_2(topk)
assert topk % block_K == 0
assert heads <= 64 and heads % 8 == 0
assert num_stages == 0
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
index_q_shape = [seq_len, heads, dim]
weights_shape = [seq_len, heads]
index_k_shape = [seq_len, dim]
topk_indices_shape = [seq_len, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
N = 2 * topk
num_iters = int(round(math.log2(N)))
if sm_scale is None:
sm_scale = dim**-0.5
@T.macro
def bitonic_sort(
topk_index_shared: T.SharedBuffer([N], dtype=INT32),
topk_value_shared: T.SharedBuffer([N], dtype=FP32),
):
T.sync_threads()
for i1 in T.serial(num_iters):
for i2 in T.serial(i1 + 1):
for i in T.Parallel(N):
ascending = (i & (1 << (i1 + 1))) != 0
j = i ^ (1 << (i1 - i2))
if i < j and \
((ascending and topk_value_shared[i] > topk_value_shared[j]) or (
not ascending and topk_value_shared[i] < topk_value_shared[j])):
val = topk_value_shared[i]
topk_value_shared[i] = topk_value_shared[j]
topk_value_shared[j] = val
idx = topk_index_shared[i]
topk_index_shared[i] = topk_index_shared[j]
topk_index_shared[j] = idx
T.sync_threads()
@T.prim_func
def tl_indexer_topk_reducesum_kernel(
IndexQ: T.Tensor(index_q_shape, dtype),
Weights: T.Tensor(weights_shape, dtype),
IndexK: T.Tensor(index_k_shape, dtype),
TopkIndices: T.Tensor(topk_indices_shape, INT32),
ReduceSum: T.Tensor(topk_indices_shape, FP32),
Offsets: T.Tensor(offsets_shape, INT32),
TokenIndices: T.Tensor(token_indices_shape, INT32),
):
with T.Kernel(seq_len, threads=num_threads) as (bx):
i_b, i_t = TokenIndices[bx, 0], TokenIndices[bx, 1]
bos, eos = Offsets[i_b], Offsets[i_b + 1]
num_blocks = T.ceildiv(i_t + 1, block_K)
topk_index_shared = T.alloc_shared([N], dtype=INT32)
topk_value_shared = T.alloc_shared([N], dtype=FP32)
T.fill(topk_index_shared, -1)
T.fill(topk_value_shared, float('-inf'))
T.sync_threads()
index_q_shared = T.alloc_shared([heads, dim], dtype=dtype)
T.copy(IndexQ[bos + i_t, :, :], index_q_shared)
T.sync_threads()
weights_frag = T.alloc_shared([heads], dtype=dtype)
T.copy(Weights[bos + i_t, :], weights_frag)
T.sync_threads()
for i, j in T.Parallel(heads, dim):
index_q_shared[i, j] = index_q_shared[i, j] * sm_scale
T.sync_threads()
for bk_i in T.Pipelined(num_blocks, num_stages=num_stages):
k_st = bk_i * block_K
k_ed = T.min((bk_i + 1) * block_K, eos - bos)
index_k_shared = T.alloc_shared([block_K, dim], dtype=dtype)
for i, j in T.Parallel(block_K, dim):
index_k_shared[i, j] = T.if_then_else(k_st + i < k_ed, IndexK[bos + k_st + i,
j], 0)
T.sync_threads()
logits = T.alloc_fragment((block_K, heads), FP32)
T.gemm(
index_k_shared,
index_q_shared,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
T.sync_threads()
for i, j in T.Parallel(block_K, heads):
logits[i, j] = T.max(logits[i, j], 0) * weights_frag[j]
T.sync_threads()
logits_sum = T.alloc_fragment(block_K, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
T.sync_threads()
offset = T.alloc_var(INT32)
if k_st >= topk:
offset = topk + (k_st % topk)
else:
offset = k_st
T.sync_threads()
for i in T.Parallel(block_K):
if k_st + i > i_t:
logits_sum[i] = float('-inf')
j = offset + i
topk_index_shared[j] = k_st + i
topk_value_shared[j] = logits_sum[i]
T.sync_threads()
if k_ed > topk and k_ed % topk == 0:
bitonic_sort(topk_index_shared, topk_value_shared)
bitonic_sort(topk_index_shared, topk_value_shared)
logits_max_frag = T.alloc_fragment([1], dtype=FP32)
logits_frag = T.alloc_fragment([topk], dtype=FP32)
reducesum_shared = T.alloc_shared([topk], dtype=FP32)
T.copy(topk_value_shared[:topk], logits_frag)
T.sync_threads()
T.reduce_max(logits_frag, logits_max_frag, dim=-1)
T.sync_threads()
for i in T.Parallel(topk):
logits_frag[i] = T.exp(logits_frag[i] - logits_max_frag[0])
T.sync_threads()
lse_frag = T.alloc_fragment([1], dtype=FP32)
T.reduce_sum(logits_frag, lse_frag)
T.sync_threads()
for i in T.Parallel(topk):
reducesum_shared[i] = logits_frag[i] / lse_frag[0]
T.sync_threads()
# for i in T.Parallel(topk):
# reducesum_shared[i] = logits_frag[i]
# T.sync_threads()
for i in T.Parallel(topk):
if topk_index_shared[i] > i_t:
topk_index_shared[i] = -1
T.sync_threads()
T.copy(topk_index_shared[:topk], TopkIndices[bos + i_t, :])
T.copy(reducesum_shared[:topk], ReduceSum[bos + i_t, :])
return tl_indexer_topk_reducesum_kernel
def indexer_topk_reducesum_interface(
q: torch.Tensor,
weights: torch.Tensor,
k: torch.Tensor,
topk: int,
offsets: torch.Tensor,
dtype: str = BF16,
):
seq_len, heads, dim = q.shape
kernel = tl_indexer_topk_reducesum_impl(heads=heads, dim=dim, topk=topk, dtype=dtype)
token_indices = prepare_token_indices(offsets)
topk_indices = torch.zeros((seq_len, topk), device=q.device, dtype=torch.int32)
topk_score = torch.zeros((seq_len, topk), device=q.device, dtype=torch.float32)
kernel(q, weights, k, topk_indices, topk_score, offsets, token_indices)
return topk_indices, topk_score
def ref_index_score(Q: torch.Tensor, Weights: torch.Tensor, K: torch.Tensor, topk: int,
offsets: torch.Tensor) -> torch.Tensor:
all_topk_indices = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
assert (offsets[i + 1] - offsets[i]).item() >= topk
q = Q[offsets[i]:offsets[i + 1]]
weights = Weights[offsets[i]:offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]]
softmax_scale = q.shape[-1]**-0.5
s = q.shape[0]
mask = (torch.arange(s)[:, None] >= torch.arange(s)[None, :]).to(q.device)
logits = einsum(q, k, 's1 h k, s2 k -> s1 h s2')
logits = F.relu(logits)
logits = (logits * weights.unsqueeze(-1)).sum(dim=-2, dtype=torch.float32) * softmax_scale
logits = torch.where(mask, logits, float('-inf'))
topk_logits, topk_indices = torch.topk(logits, k=topk, dim=-1)
topk_score = F.softmax(topk_logits, dim=-1, dtype=torch.float32)
all_topk_indices.append(topk_indices)
all_topk_score.append(topk_score)
topk_indices = torch.cat(all_topk_indices, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return topk_indices, topk_score
def test_kernel(
B=1,
S=2048,
H=64,
D=128,
topk=64,
):
torch.manual_seed(42)
q = torch.randn((S, H, D)).cuda().bfloat16()
weights = torch.randn((S, H)).cuda().bfloat16()
k = torch.randn((S, D)).cuda().bfloat16()
offsets = torch.tensor([0, S], dtype=torch.int32).cuda()
ref_topk_indices, ref_topk_score = ref_index_score(q, weights, k, topk, offsets)
topk_indices, topk_score = indexer_topk_reducesum_interface(q, weights, k, topk, offsets)
for j in range(S):
ref_np = ref_topk_indices[j].cpu().to(torch.int32).numpy()
trt_np = topk_indices[j].cpu().to(torch.int32).numpy()
ref_np_val = ref_topk_score[j]
trt_np_val = topk_score[j]
mask = (ref_np_val > 0).cpu().numpy()
set_ref = set(ref_np[mask])
set_trt = set(trt_np[mask])
intersection = set_ref & set_trt
print("idx:", j, "selected/all:", len(intersection), "/", len(set_ref), "=",
len(intersection) / len(set_ref))
print(
f"err: {get_abs_err(ref_np_val, trt_np_val):.6f} ratio: {get_err_ratio(ref_np_val, trt_np_val):.6f}"
)
if __name__ == '__main__':
test_kernel()
# ruff: noqa
import tilelang
from tilelang import language as T
import torch
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(out_idx=[-1])
def preprocess(
H,
D,
block_ND=32,
num_stages=5,
dtype="bfloat16",
accum_dtype="float",
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
S = T.symbolic('S')
shape = [S, H, D]
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND)) as (bx, by):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
do = T.alloc_fragment([block_ND, block_ND], accum_dtype)
delta = T.alloc_fragment([block_ND], accum_dtype)
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
T.copy(dO[by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[by * block_ND:(by + 1) * block_ND, bx])
return preprocess_kernel
@tilelang.jit(out_idx=[-1])
def postprocess(
D,
D_tail,
kv_group=1,
block_N=64,
threads=128,
dtype="bfloat16",
accum_dtype="float",
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
S_kv = T.symbolic('S_kv')
dkv_shape = [S_kv, kv_group, D + D_tail]
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, threads=threads) as (bx, by):
T.copy(
dKV[bx * block_N:(bx + 1) * block_N, by, :],
dKV_out[bx * block_N:(bx + 1) * block_N, by, :],
)
return postprocess_kernel
@tilelang.jit(
out_idx=[-2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def bwd(
H,
D,
D_tail,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
block_size=32,
num_stages=0,
threads=128,
indices_dtype="int32",
dtype="bfloat16",
accum_dtype="float",
):
assert is_causal == True, 'non-casual is not supported now'
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5)
B_plus_one = T.symbolic('B_plus_one')
S = T.symbolic('S')
H_kv = H // kv_group
q_shape = [S, H, D + D_tail]
k_shape = [S, kv_group, D + D_tail]
o_shape = [S, H, D]
indices_shape = [S, kv_group, topk]
delta_shape = [S, H]
lse_shape = [S, H]
offsets_shape = [B_plus_one]
token_indices_shape = [S, 2]
assert indices_dtype == "int32"
assert dtype == "bfloat16"
assert accum_dtype == "float"
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
BS = block_size
NS = tilelang.cdiv(topk, block_size)
split_store = 2
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
Offsets: T.Tensor(offsets_shape, indices_dtype),
TokenIndices: T.Tensor(token_indices_shape, indices_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, kv_group, threads=threads) as (b_s_i, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
KV_shared = T.alloc_shared([BS, D], dtype)
KV_tail_shared = T.alloc_shared([BS, D_tail], dtype)
dO_shared = T.alloc_shared([padded_H, D], dtype)
mask = T.alloc_fragment([BS], "bool")
P_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dP_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dQ_shared = T.alloc_shared([padded_H, D], dtype)
dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
acc_p = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dq = T.alloc_fragment([padded_H, D], accum_dtype)
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
max_kv_i = s_i
T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout({
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
})
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = (Indices[bos + s_i, bz, i_i * BS + bi_i] <= max_kv_i) & (
Indices[bos + s_i, bz, i_i * BS + bi_i] != -1)
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i], bz,
d_i]
T.gemm(
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i],
bz, D + d_i]
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp(acc_p[h_i, bi_i] * sm_scale -
Lse[bos + s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (
acc_dp[h_i, bi_i] - Delta[bos + s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
dP_shared_cast,
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
if bi_i < BS // split_store:
acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i,
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s *
(BS // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4])
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s *
(BS // split_store)], bz, D + d_i * 4],
acc_dkv_tail_shared[bi_i, d_i * 4])
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[bos + s_i, bz * padded_H:(bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q,
kv,
o,
do,
indices,
lse,
offsets,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
assert lse.is_contiguous()
S, H, dim_plus_tail_dim = q.shape
S_kv, kv_group, _ = kv.shape
assert kv.shape[-1] == dim_plus_tail_dim
assert S == S_kv
# dim should be assigned
D = 512
D_tail = dim_plus_tail_dim - D
topk = indices.shape[-1]
assert indices.shape == (S, kv_group, topk)
assert lse.shape == (S, H)
token_indices = prepare_token_indices(offsets)
# Get kernels
preprocess_kernel = preprocess(H, D)
bwd_kernel = bwd(H, D, D_tail, topk, kv_group, sm_scale, is_casual)
postprocess_kernel = postprocess(D, D_tail, kv_group)
if delta is None:
delta = preprocess_kernel(o, do)
dkv = torch.zeros_like(kv, dtype=torch.float32)
dq = bwd_kernel(q, kv, do, indices, lse, delta, offsets, token_indices, dkv)
dkv = postprocess_kernel(dkv)
return dq, dkv
def ref_sparse_mla_bwd_interface(q,
kv,
o,
do,
indices,
lse,
offsets,
sm_scale=None,
is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
kv.requires_grad = True
o = ref_sparse_mla_fwd_interface(q, kv, indices, offsets, sm_scale, is_casual)
o.backward(do)
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1,
S=2048,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=512,
dtype=torch.bfloat16,
check_correctness=True):
# Prepare data
q = torch.randn((S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
kv = torch.randn((S, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((S, H, DV), dtype=dtype, device='cuda')
offsets = torch.tensor([0, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device='cuda')
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, :len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, offsets)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None, offsets)
if check_correctness:
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum([
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
])
from tilelang.profiler import do_bench
def fn():
return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse, offsets)
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ',
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(
B=1,
S=2048,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=512,
dtype=torch.bfloat16,
check_correctness=True)
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
from index import prepare_token_indices
from utils import assert_tensors_similar
@tilelang.jit(
out_idx=[-2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (topk %
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5
else:
sm_scale = sm_scale
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
head_kv = heads // kv_group
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len, kv_group, dim + tail_dim]
o_shape = [seq_len, heads, dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert (
kv_group == 1
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(
seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (
Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i],
g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, Output[bos + s_i, H0:H1, :])
T.copy(sumexp, Lse[bos + s_i, H0:H1])
return main
def sparse_mla_fwd_interface(q,
kv,
indices,
offsets,
sm_scale=None,
return_p_sum: bool = False,
d_v=512,
block_I=32,
num_stages=2,
threads=128):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
seq_len, heads, dim_plus_tail_dim = q.shape
seq_len_kv, kv_group, _ = kv.shape
assert seq_len == seq_len_kv
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = d_v
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
_, _, topk = indices.shape
assert indices.shape == (seq_len, kv_group, topk)
token_indices = prepare_token_indices(offsets)
kernel = sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group,
sm_scale,
is_casual,
block_I=block_I,
num_stages=num_stages,
threads=threads)
out, lse = kernel(q, kv, indices, offsets, token_indices)
return out, lse
def ref_sparse_mla_fwd_interface(Q, KV, Indices, offsets, sm_scale=None, is_casual=True):
Q = Q.float()
KV = KV.float()
all_o = []
for i in range(offsets.shape[0] - 1):
q = Q[None, offsets[i]:offsets[i + 1]]
kv = KV[None, offsets[i]:offsets[i + 1]]
indices = Indices[None, offsets[i]:offsets[i + 1]].clone()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(
0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1)
indices[indices > sk] = sk
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
all_o.append(o.squeeze(0))
o = torch.cat(all_o, dim=0)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd(B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256):
torch.random.manual_seed(0)
q = torch.randn((S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((S, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
offsets = torch.tensor([0, S // 2 - 1, S], dtype=torch.int32, device="cuda")
indices = torch.full((S, HKV, topk), S, dtype=torch.int32, device="cuda")
for i in range(offsets.shape[0] - 1):
seq_len = (offsets[i + 1] - offsets[i]).item()
assert seq_len >= topk
for t in range(seq_len):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[offsets[i] + t, h, :len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(
q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, offsets)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(
q, kv, indices, offsets, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=100,
warmup=250,
)
print(f"Average time: {ms:.3f} ms")
print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_fwd(
B=1,
S=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=1024,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256)
# ruff: noqa
import torch
import torch.nn as nn
import torch.nn.functional as F
import tilelang
from tilelang import language as T
from einops import repeat, rearrange, einsum
from index import prepare_token_indices
from utils import get_abs_err, get_err_ratio
BF16 = "bfloat16"
FP32 = "float32"
INT32 = "int32"
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}
@tilelang.jit(pass_configs=pass_configs)
def tl_sparse_mla_topk_reducesum_impl(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
block_I=32,
num_stages=2,
threads=128,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert (topk %
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5
batch_plus_one = T.symbolic("batch_plus_one")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = heads // kv_group
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert (
kv_group == 1
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
q_shape = [seq_len, heads, dim + tail_dim]
kv_shape = [seq_len_kv, kv_group, dim + tail_dim]
indices_shape = [seq_len, kv_group, topk]
lse_shape = [seq_len, heads]
reducesum_shape = [seq_len, kv_group, REPLICATE_H, topk]
offsets_shape = [batch_plus_one]
token_indices_shape = [seq_len, 2]
@T.prim_func
def tl_sparse_mla_topk_reducesum_kernel(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
Offsets: T.Tensor(offsets_shape, indices_dtype), # type: ignore
TokenIndices: T.Tensor(token_indices_shape, indices_dtype), # type: ignore
ReduceSum: T.Tensor(reducesum_shape, accum_dtype), # type: ignore
):
with T.Kernel(
seq_len * REPLICATE_H, kv_group, threads=threads) as (
bx,
by,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
mask = T.alloc_fragment([BI], "bool")
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
reducesum = T.alloc_fragment([BI], accum_dtype)
lse = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(lse, 0)
b_s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
b_i, s_i = TokenIndices[b_s_i, 0], TokenIndices[b_s_i, 1]
bos, eos = Offsets[b_i], Offsets[b_i + 1]
r_i = bx % REPLICATE_H
g_i = by
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[bos + s_i, H0:H1, :D], Q_shared)
T.copy(Q[bos + s_i, H0:H1, D:], Q_tail_shared)
T.copy(Lse[bos + s_i, H0:H1], lse)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = (Indices[bos + s_i, g_i, i_i * BI + bi_i] <= max_kv_i) & (
Indices[bos + s_i, g_i, i_i * BI + bi_i] != -1)
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[bos + Indices[bos + s_i, g_i, i_i * BI + bi_i],
g_i, D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp(acc_s[h_i, bi_i] * sm_scale - lse[h_i])
T.reduce_sum(acc_s, reducesum, dim=0)
T.copy(reducesum, ReduceSum[bos + s_i, g_i, r_i, i_i * BI:i_i * BI + BI])
return tl_sparse_mla_topk_reducesum_kernel
def sparse_mla_topk_reducesum_interface(
q: torch.Tensor,
kv: torch.Tensor,
topk_indices: torch.Tensor,
lse: torch.Tensor,
offsets: torch.Tensor,
dim_v: int,
):
assert kv.shape[-2] == 1
seq_len, heads, dim_plus_tail_dim, topk = *q.shape, topk_indices.shape[-1]
REPLICATE_H = max(heads // 64, 1)
tail_dim = dim_plus_tail_dim - dim_v
token_indices = prepare_token_indices(offsets)
reducesum = torch.zeros([seq_len, 1, REPLICATE_H, topk], dtype=torch.float32, device=q.device)
kernel = tl_sparse_mla_topk_reducesum_impl(heads=heads, dim=dim_v, tail_dim=tail_dim, topk=topk)
kernel(q, kv, topk_indices, lse, offsets, token_indices, reducesum)
reducesum = reducesum.sum(dim=-2) # [batch, seq_len, 1, RH, topk] -> [batch, seq_len, 1, topk]
attn_score = reducesum / reducesum.sum(dim=-1, keepdim=True)
return attn_score
def ref_mla_topk_softmax(Q: torch.Tensor, K: torch.Tensor, TopkIndices: torch.Tensor,
offsets: torch.Tensor):
# q: [batch, seq_len, heads, dim]
# k: [batch, seq_len, dim]
sm_scale = Q.shape[-1]**-0.5
all_lse = []
all_topk_score = []
for i in range(offsets.shape[0] - 1):
q = Q[offsets[i]:offsets[i + 1]]
k = K[offsets[i]:offsets[i + 1]]
topk_indices = TopkIndices[offsets[i]:offsets[i + 1]]
seq_len = q.shape[0]
mask = (torch.arange(seq_len)[:, None]
>= torch.arange(seq_len)[None, :]).unsqueeze(-2).cuda()
logits = einsum(q, k, 's1 h d, s2 d -> s1 h s2') * sm_scale
logits = torch.where(mask, logits, float('-inf'))
score = F.softmax(logits, dim=-1, dtype=torch.float32)
score_sum = score.sum(dim=-2)
topk_score = torch.gather(score_sum, dim=-1, index=topk_indices.to(torch.int64))
topk_score = topk_score / topk_score.sum(dim=-1, keepdim=True)
max_logits = logits.amax(dim=-1).to(torch.float32)
lse = torch.log(
(logits - max_logits.unsqueeze(-1).to(torch.float32)).exp().sum(dim=-1)) + max_logits
all_lse.append(lse)
all_topk_score.append(topk_score)
lse = torch.cat(all_lse, dim=0)
topk_score = torch.cat(all_topk_score, dim=0)
return lse, topk_score
def test_kernel(
B=1,
S=2048,
H=16,
D=512,
tail_D=64,
topk=128,
):
torch.manual_seed(42)
q = torch.randn((S, H, D + tail_D)).cuda().bfloat16()
kv = torch.randn((S, D + tail_D)).cuda().bfloat16()
offsets = torch.tensor([0, 1023, S], dtype=torch.int32).cuda()
topk_indices = repeat(
torch.arange(topk, dtype=torch.int32).cuda(), 'k -> s k', s=S).contiguous()
lse, ref_attn_score = ref_mla_topk_softmax(q, kv, topk_indices, offsets)
kv = kv.unsqueeze(-2)
topk_indices = topk_indices.unsqueeze(-2)
attn_score = sparse_mla_topk_reducesum_interface(
q, kv, topk_indices, lse, offsets, dim_v=D).squeeze(-2)
print(
f"attn_score err: {get_abs_err(attn_score, ref_attn_score):.6f} ratio: {get_err_ratio(attn_score, ref_attn_score):.6f}"
)
if __name__ == '__main__':
test_kernel()
import torch
def get_abs_err(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
return (x - y).flatten().abs().max().item()
def get_err_ratio(y, x):
x = x.to(torch.float32)
y = y.to(torch.float32)
err = (x - y).flatten().square().mean().sqrt().item()
base = (x).flatten().square().mean().sqrt().item()
return err / base
def calculate_tensor_similarity(x, y, name="tensor"):
"""
Calculate similarity between two tensors using a normalized dot product metric.
Unlike torch.testing.assert_close which uses absolute/relative tolerance based on
element-wise differences, this function computes a global similarity score:
sim = 2 * <x, y> / (||x||^2 + ||y||^2)
This metric is scale-invariant and measures the cosine-like similarity normalized
by the magnitude of both tensors. It returns 1 for identical tensors and values
closer to 0 for dissimilar ones. This is particularly useful for comparing tensors
with varying magnitudes where relative errors matter more than absolute differences.
Args:
x: First tensor to compare
y: Second tensor to compare
name: Name of the tensor for logging purposes
Returns:
Similarity score in range [0, 1] where 1 means identical
"""
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print(f"\033[33mWARNING: {name} all zero\033[0m")
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
"""
Assert that two tensors are similar using a global similarity metric.
Key differences from torch.testing.assert_close:
- torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking
that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers
and requires all elements to satisfy the tolerance.
- assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the
normalized dot product. It's more robust to outliers and focuses on overall
tensor similarity rather than element-wise precision. This is better suited for
comparing large tensors where a few outlier elements shouldn't fail the test.
Args:
x: First tensor to compare
y: Second tensor to compare
eps: Maximum allowed difference (1 - similarity), default 1e-8
name: Name of the tensor for error messages
raise_assert: Whether to raise assertion error on failure
"""
sim = calculate_tensor_similarity(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print(
f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m"
)
if raise_assert:
assert False # noqa: B011
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment