"git@developer.sourcefind.cn:sugon_wxj/megatron-lm.git" did not exist on "27ecc17a1ea93b5b6b68145df06094de0aa53356"
Unverified Commit 65ac7454 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Add sparse mla examples (#896)

* Update README.md to include directory structure and file descriptions for deepseek_v32 example

* Refactor and clean up deepseek_v32 example scripts

- Removed unused imports and functions from `fp8_mqa_logits.py` to streamline the code.
- Improved formatting and readability in `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` by adjusting function signatures and indentation.
- Added `# ruff: noqa` comments to suppress linting warnings in multiple files.
- Enhanced the `generate_random_cu_seqlens` function in `utils.py` for better clarity and organization.
- Updated print statements for consistency in output formatting.
parent 78664e24
Comming Soon.
## Directory Structure
```
deepseek_v32/
├── README.md # This file
├── fp8_mqa_logits.py # FP8 Indexer
├── sparse_mla_fwd.py # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
```
# ruff: noqa
import itertools
import tilelang
from tilelang import language as T
import torch
from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8
def display_error_message(msg):
print(f"\033[31mWARNING: {msg}\033[0m")
def compute_correlation(a, b, label="tensor"):
a, b = a.data.double(), b.data.double()
norm_sum = (a * a + b * b).sum()
if norm_sum == 0:
display_error_message(f"{label} all zero")
return 1
correlation = 2 * (a * b).sum() / norm_sum
return correlation
def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True):
a_finite = torch.isfinite(a)
b_finite = torch.isfinite(b)
if not torch.all(a_finite == b_finite):
display_error_message(f"{tensor_name} Error: isfinite mask mismatch")
if should_raise:
assert False
if not torch.isclose(
a.masked_fill(a_finite, 0),
b.masked_fill(b_finite, 0),
rtol=0,
atol=0,
equal_nan=True,
).all():
display_error_message(f"{tensor_name} Error: nonfinite value mismatch")
if should_raise:
assert False
a = a.masked_fill(~a_finite, 0)
b = b.masked_fill(~b_finite, 0)
correlation = compute_correlation(a, b, tensor_name)
difference = 1.0 - correlation
if not (0 <= difference <= tolerance):
display_error_message(f"{tensor_name} Error: {difference}")
if should_raise:
assert False
return difference
def get_configs():
iter_params = dict(
block_N=[32, 64, 128],
num_stages=[0, 1, 2],
threads=[128, 256],
block_Q=[1, 2, 4],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
class SupplyProg:
def __init__(self):
self.tensors_dict = {}
def get_key(self, shape, dtype) -> str:
return f"{shape}-{dtype}"
def supply_prog(self, params):
shapes = [p.shape for p in params]
dtypes = [p.dtype for p in params]
tensor_list = []
for shape, dtype in zip(shapes, dtypes):
key = self.get_key(shape, dtype)
if key not in self.tensors_dict:
self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda")
tensor_list.append(self.tensors_dict[key])
else:
tensor_list.append(self.tensors_dict[key])
return tensor_list
supply_prog = SupplyProg()
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},)
def mqa_attn_return_logits(
heads,
index_dim,
block_N=256,
num_stages=3,
threads=512,
block_Q=None,
):
if block_Q is None:
block_Q = 128 // heads
dtype = "float8_e4m3"
accum_dtype = "float"
index_dtype = "int32"
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
index_q_shape = [seq_len * heads, index_dim]
index_k_shape = [seq_len_kv, index_dim]
index_k_scale_shape = [seq_len_kv]
logits_shape = [seq_len, seq_len_kv]
@T.prim_func
def mqa_attn_return_logits_kernel(
IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore
IndexK: T.Tensor(index_k_shape, dtype), # type: ignore
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore
Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore
Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype)
seq_len_i = bx * block_Q
cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype)
T.no_set_max_nreg()
cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i],
seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i],
seq_len_kv))
T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
T.copy(Weights[seq_len_i, 0], weights)
for nbn_i in T.Pipelined(
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
for bq_i, bn_i in T.Parallel(block_Q, block_N):
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
logits[bn_i, bq_i])
return mqa_attn_return_logits_kernel
@tilelang.jit
def clean_logits_(
threads: int = 512,
block_K: int = 4096,
):
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
dtype = "float"
indices_dtype = "int32"
@T.prim_func
def clean_logits_kernel(
Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore
):
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
cu_k_s = T.alloc_local([1], indices_dtype)
cu_k_e = T.alloc_local([1], indices_dtype)
cu_k_s[0] = CuSeqLenKS[bx]
cu_k_e[0] = CuSeqLenKE[bx]
for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
for k_i in T.serial(block_K // threads):
idx = n_i * block_K + k_i * threads + tx
if idx < cu_k_s[0] or idx >= cu_k_e[0]:
Logits[bx, idx] = -T.infinity(dtype)
return clean_logits_kernel
def mqa_attn_return_logits_interface(q,
kv,
kv_scales,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
clean_logits=True):
seq_len, heads, index_dim = q.shape
seq_len_kv = kv.shape[0]
clean_logits_kernel = clean_logits_()
mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32)
mqa_attn_return_logits_kernel(
q.view(seq_len * heads, index_dim),
kv,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
if clean_logits:
clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke)
return logits
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
k = kv
q = q.float()
k = k.float()
seq_len_kv = kv.shape[0]
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi
score = torch.einsum('mhd,nd->hmn', q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))
cost = mask.sum()
return logits, cost
if __name__ == "__main__":
torch.manual_seed(0)
S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)
ks, ke = generate_random_cu_seqlens(
per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
logits_ref, cost_ref = ref_fp8_mqa_logits(
q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits_tl = mqa_attn_return_logits_interface(
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
diff = validate_tensor_match(
logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
print(f"diff: {diff}")
from tilelang.profiler import do_bench
def logits_fn():
return mqa_attn_return_logits_interface(
q=q_fp8,
kv=kv_fp8,
kv_scales=kv_scales,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
logits_fn()
print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50))
logits_ms = do_bench(logits_fn, warmup=100, rep=100)
logits_flops = 2 * cost_ref * H * D
logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
print(f"cost_ref: {cost_ref}")
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
out_idx=[-2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=64,
num_stages=2,
threads=256,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (topk %
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.symbolic("batch")
seq_len = T.symbolic("seq_len")
seq_len_kv = T.symbolic("seq_len_kv")
head_kv = heads // kv_group
q_shape = [batch, seq_len, heads, dim + tail_dim]
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert (
kv_group == 1
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(
seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
O_shared = T.alloc_shared([H_per_block, D], dtype)
Lse_shared = T.alloc_shared([H_per_block], accum_dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, O_shared)
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
T.copy(sumexp, Lse_shared)
T.copy(sumexp, Lse[b_i, s_i, H0:H1])
return main
def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = d_v
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
assert kv.shape[0] == batch
_, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk)
kernel = sparse_mla_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual)
out, lse = kernel(q, kv, indices)
return out, lse
def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
q = q.float()
kv = kv.float()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(
0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd():
B, S, SKV, H, HKV, DQK, DV, topk, dtype = (
1,
4096,
32768,
128,
1,
576,
512,
2048,
torch.bfloat16,
)
torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
def fn():
return sparse_mla_fwd_interface(q, kv, indices)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=100,
warmup=250,
)
print(f"Average time: {ms:.3f} ms")
print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_fwd()
This diff is collapsed.
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import contextlib
import functools
import logging
import os
import sys
from enum import Enum
from functools import lru_cache
from typing import Any, Callable, Dict, Literal, Optional, Tuple
from packaging import version
def _is_equal(a, b):
if isinstance(a, torch.Tensor):
return a is b
# Whitelist of types that are safe to compare by value for caching.
if isinstance(a, (int, float, str, bool, type(None))) and isinstance(
b, (int, float, str, bool, type(None))):
return a == b
# For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check.
return False
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent result of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
If the function is called again with the same input tensors, it will return the cached result.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
last_args: Optional[Tuple] = None
last_kwargs: Optional[Dict] = None
last_result: Any = None
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal last_args, last_kwargs, last_result
if last_args is not None and last_kwargs is not None:
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
# For Tensors, check for object identity. For other types, check for equality.
# Python caches small integers, so `is` works for them but not for large integers like 4096.
if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \
set(kwargs.keys()) == set(last_kwargs.keys()) and \
all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()):
return last_result
result = fn(*args, **kwargs)
last_args, last_kwargs, last_result = args, kwargs, result
return result
return wrapper
@tensor_cache
def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int):
seq_idx = cu_seqlens.new_zeros(seq_len + 1)
seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx))
seq_idx.cumsum_(0)
return seq_idx[:-1]
@tensor_cache
def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
seq_len: int) -> torch.IntTensor:
seq_idx_for_q = torch.full((seq_len,),
len(cu_seqlens_qs),
dtype=torch.int32,
device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i
return seq_idx_for_q
@tensor_cache
def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor:
cu_seqlen_ks_for_each_q = torch.gather(
input=torch.cat([
cu_seqlens_ks,
torch.full((1,),
torch.iinfo(torch.int32).max,
dtype=torch.int32,
device=cu_seqlens_qs.device)
]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
return cu_seqlen_ks_for_each_q.int()
@tensor_cache
def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor,
cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor,
q_start_idxs: torch.LongTensor, seq_len: int,
kv_stride: int) -> torch.IntTensor:
cu_seqlen_ke_for_each_q = torch.gather(
input=torch.cat(
[cu_seqlens_ke,
torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]),
dim=0,
index=cal_seq_idx_for_q(
cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long())
casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,),
dtype=torch.int32,
device=cu_seqlens_qs.device)
for i in range(len(cu_seqlens_qs)):
casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange(
q_start_idxs[i],
q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i],
dtype=torch.int32,
device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i]
cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q)
return cu_seqlen_ke_for_each_q.int()
@tensor_cache
def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor,
cu_seqlens_k: torch.LongTensor = None,
offs_q: torch.LongTensor = None,
*,
seq_len: int,
kv_stride: int = 1,
cp_rank: int = 0,
cp_size: int = 1,
balanced_cp=False):
'''
seq_len: seq len per cp rank
balanced cp slice assignment: 0 1 2 3 3 2 1 0
'''
n_seq = len(cu_seqlens_q) - 1
assert n_seq > 0
assert cu_seqlens_q.shape == (n_seq + 1,)
seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len * cp_size)
qs = cu_seqlens_q.gather(0, seq_idx)
pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs
if offs_q is not None:
assert offs_q.shape == (n_seq,), offs_q.shape
qoff = offs_q.gather(0, seq_idx)
pos += qoff
if cu_seqlens_k is None or cu_seqlens_k is cu_seqlens_q:
ks = qs
else:
assert cu_seqlens_k.shape == (n_seq + 1,)
ks = cu_seqlens_k.gather(0, seq_idx)
ke = ks + (pos + 1) // kv_stride
if cp_size == 1:
pass
elif balanced_cp:
assert cp_size % 2 == 0, cp_size
def f(x: torch.Tensor):
chunks = x.chunk(cp_size * 2)
return torch.cat([
chunks[cp_rank],
chunks[cp_size - cp_rank - 1],
])
ks = f(ks)
ke = f(ke)
else:
ks = ks.chunk(cp_size)[cp_rank]
ke = ke.chunk(cp_size)[cp_rank]
return ks, ke
def ceil_to_ue8m0(x: torch.Tensor):
assert x.view(-1).amax().item() > 0
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int],
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512):
total_seqlen = per_cp_seqlen * cp_size
cu_seqlens = torch.randint(0, average_q_len * 2, (total_seqlen // average_q_len * 2,)).cuda()
last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0]
cu_seqlens = cu_seqlens[:last_seq_id]
if cu_seqlens.sum() < total_seqlen:
cu_seqlens = torch.cat([cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()])
cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0)
cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0)
cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]])
cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]])
cu_seqlens_qe = cu_seqlens_cumsum.clone()
cu_seqlens_ke = cu_seqlens_k_cumsum.clone()
cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q(
cu_seqlens_qs=cu_seqlens_qs,
cu_seqlens_qe=cu_seqlens_qe,
cu_seqlens_ks=cu_seqlens_ks,
seq_len=total_seqlen,
)
cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q(
cu_seqlens_qs=cu_seqlens_qs,
cu_seqlens_qe=cu_seqlens_qe,
cu_seqlens_ks=cu_seqlens_ks,
cu_seqlens_ke=cu_seqlens_ke,
q_start_idxs=torch.zeros_like(cu_seqlens_qs),
seq_len=total_seqlen,
kv_stride=kv_stride,
)
assert per_cp_seqlen % 2 == 0
per_chunk_seqlen = per_cp_seqlen // 2
slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen)
slice_long = slice(
total_seqlen - (cp_rank + 1) * per_chunk_seqlen,
total_seqlen - cp_rank * per_chunk_seqlen,
)
ks = torch.cat([
cu_seqlens_ks_for_each_q[slice_short],
cu_seqlens_ks_for_each_q[slice_long],
])
ke = torch.cat([
cu_seqlens_ke_for_each_q[slice_short],
cu_seqlens_ke_for_each_q[slice_long],
])
assert len(ks) == len(ke) == per_cp_seqlen
return ks, ke
def print_red_warning(message):
print(f"\033[31mWARNING: {message}\033[0m")
def calc_sim(x, y, name="tensor"):
x, y = x.data.double(), y.data.double()
denominator = (x * x + y * y).sum()
if denominator == 0:
print_red_warning(f'{name} all zero')
return 1
sim = 2 * (x * y).sum() / denominator
return sim
def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
sim = calc_sim(x, y, name)
diff = 1. - sim
if not (0 <= diff <= eps):
print_red_warning(f'{name} Error: {diff}')
if raise_assert:
assert False # noqa: B011
if __name__ == "__main__":
seq_len = 32768
cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda")
last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0]
cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0)
cu_seqlens_qs = torch.cat(
[torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum])
cu_seqlens_qe = torch.cat(
[cu_seqlens_cumsum,
torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len])
from tilelang.profiler import do_bench
fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) # noqa: E731
ms = do_bench(fn, warmup=25, rep=100)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment