Commit bc2d5632 authored by root's avatar root
Browse files

init

parents
Pipeline #3222 failed with stages
in 0 seconds
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
@staticmethod
@contiguous
@autocast_custom_bwd
def backward(ctx, do_slc, do_swa):
q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors
dq, dk, dv = parallel_nsa_bwd(
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
do_slc=do_slc,
do_swa=do_swa,
block_indices=ctx.block_indices,
block_counts=ctx.block_counts,
block_size=ctx.block_size,
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices)
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@triton.heuristics({'USE_OFFSETS': lambda args: args['offsets'] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
)
@triton.jit(do_not_specialize=['T'])
def parallel_nsa_bwd_kernel_dkv(q, k, v, lse_slc, lse_swa, delta_slc, delta_swa, do_slc, do_swa, dk,
dv, block_mask, offsets, chunk_indices, scale, T, B: tl.constexpr,
H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr,
V: tl.constexpr, M: tl.constexpr, BS: tl.constexpr,
WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
USE_OFFSETS: tl.constexpr):
i_v, i_s, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_s = tl.load(chunk_indices + i_s * 2).to(tl.int32), tl.load(chunk_indices + i_s * 2 +
1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_s * BS, 0), (BS, BK),
(1, 0))
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV),
(BS, BV), (1, 0))
p_dk = tl.make_block_ptr(dk + (i_v * B * T * H + bos * H + i_h) * K, (T, K), (H * K, 1),
(i_s * BS, 0), (BS, BK), (1, 0))
p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_s * BS, i_v * BV),
(BS, BV), (1, 0))
# [BS, BK]
b_k = tl.load(p_k, boundary_check=(0, 1))
b_dk = tl.zeros([BS, BK], dtype=tl.float32)
# [BS, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
b_dv = tl.zeros([BS, BV], dtype=tl.float32)
for i in range(i_s * BS, T):
b_m_slc = tl.load(block_mask + (bos + i) * H * M + i_h * M + i_s)
if b_m_slc:
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_slc = tl.make_block_ptr(do_slc + (bos + i) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G)
p_delta_slc = delta_slc + (bos + i) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1))
# [G]
b_lse_slc = tl.load(p_lse_slc)
b_delta_slc = tl.load(p_delta_slc)
# [BS, G]
b_s_slc = tl.dot(b_k, tl.trans(b_q))
b_p_slc = tl.exp(b_s_slc - b_lse_slc[None, :])
b_p_slc = tl.where((i >= (i_s * BS + tl.arange(0, BS)))[:, None], b_p_slc, 0)
# [BS, G] @ [G, BV] -> [BS, BV]
b_dv += tl.dot(b_p_slc.to(b_do_slc.dtype), b_do_slc)
# [BS, BV] @ [BV, G] -> [BS, G]
b_dp_slc = tl.dot(b_v, tl.trans(b_do_slc))
# [BS, G]
b_ds_slc = b_p_slc * (b_dp_slc - b_delta_slc[None, :])
# [BS, G] @ [G, BK] -> [BS, BK]
b_dk += tl.dot(b_ds_slc.to(b_q.dtype), b_q)
if WS > 0:
o_s = i_s * BS + tl.arange(0, BS)
if max(i_s * BS, i - WS + 1) < min((i_s + 1) * BS, i + 1):
p_q = tl.make_block_ptr(q + (bos + i) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0),
(G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_swa = tl.make_block_ptr(do_swa + (bos + i) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G)
p_delta_swa = delta_swa + (bos + i) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1))
# [G]
b_lse_swa = tl.load(p_lse_swa)
b_delta_swa = tl.load(p_delta_swa)
# [BS, G]
b_s_swa = tl.dot(b_k, tl.trans(b_q))
b_p_swa = tl.exp(b_s_swa - b_lse_swa[None, :])
b_p_swa = tl.where((i >= o_s and (i - WS) < o_s)[:, None], b_p_swa, 0)
# [BS, G] @ [G, BV] -> [BS, BV]
b_dv += tl.dot(b_p_swa.to(b_do_swa.dtype), b_do_swa)
# [BS, BV] @ [BV, G] -> [BS, G]
b_dp_swa = tl.dot(b_v, tl.trans(b_do_swa))
# [BS, G]
b_ds_swa = b_p_swa * (b_dp_swa - b_delta_swa[None, :])
# [BS, G] @ [G, BK] -> [BS, BK]
b_dk += tl.dot(b_ds_swa.to(b_q.dtype), b_q)
tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)})
@triton.jit
def parallel_nsa_kernel_mask(block_indices, block_counts, block_mask, T: tl.constexpr,
H: tl.constexpr, S: tl.constexpr, BS: tl.constexpr, NS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_b, i_hs = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_h, i_s = i_hs // S, i_hs % S
b_i = tl.load(block_indices + i_b * T * H * S + i_t * H * S + i_h * S + i_s)
if USE_BLOCK_COUNTS:
b_m = b_i * BS <= i_t and i_s < tl.load(block_counts + i_b * T * H + i_t * H + i_h)
else:
b_m = b_i * BS <= i_t
if b_i < NS and b_i >= 0:
tl.store(block_mask + i_b * T * H * NS + i_t * H * NS + i_h * NS + b_i,
b_m.to(block_mask.dtype.element_ty))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor)
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
)
@triton.jit(do_not_specialize=['T'])
def parallel_nsa_bwd_kernel_dq(q, k, v, lse_slc, delta_slc, do_slc, lse_swa, delta_swa, do_swa, dq,
scale, block_indices, block_counts, offsets, token_indices, T,
B: tl.constexpr, H: tl.constexpr, HQ: tl.constexpr, G: tl.constexpr,
K: tl.constexpr, V: tl.constexpr, S: tl.constexpr, BS: tl.constexpr,
WS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,
USE_OFFSETS: tl.constexpr, USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
q += (bos + i_t) * HQ * K
do_slc += (bos + i_t) * HQ * V
lse_slc += (bos + i_t) * HQ
delta_slc += (bos + i_t) * HQ
if WS > 0:
do_swa += (bos + i_t) * HQ * V
lse_swa += (bos + i_t) * HQ
delta_swa += (bos + i_t) * HQ
dq += (i_v * B * T + bos + i_t) * HQ * K
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
p_q = tl.make_block_ptr(q, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
p_dq = tl.make_block_ptr(dq, (HQ, K), (K, 1), (i_h * G, 0), (G, BK), (1, 0))
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_do_slc = tl.make_block_ptr(do_slc, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_slc = lse_slc + i_h * G + tl.arange(0, G)
p_delta_slc = delta_slc + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_slc = tl.load(p_do_slc, boundary_check=(0, 1))
# [G]
b_lse_slc = tl.load(p_lse_slc)
b_delta_slc = tl.load(p_delta_slc)
# [G, BK]
b_dq_slc = tl.zeros([G, BK], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BV, BS]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_p_slc = tl.exp(b_s_slc - b_lse_slc[:, None])
b_p_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_slc, 0)
# [G, BV] @ [BV, BS] -> [G, BS]
b_dp_slc = tl.dot(b_do_slc, b_v_slc)
b_ds_slc = b_p_slc * (b_dp_slc.to(tl.float32) - b_delta_slc[:, None])
# [G, BS] @ [BS, BK] -> [G, BK]
b_dq_slc += tl.dot(b_ds_slc.to(b_k_slc.dtype), tl.trans(b_k_slc))
b_dq_slc *= scale
if WS > 0:
p_do_swa = tl.make_block_ptr(do_swa, (HQ, V), (V, 1), (i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + i_h * G + tl.arange(0, G)
p_delta_swa = delta_swa + i_h * G + tl.arange(0, G)
# [G, BV]
b_do_swa = tl.load(p_do_swa, boundary_check=(0, 1))
# [G]
b_lse_swa = tl.load(p_lse_swa)
b_delta_swa = tl.load(p_delta_swa)
# [G, BK]
b_dq_swa = tl.zeros([G, BK], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (V, T), (1, H * V), (i_v * BV, i_s), (BV, BS), (0, 1))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BV, BS]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_p_swa = tl.exp(b_s_swa - b_lse_swa[:, None])
b_p_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_p_swa, 0)
# [G, BV] @ [BV, BS] -> [G, BS]
b_dp_swa = tl.dot(b_do_swa, b_v_swa)
b_ds_swa = b_p_swa * (b_dp_swa.to(tl.float32) - b_delta_swa[:, None])
# [G, BS] @ [BS, BK] -> [G, BK]
b_dq_swa += tl.dot(b_ds_swa.to(b_k_swa.dtype), tl.trans(b_k_swa))
b_dq_swa *= scale
if WS == 0:
tl.store(p_dq, b_dq_slc.to(p_dq.dtype.element_ty), boundary_check=(0, 1))
else:
tl.store(p_dq, (b_dq_slc + b_dq_swa).to(p_dq.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BS, BV]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf'))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
b_r_swa = tl.exp(b_mp_swa - b_m_swa)
# [G, BS]
b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None])
# [G]
b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1)
# [G, BV]
b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa)
b_mp_swa = b_m_swa
b_o_swa = b_o_swa / b_acc_swa[:, None]
b_m_swa += tl.log(b_acc_swa)
tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty))
@triton.jit
def parallel_nsa_bwd_kernel_preprocess(o, do, delta, B: tl.constexpr, V: tl.constexpr):
i_n = tl.program_id(0)
o_d = tl.arange(0, B)
m_d = o_d < V
b_o = tl.load(o + i_n * V + o_d, mask=m_d, other=0)
b_do = tl.load(do + i_n * V + o_d, mask=m_d, other=0).to(tl.float32)
b_delta = tl.sum(b_o * b_do)
tl.store(delta + i_n, b_delta.to(delta.dtype.element_ty))
def parallel_nsa_block_mask(
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
offsets: torch.LongTensor,
block_size: int,
):
B, T, H, S = block_indices.shape
BS = block_size
if offsets is not None:
NS = triton.cdiv(prepare_lens(offsets).max().item(), BS)
else:
NS = triton.cdiv(T, BS)
block_mask = torch.zeros(B, T, H, NS, dtype=torch.bool, device=block_indices.device)
parallel_nsa_kernel_mask[(T, B, H * S)](
block_indices=block_indices,
block_counts=block_counts,
block_mask=block_mask,
T=T,
H=H,
S=S,
BS=BS,
NS=NS)
return block_mask
def parallel_nsa_bwd_preprocess(o: torch.Tensor, do: torch.Tensor):
V = o.shape[-1]
delta = torch.empty_like(o[..., 0], dtype=torch.float32)
parallel_nsa_bwd_kernel_preprocess[(delta.numel(),)](
o=o,
do=do,
delta=delta,
B=triton.next_power_of_2(V),
V=V,
)
return delta
def parallel_nsa_bwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
o_slc: torch.Tensor,
lse_slc: torch.Tensor,
do_slc: torch.Tensor,
o_swa: torch.Tensor,
lse_swa: torch.Tensor,
do_swa: torch.Tensor,
block_indices: torch.Tensor,
block_counts: Union[torch.LongTensor, int],
block_size: int = 64,
window_size: int = 0,
scale: float = None,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
BK = triton.next_power_of_2(K)
BV = min(128, triton.next_power_of_2(v.shape[-1]))
NV = triton.cdiv(V, BV)
delta_slc = parallel_nsa_bwd_preprocess(o_slc, do_slc)
delta_swa = parallel_nsa_bwd_preprocess(o_swa, do_swa) if window_size > 0 else None
dq = torch.empty(NV, *q.shape, dtype=q.dtype if NV == 1 else torch.float, device=q.device)
grid = (T, NV, B * H)
parallel_nsa_bwd_kernel_dq[grid](
q=q,
k=k,
v=v,
lse_slc=lse_slc,
delta_slc=delta_slc,
do_slc=do_slc,
lse_swa=lse_swa,
delta_swa=delta_swa,
do_swa=do_swa,
dq=dq,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
scale=scale,
T=T,
B=B,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV)
dq = dq.sum(0)
if offsets is not None:
chunk_indices = prepare_chunk_indices(offsets, BS)
NS = len(chunk_indices)
else:
chunk_indices = None
NS = triton.cdiv(T, BS)
# [B, T, H, M]
block_mask = parallel_nsa_block_mask(block_indices, block_counts, offsets, block_size)
dk = torch.empty(NV, *k.shape, dtype=k.dtype if NV == 1 else torch.float, device=q.device)
dv = torch.empty(v.shape, dtype=v.dtype, device=q.device)
grid = (NV, NS, B * H)
parallel_nsa_bwd_kernel_dkv[grid](
q=q,
k=k,
v=v,
lse_slc=lse_slc,
lse_swa=lse_swa,
delta_slc=delta_slc,
delta_swa=delta_swa,
do_slc=do_slc,
do_swa=do_swa,
dk=dk,
dv=dv,
block_mask=block_mask,
offsets=offsets,
chunk_indices=chunk_indices,
scale=scale,
T=T,
B=B,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
M=block_mask.shape[-1],
BS=BS,
WS=WS,
BK=BK,
BV=BV)
dk = dk.sum(0)
return dq, dk, dv
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
@staticmethod
@contiguous
@autocast_custom_bwd
def backward(ctx, do_slc, do_swa):
q, k, v, o_slc, lse_slc, o_swa, lse_swa = ctx.saved_tensors
dq, dk, dv = parallel_nsa_bwd(
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
do_slc=do_slc,
do_swa=do_swa,
block_indices=ctx.block_indices,
block_counts=ctx.block_counts,
block_size=ctx.block_size,
window_size=ctx.window_size,
scale=ctx.scale,
offsets=ctx.offsets,
token_indices=ctx.token_indices)
return dq.to(q), dk.to(k), dv.to(v), None, None, None, None, None, None, None, None
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
ref.backward(do)
ref_dq, q.grad = q.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dv, v.grad = v.grad.clone(), None
ref_dg_slc, g_slc.grad = g_slc.grad.clone(), None
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_size=block_size,
block_counts=block_counts,
)
print("tri", tri)
print("ref", ref)
tri.backward(do)
tri_dq, q.grad = q.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dv, v.grad = v.grad.clone(), None
tri_dg_slc, g_slc.grad = g_slc.grad.clone(), None
# assert_close(" o", ref, tri, 0.004)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dq, tri_dq, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dk, tri_dk, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=1e-2)
torch.testing.assert_close(ref_dg_slc, tri_dg_slc, atol=1e-2, rtol=1e-2)
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
# if USE_BLOCK_COUNTS:
# NS = tl.load(block_counts + (bos + i_t) * H + i_h)
# else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o, lse = parallel_nsa_fwd(
q=q, k=k, v=v, block_indices=block_indices, block_size=block_size, scale=scale)
ctx.save_for_backward(q, k, v, o, lse)
ctx.block_indices = block_indices
ctx.block_size = block_size
ctx.scale = scale
return o.to(q.dtype)
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
B, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 32, 1, 32, torch.float16
torch.random.manual_seed(0)
q = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda').requires_grad_(True)
k = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
v = torch.randn((B, T, H, D), dtype=dtype, device='cuda').requires_grad_(True)
g_slc = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.ones((B, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, T, HQ, D), dtype=dtype, device='cuda')
block_indices = torch.full((B, T, H, S), T, dtype=torch.long, device='cuda')
for b in range(B):
for t in range(T):
for h in range(H):
i_i = torch.randperm(max(1, (t // block_size)))[:S]
block_indices[b, t, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (B, T, H), device='cuda')
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_size=block_size,
block_counts=block_counts,
)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
@triton.heuristics({
'USE_OFFSETS': lambda args: args['offsets'] is not None,
'USE_BLOCK_COUNTS': lambda args: isinstance(args['block_counts'], torch.Tensor),
})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=['BS', 'BK', 'BV'],
)
@triton.jit
def parallel_nsa_fwd_kernel(q, k, v, o_slc, o_swa, lse_slc, lse_swa, scale, block_indices,
block_counts, offsets, token_indices, T, H: tl.constexpr,
HQ: tl.constexpr, G: tl.constexpr, K: tl.constexpr, V: tl.constexpr,
S: tl.constexpr, BS: tl.constexpr, WS: tl.constexpr, BK: tl.constexpr,
BV: tl.constexpr, USE_OFFSETS: tl.constexpr,
USE_BLOCK_COUNTS: tl.constexpr):
i_t, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if USE_OFFSETS:
i_n, i_t = tl.load(token_indices + i_t * 2).to(tl.int32), tl.load(token_indices + i_t * 2 +
1).to(tl.int32)
bos, eos = tl.load(offsets + i_n).to(tl.int32), tl.load(offsets + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
k += (bos * H + i_h) * K
v += (bos * H + i_h) * V
block_indices += (bos + i_t) * H * S + i_h * S
if USE_BLOCK_COUNTS:
NS = tl.load(block_counts + (bos + i_t) * H + i_h)
else:
NS = S
p_q = tl.make_block_ptr(q + (bos + i_t) * HQ * K, (HQ, K), (K, 1), (i_h * G, 0), (G, BK),
(1, 0))
# the Q block is kept in the shared memory throughout the whole kernel
# [G, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
p_o_slc = tl.make_block_ptr(o_slc + (bos + i_t) * HQ * V, (HQ, V), (V, 1), (i_h * G, i_v * BV),
(G, BV), (1, 0))
p_lse_slc = lse_slc + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_slc = tl.zeros([G, BV], dtype=tl.float32)
b_m_slc = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_slc = tl.zeros([G], dtype=tl.float32)
for i in range(NS):
i_s = tl.load(block_indices + i).to(tl.int32) * BS
if i_s <= i_t and i_s >= 0:
p_k_slc = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_slc = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_slc = tl.load(p_k_slc, boundary_check=(0, 1))
# [BS, BV]
b_v_slc = tl.load(p_v_slc, boundary_check=(0, 1))
# [G, BS]
b_s_slc = tl.dot(b_q, b_k_slc)
b_s_slc = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_slc, float('-inf'))
# [G]
b_m_slc, b_mp_slc = tl.maximum(b_m_slc, tl.max(b_s_slc, 1)), b_m_slc
b_r_slc = tl.exp(b_mp_slc - b_m_slc)
# [G, BS]
b_p_slc = tl.exp(b_s_slc - b_m_slc[:, None])
# [G]
b_acc_slc = b_acc_slc * b_r_slc + tl.sum(b_p_slc, 1)
# [G, BV]
b_o_slc = b_o_slc * b_r_slc[:, None] + tl.dot(b_p_slc.to(b_q.dtype), b_v_slc)
b_mp_slc = b_m_slc
b_o_slc = b_o_slc / b_acc_slc[:, None]
b_m_slc += tl.log(b_acc_slc)
tl.store(p_o_slc, b_o_slc.to(p_o_slc.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_slc, b_m_slc.to(p_lse_slc.dtype.element_ty))
if WS > 0:
p_o_swa = tl.make_block_ptr(o_swa + (bos + i_t) * HQ * V, (HQ, V), (V, 1),
(i_h * G, i_v * BV), (G, BV), (1, 0))
p_lse_swa = lse_swa + (bos + i_t) * HQ + i_h * G + tl.arange(0, G)
# [G, BV]
b_o_swa = tl.zeros([G, BV], dtype=tl.float32)
b_m_swa = tl.full([G], float('-inf'), dtype=tl.float32)
b_acc_swa = tl.zeros([G], dtype=tl.float32)
for i_s in range(max(0, i_t - WS + 1), i_t + 1, BS):
p_k_swa = tl.make_block_ptr(k, (K, T), (1, H * K), (0, i_s), (BK, BS), (0, 1))
p_v_swa = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_s, i_v * BV), (BS, BV), (1, 0))
# [BK, BS]
b_k_swa = tl.load(p_k_swa, boundary_check=(0, 1))
# [BS, BV]
b_v_swa = tl.load(p_v_swa, boundary_check=(0, 1))
# [G, BS]
b_s_swa = tl.dot(b_q, b_k_swa)
b_s_swa = tl.where((i_t >= (i_s + tl.arange(0, BS)))[None, :], b_s_swa, float('-inf'))
# [G]
b_m_swa, b_mp_swa = tl.maximum(b_m_swa, tl.max(b_s_swa, 1)), b_m_swa
b_r_swa = tl.exp(b_mp_swa - b_m_swa)
# [G, BS]
b_p_swa = tl.exp(b_s_swa - b_m_swa[:, None])
# [G]
b_acc_swa = b_acc_swa * b_r_swa + tl.sum(b_p_swa, 1)
# [G, BV]
b_o_swa = b_o_swa * b_r_swa[:, None] + tl.dot(b_p_swa.to(b_q.dtype), b_v_swa)
b_mp_swa = b_m_swa
b_o_swa = b_o_swa / b_acc_swa[:, None]
b_m_swa += tl.log(b_acc_swa)
tl.store(p_o_swa, b_o_swa.to(p_o_swa.dtype.element_ty), boundary_check=(0, 1))
tl.store(p_lse_swa, b_m_swa.to(p_lse_swa.dtype.element_ty))
def parallel_nsa_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Union[torch.LongTensor, int],
block_size: int,
window_size: int,
scale: float,
offsets: Optional[torch.LongTensor] = None,
token_indices: Optional[torch.LongTensor] = None,
):
B, T, H, K, V, S = *k.shape, v.shape[-1], block_indices.shape[-1]
HQ = q.shape[2]
G = HQ // H
BS = block_size
WS = window_size
if torch.cuda.get_device_capability()[0] >= 9:
BK = min(256, triton.next_power_of_2(K))
BV = min(256, triton.next_power_of_2(V))
else:
BK = min(128, triton.next_power_of_2(K))
BV = min(128, triton.next_power_of_2(V))
NK = triton.cdiv(K, BK)
NV = triton.cdiv(V, BV)
assert NK == 1, "The key dimension can not be larger than 256"
grid = (T, NV, B * H)
o_slc = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device)
o_swa = torch.empty(B, T, HQ, V, dtype=v.dtype, device=q.device) if window_size > 0 else None
lse_slc = torch.empty(B, T, HQ, dtype=torch.float, device=q.device)
lse_swa = torch.empty(B, T, HQ, dtype=torch.float, device=q.device) if window_size > 0 else None
parallel_nsa_fwd_kernel[grid](
q=q,
k=k,
v=v,
o_slc=o_slc,
o_swa=o_swa,
lse_slc=lse_slc,
lse_swa=lse_swa,
scale=scale,
block_indices=block_indices,
block_counts=block_counts,
offsets=offsets,
token_indices=token_indices,
T=T,
H=H,
HQ=HQ,
G=G,
K=K,
V=V,
S=S,
BS=BS,
WS=WS,
BK=BK,
BV=BV,
)
return o_slc, lse_slc, o_swa, lse_swa
@torch.compile
class ParallelNSAFunction(torch.autograd.Function):
@staticmethod
@contiguous
@autocast_custom_fwd
def forward(ctx, q, k, v, block_indices, block_counts, block_size, window_size, scale, offsets):
ctx.dtype = q.dtype
# 2-d sequence indices denoting the offsets of tokens in each sequence
# for example, if the passed `offsets` is [0, 2, 6],
# then there are 2 and 4 tokens in the 1st and 2nd sequences respectively, and `token_indices` will be
# [[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [1, 3]]
token_indices = prepare_token_indices(offsets) if offsets is not None else None
o_slc, lse_slc, o_swa, lse_swa = parallel_nsa_fwd(
q=q,
k=k,
v=v,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
window_size=window_size,
scale=scale,
offsets=offsets,
token_indices=token_indices)
ctx.save_for_backward(q, k, v, o_slc, lse_slc, o_swa, lse_swa)
ctx.block_indices = block_indices
ctx.block_counts = block_counts
ctx.offsets = offsets
ctx.token_indices = token_indices
ctx.block_size = block_size
ctx.window_size = window_size
ctx.scale = scale
return o_slc.to(q.dtype), o_swa.to(q.dtype) if o_swa is not None else o_swa
def parallel_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
assert q.shape[2] % (k.shape[2] * 16) == 0, "Group size must be a multiple of 16 in NSA"
if isinstance(block_counts, int):
block_indices = block_indices[:, :, :, :block_counts]
block_counts = None
o_slc, o_swa = ParallelNSAFunction.apply(q, k, v, block_indices, block_counts, block_size,
window_size, scale, cu_seqlens)
if window_size > 0:
o = torch.addcmul(o_slc * g_slc.unsqueeze(-1), o_swa, g_swa.unsqueeze(-1))
else:
o = o_slc * g_slc.unsqueeze(-1)
if head_first:
o = rearrange(o, 'b t h d -> b h t d')
return o
if __name__ == "__main__":
N, T, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 64, 1, 32, torch.float16
torch.manual_seed(42)
# randomly split the sequence into N segments
offsets = torch.cat([
torch.tensor([0], dtype=torch.long),
torch.arange(16, T)[torch.randperm(T - 1)[:N - 1]],
torch.tensor([T], dtype=torch.long)
], 0).cuda().sort()[0]
# offsets.shape is [N+1]
# seq-first required for inputs with variable lengths
perm_q = torch.randperm(T, device='cuda')
perm_k = torch.randperm(T, device='cuda')
perm_v = torch.randperm(T, device='cuda')
q = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_q].view(1, T, 1, 1).expand(1, T, HQ, D).clone().requires_grad_(True)
k = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_k].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
v = torch.linspace(
0, 1, steps=T, dtype=dtype,
device='cuda')[perm_v].view(1, T, 1, 1).expand(1, T, H, D).clone().requires_grad_(True)
g_slc = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
g_swa = torch.rand((1, T, HQ), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((1, T, HQ, D), dtype=dtype, device='cuda')
token_indices = prepare_token_indices(offsets).tolist()
block_indices = torch.full((1, T, H, S), T, dtype=torch.long, device='cuda')
for i in range(T):
_, t = token_indices[i]
for h in range(H):
i_i = torch.randperm(max(1, triton.cdiv(t, block_size)))[:S]
block_indices[0, i, h, :len(i_i)] = i_i
block_indices = block_indices.sort(-1)[0]
block_counts = torch.randint(1, S + 1, (1, T, H), device='cuda')
ref = naive_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
tri = parallel_nsa(
q=q,
k=k,
v=v,
g_slc=g_slc,
g_swa=g_swa,
block_indices=block_indices,
block_counts=block_counts,
block_size=block_size,
cu_seqlens=offsets)
print("tri", tri)
print("ref", ref)
torch.testing.assert_close(ref, tri, atol=1e-2, rtol=1e-2)
# ruff: noqa
from typing import Optional
import torch
from typing import Union
from einops import rearrange, repeat
def naive_nsa(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g_slc: torch.Tensor,
g_swa: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: Optional[Union[torch.LongTensor, int]] = None,
block_size: int = 64,
window_size: int = 0,
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
Queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
Keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
Values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g_slc (torch.Tensor):
Gate score for selected attention of shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
g_swa (torch.Tensor):
Gate score for sliding attentionof shape `[B, T, HQ]` if `head_first=False` else `[B, HQ, T]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (Union[torch.LongTensor, int]):
Number of selected blocks for each token.
If a tensor is provided, with shape `[B, T, H]` if `head_first=True` else `[B, T, H]`,
each token can select the same number of blocks.
If not provided, it will default to `S`, Default: `None`.
block_size (int):
Selected block size. Default: 64.
window_size (int):
Sliding window size. Default: 0.
scale (Optional[int]):
Scale factor for attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format. Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
if scale is None:
scale = k.shape[-1]**-0.5
if cu_seqlens is not None:
assert q.shape[0] == 1, "batch size must be 1 when cu_seqlens are provided"
if head_first:
raise RuntimeError(
"Sequences with variable lengths are not supported for head-first mode")
if head_first:
q, k, v, block_indices = map(lambda x: rearrange(x, 'b h t d -> b t h d'),
(q, k, v, block_indices))
g_slc, g_swa = map(lambda x: rearrange(x, 'b h t -> b t h'), (g_slc, g_swa))
if isinstance(block_counts, torch.Tensor):
block_counts = rearrange(block_counts, 'b h t -> b t h')
dtype = q.dtype
G = q.shape[2] // k.shape[2]
BS = block_size
S = block_indices.shape[-1]
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
if isinstance(block_counts, torch.Tensor):
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o_slc = torch.zeros_like(v)
o_swa = torch.zeros_like(v) if window_size > 0 else None
varlen = True
if cu_seqlens is None:
varlen = False
B, T = q.shape[:2]
cu_seqlens = torch.cat(
[block_indices.new_tensor(range(0, B * T, T)),
block_indices.new_tensor([B * T])])
for i in range(len(cu_seqlens) - 1):
if not varlen:
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = q[i], k[i], v[i], g_slc[i], g_swa[
i], block_indices[i]
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[i]
else:
s_b = block_counts
else:
T = cu_seqlens[i + 1] - cu_seqlens[i]
q_b, k_b, v_b, g_slc_b, g_swa_b, i_b = map(
lambda x: x[0][cu_seqlens[i]:cu_seqlens[i + 1]],
(q, k, v, g_slc, g_swa, block_indices))
if isinstance(block_counts, torch.Tensor):
s_b = block_counts[0][cu_seqlens[i]:cu_seqlens[i + 1]]
else:
s_b = block_counts
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [HQ]
g_slc_i = g_slc_b[i_q]
# [HQ]
g_swa_i = g_swa_b[i_q]
# [S*BS, HQ]
i_i = i_b[i_q]
# [HQ]
if isinstance(block_counts, torch.Tensor):
s_i = s_b[i_q]
else:
s_i = s_b
# [S*BS, HQ, -1]
k_i_slc, v_i_slc = map(
lambda x: x.gather(
0,
i_i.clamp(0, T - 1).unsqueeze(-1).expand(*i_i.shape, x.shape[-1])), (k_b, v_b))
# [S*BS, HQ]
attn_slc = torch.einsum('h d, n h d -> n h', q_i, k_i_slc).masked_fill(
torch.logical_or(i_i < 0, i_i > i_q) |
(c >= s_i if block_counts is not None else False), float('-inf')).softmax(0)
if not varlen:
o_slc[i, i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
else:
o_slc[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_slc,
v_i_slc) * g_slc_i.unsqueeze(-1)
if window_size > 0:
k_i_swa, v_i_swa = map(lambda x: x[max(0, i_q - window_size + 1):i_q + 1],
(k_b, v_b))
attn_swa = torch.einsum('h d, n h d -> n h', q_i, k_i_swa).softmax(0)
if not varlen:
o_swa[i, i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
else:
o_swa[0][cu_seqlens[i] + i_q] = torch.einsum('n h, n h v -> h v', attn_swa,
v_i_swa) * g_swa_i.unsqueeze(-1)
if head_first:
o_slc = rearrange(o_slc, 'b t h d -> b h t d')
o_swa = rearrange(o_swa, 'b t h d -> b h t d')
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
def naive_nsa_simple(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, T, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(v)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
for i_q in range(T):
# [HQ, D]
q_i = q_b[i_q] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[i_q]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[i_q]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((i_i > i_q) | (c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, i_q] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
def naive_nsa_simple_inference(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_indices: torch.LongTensor,
block_counts: torch.LongTensor,
block_size: int = 64,
) -> torch.Tensor:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, 1, HQ, K]` if `head_first=False` else `[B, HQ, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
GQA is enforced here. The ratio of query heads (HQ) to key/value heads (H) must be a power of 2 and >=16.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
block_indices (torch.LongTensor):
Block indices of shape `[B, 1, H, S]` if `head_first=False` else `[B, H, T, S]`.
`S` is the maximum number of selected blocks for each query token, which is set to 16 in the paper.
block_counts (torch.LongTensor):
Block counts of shape `[B, 1, H]` if `head_first=False` else `[B, H, T]`.
block_size (int):
Selected block size. Default: 64.
Returns:
o (torch.Tensor):
Outputs of shape `[B, 1, HQ, V]` if `head_first=False` else `[B, HQ, T, V]`.
"""
scale = k.shape[-1]**-0.5
dtype = q.dtype
HQ = q.shape[2]
H = k.shape[2]
D = k.shape[-1]
G = HQ // H
BS = block_size
S = block_indices.shape[-1]
SELECTED_BLOCKS_SIZE = S * BS
k, v, block_indices = (repeat(x, 'b t h d -> b t (h g) d', g=G) for x in (k, v, block_indices))
block_counts = repeat(block_counts, 'b t h -> b t (h g)', g=G)
c = torch.arange(S).repeat_interleave(BS).unsqueeze(1).expand(-1, q.shape[2]).to(q.device)
q, k, v = map(lambda x: x.float(), (q, k, v))
o = torch.zeros_like(q)
B, T = q.shape[:2]
for i in range(B):
q_b, k_b, v_b, i_b, s_b = q[i], k[i], v[i], block_indices[i], block_counts[i]
# [T, HQ, S, BS] -> [T, HQ, S*BS]
i_b = i_b.unsqueeze(-1) * BS + i_b.new_tensor(range(BS))
# [T, HQ, S*BS] -> [T, S*BS, HQ]
i_b = i_b.view(T, block_indices.shape[2], -1).transpose(1, 2)
# [HQ, D]
q_i = q_b[0] * scale
# [S*BS, HQ] -> represents selected blocks for each query token
i_i = i_b[0]
# [HQ] -> represents the number of selected blocks for each query token
s_i = s_b[0]
k_i = torch.zeros((S * BS, HQ, D), device=k_b.device, dtype=k_b.dtype)
v_i = torch.zeros((S * BS, HQ, D), device=v_b.device, dtype=v_b.dtype)
for h in range(HQ):
for t in range(SELECTED_BLOCKS_SIZE):
selected_block_index = i_i[t, h]
k_i[t, h] = k_b[selected_block_index, h, :]
v_i[t, h] = v_b[selected_block_index, h, :]
# [S*BS, HQ]
attn = torch.einsum('h d, n h d -> n h', q_i, k_i)
attn = attn.masked_fill((c >= s_i), float('-inf'))
attn = torch.softmax(attn, dim=0)
o[i, 0] = torch.einsum('n h, n h v -> h v', attn, v_i)
return o.to(dtype)
git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e
\ No newline at end of file
# ruff: noqa
import tilelang.testing
from example_tilelang_nsa_fwd import main as main_fwd
from example_tilelang_nsa_decode import main as main_fwd_decode
def test_example_tilelang_nsa_fwd():
main_fwd()
def test_example_tilelang_nsa_fwd_decode():
main_fwd_decode()
if __name__ == "__main__":
tilelang.testing.main()
## Directory Structure
```
deepseek_v32/
├── README.md # This file
├── figures/ # Figures and diagrams
├── inference/ # Inference implementation folder
├── fp8_lighting_indexer.py # FP8 lighting indexer
├── sparse_mla_bwd.py # Sparse MLA backward implementation
├── sparse_mla_fwd.py # Sparse MLA forward implementation
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
├── topk_selector.py # Top-k selector implementation
```
## File Descriptions
### Architecture Overview
![DeepSeek V3.2 Architecture](./figures/v32_arch.png)
The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
3. **Multi-Query Attention** (`sparse_mla_fwd.py`, `sparse_mla_fwd_pipelined.py`, and `sparse_mla_bwd.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward and backward passes
### Lightning Indexer
Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.
The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:
```python
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)
```
After the matmul, we apply ReLU and aggregate across heads with learned weights:
```python
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i, h_i] = (
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
```
The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:
```python
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
```
The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.
### Top-k Selector
The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.
The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:
```python
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
input_idx = s*BLOCK_SIZE+tx
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
inval_int16 = convert_to_uint16(input[bx, input_idx])
T.atomic_add(s_histogram[inval_int16], 1)
```
The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:
```python
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
s_threshold_bin_id[0] = tx
```
Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:
```python
if l_bin_id32 > l_threshold_bin_id:
pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
index[bx, pos] = input_idx
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
s_input_idx[0, pos] = input_idx
```
Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.
### Sparse MLA Forward
The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.
Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:
```python
# Dense MLA: iterate over full sequence
loop_range = T.ceildiv(seqlen_kv, block_N)
for k in T.Pipelined(loop_range, num_stages=2):
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
# ... compute attention over this block
```
Sparse MLA only loads KV positions selected by the top-k selector:
```python
# Sparse MLA: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
# ... compute attention over selected tokens
```
This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:
```python
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
```
Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.
### Sparse MLA Forward (Pipelined)
The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.
The key difference is splitting the warp groups into specialized roles:
```python
if tx < 128:
# Consumer 0: computes left half of output (D//2 dimensions)
# Handles QK matmul, softmax, and PV for left half
elif tx >= 128 and tx < 256:
# Consumer 1: computes right half of output (D//2 dimensions)
# Only does PV matmul for right half
elif tx >= 256:
# Producer: loads KV data from global memory
# Uses async copy with barriers to feed consumers
```
The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:
```python
# Producer alternates between two buffers
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 0
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
# ... load KV into buffer 1
T.cp_async_barrier_noinc(bar_k_1_ready[0])
```
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
### Sparse MLA Backward
The Sparse MLA backward kernel (`sparse_mla_bwd.py`) computes gradients with respect to queries (dQ) and key-values (dKV) for the sparse attention mechanism. Like the forward pass, it processes only the selected top-k indices, maintaining O(seq_len * topk) complexity.
The backward pass consists of three main stages:
**1. Preprocessing**: Computes delta values (row-wise dot products of output and output gradient):
```python
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], o)
T.copy(dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND], do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
```
**2. Main Backward Computation**: Computes gradients through sparse attention:
```python
# Sparse MLA backward: iterate over selected indices only
for i_i in T.Pipelined(NI, num_stages=num_stages):
# Load KV data for selected indices
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BI + bi_i], bz, d_i]
# Recompute attention scores for backward
T.gemm(Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
# Apply softmax gradient: dP = P * (dP_raw - Delta)
for h_i, bi_i in T.Parallel(padded_H, BI):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
```
The key gradient computations are:
- **dQ = dP @ K** (query gradients)
- **dK = dP^T @ Q** (key gradients)
- **dV = P^T @ dO** (value gradients)
**3. Atomic Sparse Updates**: Uses atomic operations for dKV accumulation:
```python
# Atomically update dKV at selected indices
for bi_i, d_i in T.Parallel(BI // split_store, D // 4):
T.atomic_addx4(dKV[by, Indices[by, s_i, bz, i_i * BI + bi_i + s * (BI // split_store)], bz, d_i * 4],
acc_dkv_shared[bi_i, d_i * 4])
```
**Performance**: The sparse MLA backward achieves excellent performance:
- **H800 SXM**: ~100 TFlops
- **H200 SXM**: ~115 TFlops
The implementation efficiently handles the irregular memory access patterns inherent in sparse attention while maintaining high compute utilization through careful memory management and atomic update strategies. Note that this is a relatively naive implementation that requires further optimization.
# ruff: noqa
import itertools
import tilelang
from tilelang import language as T
import torch
from utils import generate_random_cu_seqlens, per_custom_dims_cast_to_fp8
def display_error_message(msg):
print(f"\033[31mWARNING: {msg}\033[0m")
def compute_correlation(a, b, label="tensor"):
a, b = a.data.double(), b.data.double()
norm_sum = (a * a + b * b).sum()
if norm_sum == 0:
display_error_message(f"{label} all zero")
return 1
correlation = 2 * (a * b).sum() / norm_sum
return correlation
def validate_tensor_match(a, b, tolerance=1e-8, tensor_name="tensor", should_raise=True):
a_finite = torch.isfinite(a)
b_finite = torch.isfinite(b)
if not torch.all(a_finite == b_finite):
display_error_message(f"{tensor_name} Error: isfinite mask mismatch")
if should_raise:
assert False
if not torch.isclose(
a.masked_fill(a_finite, 0),
b.masked_fill(b_finite, 0),
rtol=0,
atol=0,
equal_nan=True,
).all():
display_error_message(f"{tensor_name} Error: nonfinite value mismatch")
if should_raise:
assert False
a = a.masked_fill(~a_finite, 0)
b = b.masked_fill(~b_finite, 0)
correlation = compute_correlation(a, b, tensor_name)
difference = 1.0 - correlation
if not (0 <= difference <= tolerance):
display_error_message(f"{tensor_name} Error: {difference}")
if should_raise:
assert False
return difference
def get_configs():
iter_params = dict(
block_N=[32, 64, 128],
num_stages=[0, 1, 2],
threads=[128, 256],
block_Q=[1, 2, 4],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]
class SupplyProg:
def __init__(self):
self.tensors_dict = {}
def get_key(self, shape, dtype) -> str:
return f"{shape}-{dtype}"
def supply_prog(self, params):
shapes = [p.shape for p in params]
dtypes = [p.dtype for p in params]
tensor_list = []
for shape, dtype in zip(shapes, dtypes):
key = self.get_key(shape, dtype)
if key not in self.tensors_dict:
self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda")
tensor_list.append(self.tensors_dict[key])
else:
tensor_list.append(self.tensors_dict[key])
return tensor_list
supply_prog = SupplyProg()
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
},)
def mqa_attn_return_logits(
heads,
index_dim,
block_N=256,
num_stages=3,
threads=512,
block_Q=None,
):
if block_Q is None:
block_Q = 128 // heads
dtype = "float8_e4m3"
accum_dtype = "float"
index_dtype = "int32"
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
index_q_shape = [seq_len * heads, index_dim]
index_k_shape = [seq_len_kv, index_dim]
index_k_scale_shape = [seq_len_kv]
logits_shape = [seq_len, seq_len_kv]
@T.prim_func
def mqa_attn_return_logits_kernel(
IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore
IndexK: T.Tensor(index_k_shape, dtype), # type: ignore
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore
Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore
Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore
):
with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype)
index_k_shared = T.alloc_shared([block_N, index_dim], dtype)
index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype)
s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype)
s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype)
logits = T.alloc_fragment([block_N, block_Q], accum_dtype)
weights = T.alloc_fragment([block_Q, heads], accum_dtype)
seq_len_i = bx * block_Q
cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype)
cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648
for bq_i in T.serial(block_Q):
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i],
seq_len_kv))
for bq_i in T.serial(block_Q):
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i],
seq_len_kv))
T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
T.copy(Weights[seq_len_i, 0], weights)
for nbn_i in T.Pipelined(
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
T.gemm(
index_k_shared,
index_q_shared,
s,
transpose_B=True,
clear_accum=True,
policy=T.GemmWarpPolicy.FullCol,
)
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
s_reshaped[bn_i, bq_i,
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
for bq_i, bn_i in T.Parallel(block_Q, block_N):
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
logits[bn_i, bq_i])
return mqa_attn_return_logits_kernel
@tilelang.jit
def clean_logits_(
threads: int = 512,
block_K: int = 4096,
):
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
dtype = "float"
indices_dtype = "int32"
@T.prim_func
def clean_logits_kernel(
Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore
CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore
CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore
):
with T.Kernel(seq_len, threads=threads) as bx:
tx = T.thread_binding(0, threads, thread="threadIdx.x")
cu_k_s = T.alloc_local([1], indices_dtype)
cu_k_e = T.alloc_local([1], indices_dtype)
cu_k_s[0] = CuSeqLenKS[bx]
cu_k_e[0] = CuSeqLenKE[bx]
for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)):
for k_i in T.serial(block_K // threads):
idx = n_i * block_K + k_i * threads + tx
if idx < cu_k_s[0] or idx >= cu_k_e[0]:
Logits[bx, idx] = -T.infinity(dtype)
return clean_logits_kernel
def mqa_attn_return_logits_interface(q,
kv,
kv_scales,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
clean_logits=True):
seq_len, heads, index_dim = q.shape
seq_len_kv = kv.shape[0]
clean_logits_kernel = clean_logits_()
mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32)
mqa_attn_return_logits_kernel(
q.view(seq_len * heads, index_dim),
kv,
kv_scales,
logits,
weights,
cu_seqlen_ks,
cu_seqlen_ke,
)
if clean_logits:
clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke)
return logits
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor):
k = kv
q = q.float()
k = k.float()
seq_len_kv = kv.shape[0]
mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None]
mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None]
mask = mask_lo & mask_hi
score = torch.einsum('mhd,nd->hmn', q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float('-inf'))
cost = mask.sum()
return logits, cost
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)
ks, ke = generate_random_cu_seqlens(
per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
logits_ref, cost_ref = ref_fp8_mqa_logits(
q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
logits_tl = mqa_attn_return_logits_interface(
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
diff = validate_tensor_match(
logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False)
print(f"diff: {diff}")
from tilelang.profiler import do_bench
def logits_fn():
return mqa_attn_return_logits_interface(
q=q_fp8,
kv=kv_fp8,
kv_scales=kv_scales,
weights=weights,
cu_seqlen_ks=ks,
cu_seqlen_ke=ke)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
logits_fn()
print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50))
logits_ms = do_bench(logits_fn, warmup=100, rep=100)
logits_flops = 2 * cost_ref * H * D
logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
print(f"cost_ref: {cost_ref}")
if __name__ == "__main__":
test_fp8_lighting_indexer()
# DeepSeek V3.2
First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
```bash
cd inference
export EXPERTS=256
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
```
Launch the interactive chat interface and start exploring DeepSeek's capabilities:
```bash
export CONFIG=config_671B_v3.2.json
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
```
\ No newline at end of file
{
"vocab_size": 129280,
"dim": 7168,
"inter_dim": 18432,
"moe_inter_dim": 2048,
"n_layers": 61,
"n_dense_layers": 3,
"n_heads": 128,
"n_routed_experts": 256,
"n_shared_experts": 1,
"n_activated_experts": 8,
"n_expert_groups": 8,
"n_limited_groups": 4,
"route_scale": 2.5,
"score_func": "sigmoid",
"q_lora_rank": 1536,
"kv_lora_rank": 512,
"qk_nope_head_dim": 128,
"qk_rope_head_dim": 64,
"v_head_dim": 128,
"dtype": "fp8",
"scale_fmt": "ue8m0",
"index_n_heads": 64,
"index_head_dim": 128,
"index_topk": 2048
}
\ No newline at end of file
import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
mapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
"q_proj": ("wq", 0),
"q_a_proj": ("wq_a", None),
"q_a_layernorm": ("q_norm", None),
"q_b_proj": ("wq_b", 0),
"kv_a_proj_with_mqa": ("wkv_a", None),
"kv_a_layernorm": ("kv_norm", None),
"kv_b_proj": ("wkv_b", 0),
"o_proj": ("wo", 1),
"gate": ("gate", None),
"gate_proj": ("w1", 0),
"down_proj": ("w2", 1),
"up_proj": ("w3", 0),
"norm": ("norm", None),
"lm_head": ("head", 0),
"scale": ("scale", None),
"wq_b": ("wq_b", None),
"wk": ("wk", None),
"k_norm": ("k_norm", None),
"weights_proj": ("weights_proj", None),
}
def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Converts and saves model checkpoint files into a specified format.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
Returns:
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(
dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
import os
import json
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode()
def generate(model: Transformer,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(
prompt_lens
) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def main(
ckpt_path: str,
config: str,
input_file: str = "",
interactive: bool = True,
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(33377335)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
print("load model")
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
print("I'm DeepSeek 👋")
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens,
tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = f.read().split("\n\n")
assert len(
prompts
) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
add_generation_prompt=True) for prompt in prompts
]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id,
temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.6)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens,
args.temperature)
import torch
import tilelang
import tilelang.language as T
from typing import Tuple, Optional
tilelang.set_log_level("WARNING")
pass_configs = {
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True,
}
FP8 = "float8_e4m3"
BF16 = "bfloat16"
FP32 = "float32"
def fast_log2_ceil(x):
bits_x = T.reinterpret("uint32", x)
exp_x = (bits_x >> 23) & 0xFF
man_bits = bits_x & ((1 << 23) - 1)
return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
def fast_pow2(x):
bits_x = (x + 127) << 23
return T.reinterpret("float32", bits_x)
def fast_round_scale(amax, fp8_max_inv):
return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
@tilelang.jit(pass_configs=pass_configs)
def act_quant_kernel(N, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False):
M = T.dynamic("M")
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1 / fp8_max
num_stages = 0 if round_scale else 2
blk_m = 32
group_size = 128
@T.prim_func
def act_quant_kernel_(
X: T.Tensor[(M, N), in_dtype],
Y: T.Tensor[(M, N), out_dtype],
S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype],
):
with T.Kernel(
T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as (
pid_m,
pid_n,
):
x_shared = T.alloc_shared((blk_m, group_size), in_dtype)
x_local = T.alloc_fragment((blk_m, group_size), in_dtype)
amax_local = T.alloc_fragment((blk_m,), scale_dtype)
s_local = T.alloc_fragment((blk_m,), scale_dtype)
y_local = T.alloc_fragment((blk_m, group_size), out_dtype)
y_shared = T.alloc_shared((blk_m, group_size), out_dtype)
for _ in T.Pipelined(1, num_stages=num_stages):
T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared)
T.copy(x_shared, x_local)
T.reduce_absmax(x_local, amax_local, dim=1)
for i in T.Parallel(blk_m):
amax_local[i] = T.max(amax_local[i], 1e-4)
if round_scale:
s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv)
else:
s_local[i] = amax_local[i] * fp8_max_inv
for i, j in T.Parallel(blk_m, group_size):
y_local[i, j] = T.clamp(x_local[i, j] / s_local[i], fp8_min, fp8_max)
for i in T.Parallel(blk_m):
S[pid_m * blk_m + i, pid_n] = s_local[i]
T.copy(y_local, y_shared)
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])
return act_quant_kernel_
def act_quant(x: torch.Tensor,
block_size: int = 128,
scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), "Input tensor must be contiguous"
assert x.size(-1) % block_size == 0, (
f"Last dimension size must be divisible by block_size (block_size={block_size})")
N = x.size(-1)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], N // block_size, dtype=torch.float32)
kernel = act_quant_kernel(N, round_scale=scale_fmt is not None)
kernel(x.view(-1, N), y.view(-1, N), s.view(-1, N // block_size))
return y, s
@tilelang.jit(pass_configs=pass_configs)
def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype="float32"):
assert out_dtype in [BF16, "float32"]
M = T.dynamic("M")
group_size = 128
block_M = 32
block_N = 128
block_K = 128
@T.prim_func
def fp8_gemm_kernel_(
A: T.Tensor[(M, K), FP8],
B: T.Tensor[(N, K), FP8],
C: T.Tensor[(M, N), out_dtype],
scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), FP32],
scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), FP32],
):
with T.Kernel(
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (
bx,
by,
):
A_shared = T.alloc_shared((block_M, block_K), FP8)
B_shared = T.alloc_shared((block_N, block_K), FP8)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
Scale_C_shared = T.alloc_shared((block_M), FP32)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype)
# Improve L2 Cache
T.use_swizzle(panel_size=10)
T.clear(C_local)
T.clear(C_local_accum)
K_iters = T.ceildiv(K, block_K)
for k in T.Pipelined(K_iters, num_stages=4):
# Load A into shared memory
T.copy(A[by * block_M, k * block_K], A_shared)
# Load B into shared memory
T.copy(B[bx * block_N, k * block_K], B_shared)
# Load scale into shared memory
Scale_B = scales_b[bx * block_N // group_size, k]
for i in T.Parallel(block_M):
Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
# Promote to enable 2xAcc
for i, j in T.Parallel(block_M, block_N):
C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i]
T.clear(C_local)
# TMA store
T.copy(C_local_accum, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return fp8_gemm_kernel_
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor,
b_s: torch.Tensor) -> torch.Tensor:
"""
Perform a matrix multiplication using FP8 precision.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous"
assert a_s.is_contiguous() and b_s.is_contiguous(), (
"Scaling factor tensors must be contiguous")
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
kernel = fp8_gemm_kernel(N, K)
kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s)
return c
@tilelang.jit(out_idx=[4], pass_configs=pass_configs)
def fp8_index_kernel(h: int, d: int):
b = T.dynamic("b")
m = T.dynamic("m")
n = T.dynamic("n")
blk_n1 = 512
blk_n2 = 128
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
with T.Kernel(b, m, T.ceildiv(n, blk_n1)) as (i_b, i_m, i1_n):
q_smem = T.alloc_shared((h, d), FP8)
T.copy(q[i_b, i_m, 0, 0], q_smem)
q_s_frag = T.alloc_fragment(h, FP32)
T.copy(q_s[i_b, i_m, 0], q_s_frag)
for i2_n in T.Pipelined(blk_n1 // blk_n2, num_stages=2):
k_smem = T.alloc_shared((blk_n2, d), FP8)
T.copy(k[i_b, i1_n * blk_n1 + i2_n * blk_n2, 0], k_smem)
k_s_frag = T.alloc_fragment(blk_n2, FP32)
T.copy(k_s[i_b, i1_n * blk_n1 + i2_n * blk_n2], k_s_frag)
logits = T.alloc_fragment((blk_n2, h), FP32)
T.gemm(
k_smem,
q_smem,
logits,
transpose_A=False,
transpose_B=True,
clear_accum=True,
)
for i_h, i3_n in T.Parallel(h, blk_n2):
logits[i3_n, i_h] = T.max(logits[i3_n, i_h], 0) * q_s_frag[i_h]
logits_sum = T.alloc_fragment(blk_n2, FP32)
T.reduce_sum(logits, logits_sum, dim=1)
for i3_n in T.Parallel(blk_n2):
logits_sum[i3_n] *= k_s_frag[i3_n]
T.copy(logits_sum, o[i_b, i_m, i1_n * blk_n1 + i2_n * blk_n2])
return fp8_index_kernel_
def fp8_index(
q: torch.Tensor,
q_s: torch.Tensor,
k: torch.Tensor,
k_s: torch.Tensor,
) -> torch.Tensor:
"""
Perform index score using FP8 precision.
Args:
q (torch.Tensor): The Q tensor, must be contiguous.
q_s (torch.Tensor): The scaling factor for Q (float), must be contiguous.
k (torch.Tensor): The K tensor, must be contiguous.
k_s (torch.Tensor): The scaling factor for K (e8m0 here), must be contiguous.
fp8 q @ fp8 k -> fp32 logits
relu(fp32 logits) * q_s (weights) -> fp32 logits
fp32 logits -> fp32 logits_sum
fp32 logits_sum * k_s (e8m0) -> fp32 index_score
"""
return fp8_index_kernel(q.shape[2], q.shape[3])(q, q_s, k, k_s)
import math
from dataclasses import dataclass
from typing import Tuple, Optional, Literal
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
from kernel import act_quant, fp8_gemm, fp8_index
world_size = 1
rank = 0
block_size = 128
@dataclass
class ModelArgs:
"""
Data class for defining model arguments and hyperparameters.
Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
dtype (Literal["bf16", "fp8"]): Data type for computations.
scale_fmt (Optional[str]): Format for quantization scale.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
n_routed_experts (int): Number of routed experts for MoE layers.
n_shared_experts (int): Number of shared experts for MoE layers.
n_activated_experts (int): Number of activated experts in MoE layers.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing.
route_scale (float): Scaling factor for routing scores.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
index_head_dim (int): Dimension for index head.
index_topk (int): Top-k for index head.
"""
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
scale_fmt: Optional[str] = None
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
# moe
n_routed_experts: int = 64
n_shared_experts: int = 2
n_activated_experts: int = 6
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.
# index
index_n_heads: int = 64
index_head_dim: int = 128
index_topk: int = 2048
class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.
Args:
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
self.part_vocab_size = (vocab_size // world_size)
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.
Args:
x (torch.Tensor): Input tensor containing token indices.
Returns:
torch.Tensor: Embedded representations.
Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx
x[mask] = 0
y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0
dist.all_reduce(y)
return y
def linear(x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.
Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
scale_fmt (Optional[str]): The format of scaling factors.
Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.
Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
"""
assert bias is None
if weight.dtype != torch.float8_e4m3fn:
return F.linear(x, weight)
else:
x, scale = act_quant(x, block_size, scale_fmt)
return fp8_gemm(x, scale, weight, weight.scale)
class Linear(nn.Module):
"""
Custom linear layer with support for quantized weights and optional bias.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16
scale_fmt: Optional[str] = None
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(
torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
else:
self.register_parameter("scale", None)
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias, self.scale_fmt)
class ColumnParallelLinear(Linear):
"""
Linear layer with column parallelism, splitting output features across distributed processes.
Args:
in_features (int): Number of input features.
out_features (int): Total number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype=None):
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for column parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with column-parallel computation.
"""
y = linear(x, self.weight, self.bias, self.scale_fmt)
return y
class RowParallelLinear(Linear):
"""
Linear layer with row parallelism, splitting input features across distributed processes.
Args:
in_features (int): Total number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self,
in_features: int,
out_features: int,
bias: bool = False,
reduce_output=True,
dtype=None):
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
self.part_in_features = in_features // world_size
self.reduce_output = reduce_output
super().__init__(self.part_in_features, out_features, bias, dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for row parallel linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor with row-parallel computation.
"""
y = linear(x, self.weight, None, self.scale_fmt)
if self.reduce_output and world_size > 1:
y = y.float()
dist.all_reduce(y)
if self.bias is not None:
y += self.bias
return y.type_as(x)
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization (RMSNorm).
Args:
dim (int): Dimension of the input tensor.
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor, residual: Optional[torch.Tensor] = None):
"""
Forward pass for RMSNorm.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
dtype = x.dtype
if residual is None:
x = x.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype)
else:
x = residual = x.float() + residual.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(dtype), residual.to(dtype)
class LayerNorm(nn.Module):
"""
Layer Normalization.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32))
def forward(self, x: torch.Tensor):
return F.layer_norm(x.float(), (self.dim,), self.weight, self.bias, self.eps).type_as(x)
def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
"""
Precomputes frequency-based complex exponential values for rotary positional embeddings.
Args:
args (ModelArgs): Model arguments containing positional embedding parameters.
Returns:
torch.Tensor: Precomputed complex exponential values for positional embeddings.
"""
dim = args.qk_rope_head_dim
seqlen = args.max_seq_len
beta_fast = args.beta_fast
beta_slow = args.beta_slow
base = args.rope_theta
factor = args.rope_factor
def find_correction_dim(num_rotations, dim, base, max_seq_len):
"""
Computes the correction dimension for a given number of rotations in the rotary positional embedding.
Args:
num_rotations (float): Number of rotations to compute the correction for.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
float: The correction dimension based on the input parameters.
"""
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
"""
Computes the range of correction dimensions for rotary positional embeddings.
Args:
low_rot (float): Lower bound for the number of rotations.
high_rot (float): Upper bound for the number of rotations.
dim (int): Dimensionality of the embedding space.
base (float): Base value for the exponential computation.
max_seq_len (int): Maximum sequence length.
Returns:
Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices.
"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim):
"""
Computes a linear ramp function used to smooth values between a minimum and maximum range.
Args:
min (float): Minimum value for the ramp function.
max (float): Maximum value for the ramp function.
dim (int): Dimensionality of the ramp tensor.
Returns:
torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1,
clamped to the range [0, 1].
"""
if min == max:
max += 0.001
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
freqs = 1.0 / (base**(torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
t = torch.arange(seqlen)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""
Applies rotary positional embeddings to the input tensor.
Args:
x (torch.Tensor): Input tensor with positional embeddings to be applied.
freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings.
Returns:
torch.Tensor: Tensor with rotary embeddings applied.
"""
dtype = x.dtype
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1))
y = torch.view_as_real(x * freqs_cis).flatten(3)
return y.to(dtype)
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.bfloat16
from fast_hadamard_transform import hadamard_transform
hidden_size = x.size(-1)
return hadamard_transform(x, scale=hidden_size**-0.5)
class Indexer(torch.nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.dim: int = args.dim
self.n_heads: int = args.index_n_heads
self.n_local_heads = args.index_n_heads // world_size
self.head_dim: int = args.index_head_dim
self.rope_head_dim: int = args.qk_rope_head_dim
self.index_topk: int = args.index_topk
self.q_lora_rank: int = args.q_lora_rank
self.wq_b = Linear(self.q_lora_rank, self.n_heads * self.head_dim)
self.wk = Linear(self.dim, self.head_dim)
self.k_norm = LayerNorm(self.head_dim)
self.weights_proj = Linear(self.dim, self.n_heads, dtype=torch.get_default_dtype())
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = args.scale_fmt
self.register_buffer(
"k_cache",
torch.zeros(
args.max_batch_size, args.max_seq_len, self.head_dim, dtype=torch.float8_e4m3fn),
persistent=False)
self.register_buffer(
"k_scale_cache",
torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.head_dim // block_size,
dtype=torch.float32),
persistent=False)
def forward(self, x: torch.Tensor, qr: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor]):
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
q = self.wq_b(qr)
q = rearrange(q, 'b s (h d) -> b s h d', d=self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
q = torch.cat([q_pe, q_nope], dim=-1)
k = self.wk(x)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_head_dim, self.head_dim - self.rope_head_dim], dim=-1)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis).squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1)
q = rotate_activation(q)
k = rotate_activation(k)
q_fp8, q_scale = act_quant(q, block_size, self.scale_fmt)
k_fp8, k_scale = act_quant(k, block_size, self.scale_fmt)
self.k_cache[:bsz, start_pos:end_pos] = k_fp8
self.k_scale_cache[:bsz, start_pos:end_pos] = k_scale
weights = self.weights_proj(x) * self.n_heads**-0.5
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
index_score = fp8_index(q_fp8.contiguous(), weights,
self.k_cache[:bsz, :end_pos].contiguous(),
self.k_scale_cache[:bsz, :end_pos].contiguous())
if mask is not None:
index_score += mask
topk_indices = index_score.topk(min(self.index_topk, end_pos), dim=-1)[1]
topk_indices_ = topk_indices.clone()
dist.broadcast(topk_indices_, src=0)
assert torch.all(topk_indices == topk_indices_), f"{topk_indices=} {topk_indices_=}"
return topk_indices
def weight_dequant(weight, scale):
shape = weight.shape
assert weight.dim() == 2
weight = weight.view(shape[0] // block_size, block_size, shape[1] // block_size,
block_size).transpose(1, 2).contiguous().view(-1, block_size * block_size)
weight = (weight.float() * scale.view(-1, 1).float()).to(torch.get_default_dtype()).view(
shape[0] // block_size, shape[1] // block_size, block_size,
block_size).transpose(1, 2).contiguous().view(shape)
return weight
class MLA(nn.Module):
"""
Multi-Head Latent Attention (MLA) Layer.
Attributes:
dim (int): Dimensionality of the input features.
n_heads (int): Number of attention heads.
n_local_heads (int): Number of local attention heads for distributed systems.
q_lora_rank (int): Rank for low-rank query projection.
kv_lora_rank (int): Rank for low-rank key/value projection.
qk_nope_head_dim (int): Dimensionality of non-positional query/key projections.
qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections.
qk_head_dim (int): Total dimensionality of query/key projections.
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
self.n_heads = args.n_heads
self.n_local_heads = args.n_heads // world_size
self.q_lora_rank = args.q_lora_rank
self.kv_lora_rank = args.kv_lora_rank
self.qk_nope_head_dim = args.qk_nope_head_dim
self.qk_rope_head_dim = args.qk_rope_head_dim
self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim
self.v_head_dim = args.v_head_dim
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank,
self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim**-0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
self.indexer = Indexer(args)
self.register_buffer(
"kv_cache",
torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
persistent=False)
self.register_buffer(
"pe_cache",
torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim),
persistent=False)
self.dequant_wkv_b = None
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor]):
"""
Forward pass for the Multi-Head Latent Attention (MLA) Layer.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).
start_pos (int): Starting position in the sequence for caching.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor with the same shape as the input.
"""
bsz, seqlen, _ = x.size()
end_pos = start_pos + seqlen
qr = self.q_norm(self.wq_a(x))
q = self.wq_b(qr)
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv = self.kv_norm(kv)
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis)
self.kv_cache[:bsz, start_pos:end_pos] = kv
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
if mask is not None: # MHA prefill
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(kv)
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
scores = torch.einsum("bshd,bthd->bsht", q.float(), k.float()) * self.softmax_scale
# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, seqlen, seqlen), float("-inf"),
device=x.device).scatter_(-1, topk_indices, 0)
index_mask += mask
scores += index_mask.unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,bthd->bshd", scores.type_as(x), v)
else: # MHA decode
if self.dequant_wkv_b is None and self.wkv_b.scale is not None:
self.dequant_wkv_b = weight_dequant(self.wkv_b.weight, self.wkv_b.scale)
wkv_b = self.wkv_b.weight if self.dequant_wkv_b is None else self.dequant_wkv_b
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
scores = (torch.einsum("bshc,btc->bsht", q_nope.float(),
self.kv_cache[:bsz, :end_pos].float()) +
torch.einsum("bshr,btr->bsht", q_pe.float(),
self.pe_cache[:bsz, :end_pos].float())) * self.softmax_scale
# indexer
topk_indices = self.indexer(x, qr, start_pos, freqs_cis, mask)
index_mask = torch.full((bsz, 1, end_pos), float("-inf"),
device=x.device).scatter_(-1, topk_indices, 0)
scores += index_mask.unsqueeze(2)
scores = scores.softmax(dim=-1, dtype=torch.float32)
x = torch.einsum("bsht,btc->bshc", scores.type_as(x), self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = self.wo(x.flatten(2))
return x
class MLP(nn.Module):
"""
Multi-Layer Perceptron (MLP) used as a feed-forward layer.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int, reduce_output: bool = True):
"""
Initializes the MLP layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.w1 = ColumnParallelLinear(dim, inter_dim)
self.w2 = RowParallelLinear(inter_dim, dim, reduce_output=reduce_output)
self.w3 = ColumnParallelLinear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MLP layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after MLP computation.
"""
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
class Gate(nn.Module):
"""
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.
Attributes:
dim (int): Dimensionality of input features.
topk (int): Number of top experts activated for each input.
n_groups (int): Number of groups for routing.
topk_groups (int): Number of groups to route inputs to.
score_func (str): Scoring function ('softmax' or 'sigmoid').
route_scale (float): Scaling factor for routing weights.
weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Gate module.
Args:
args (ModelArgs): Model arguments containing gating parameters.
"""
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.n_groups = args.n_expert_groups
self.topk_groups = args.n_limited_groups
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts,
dtype=torch.float32)) if self.dim == 7168 else None
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for the gating mechanism.
Args:
x (torch.Tensor): Input tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices.
"""
scores = linear(x.float(), self.weight.float())
if self.score_func == "softmax":
scores = scores.softmax(dim=-1)
else:
scores = scores.sigmoid()
original_scores = scores
if self.bias is not None:
scores = scores + self.bias
if self.n_groups > 1:
scores = scores.view(x.size(0), self.n_groups, -1)
if self.bias is None:
group_scores = scores.amax(dim=-1)
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
indices = scores.topk(self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":
weights /= weights.sum(dim=-1, keepdim=True)
weights *= self.route_scale
return weights, indices
class Expert(nn.Module):
"""
Expert layer for Mixture-of-Experts (MoE) models.
Attributes:
w1 (nn.Module): Linear layer for input-to-hidden transformation.
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
"""
Initializes the Expert layer.
Args:
dim (int): Input and output dimensionality.
inter_dim (int): Hidden layer dimensionality.
"""
super().__init__()
self.w1 = Linear(dim, inter_dim)
self.w2 = Linear(inter_dim, dim)
self.w3 = Linear(dim, inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the Expert layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert computation.
"""
return self.w2((F.silu(self.w1(x).float()) * self.w3(x).float()).type_as(x))
class MoE(nn.Module):
"""
Mixture-of-Experts (MoE) module.
Attributes:
dim (int): Dimensionality of input features.
n_routed_experts (int): Total number of experts in the model.
n_local_experts (int): Number of experts handled locally in distributed systems.
n_activated_experts (int): Number of experts activated for each input.
gate (nn.Module): Gating mechanism to route inputs to experts.
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.
Args:
args (ModelArgs): Model arguments containing MoE parameters.
"""
super().__init__()
self.dim = args.dim
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
self.n_routed_experts = args.n_routed_experts
self.n_local_experts = args.n_routed_experts // world_size
self.n_activated_experts = args.n_activated_experts
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([
Expert(args.dim, args.moe_inter_dim)
if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)
])
self.shared_experts = MLP(
args.dim, args.n_shared_experts * args.moe_inter_dim, reduce_output=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the MoE module.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output tensor after expert routing and computation.
"""
shape = x.size()
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x, dtype=torch.float32)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
expert = self.experts[i]
idx, top = torch.where(indices == i)
y[idx] += expert(x[idx]) * weights[idx, top, None]
y += self.shared_experts(x)
if world_size > 1:
dist.all_reduce(y)
return y.type_as(x).view(shape)
class Block(nn.Module):
"""
Transformer block combining attention and feed-forward layers.
Attributes:
attn (nn.Module): Attention layer (MLA).
ffn (nn.Module): Feed-forward network (MLP or MoE).
attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network.
"""
def __init__(self, layer_id: int, args: ModelArgs):
"""
Initializes the Transformer block.
Args:
layer_id (int): Layer index in the transformer.
args (ModelArgs): Model arguments containing block parameters.
"""
super().__init__()
self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, residual: torch.Tensor, start_pos: int,
freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
"""
Forward pass for the Transformer block.
Args:
x (torch.Tensor): Input tensor.
start_pos (int): Starting position in the sequence.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention.
Returns:
torch.Tensor: Output tensor after block computation.
"""
if residual is None:
x, residual = self.attn_norm(x), x
else:
x, residual = self.attn_norm(x, residual)
x = self.attn(x, start_pos, freqs_cis, mask)
x, residual = self.ffn_norm(x, residual)
x = self.ffn(x)
return x, residual
class Transformer(nn.Module):
"""
Transformer model with positional embeddings, multiple layers, and output projection.
Attributes:
max_seq_len (int): Maximum sequence length for the transformer.
embed (nn.Module): Embedding layer for input tokens.
layers (torch.nn.ModuleList): List of transformer blocks.
norm (nn.Module): Layer normalization applied after all blocks.
head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Transformer model.
Args:
args (ModelArgs): Model arguments containing transformer parameters.
"""
global world_size, rank
world_size = dist.get_world_size() if dist.is_initialized() else 1
rank = dist.get_rank() if dist.is_initialized() else 0
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
Linear.scale_fmt = args.scale_fmt
super().__init__()
self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim)
self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
# lm_head in the checkpoint is stored in bf16, while the parameter here is stored in fp32 for easier computation of logits later.
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.float32)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int = 0):
"""
Forward pass for the Transformer model.
Args:
tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len).
start_pos (int, optional): Starting position in the sequence for rotary embeddings. Defaults to 0.
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
seqlen = tokens.size(1)
freqs_cis = self.freqs_cis[start_pos:start_pos + seqlen]
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) if seqlen > 1 else None
h, residual = self.embed(tokens), None
for layer in self.layers:
h, residual = layer(h, residual, start_pos, freqs_cis, mask)
h, _ = self.norm(h, residual)
logits = self.head(h[:, -1].float())
if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1)
return logits
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.manual_seed(0)
args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())
torch
transformers
safetensors
fast_hadamard_transform
tilelang==0.1.6
\ No newline at end of file
# ruff: noqa
import tilelang
from tilelang import language as T
import torch
from utils import assert_tensors_similar
@tilelang.jit(out_idx=[-1])
def preprocess(
B,
S,
H,
D,
block_ND=32,
num_stages=5,
dtype="bfloat16",
accum_dtype="float",
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
shape = [B, S, H, D]
@T.prim_func
def preprocess_kernel(
O: T.Tensor(shape, dtype),
dO: T.Tensor(shape, dtype),
Delta: T.Tensor([B, S, H], accum_dtype),
):
with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz):
o = T.alloc_fragment([block_ND, block_ND], accum_dtype)
do = T.alloc_fragment([block_ND, block_ND], accum_dtype)
delta = T.alloc_fragment([block_ND], accum_dtype)
acc = T.alloc_fragment([block_ND, block_ND], accum_dtype)
T.clear(acc)
for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages):
T.copy(
O[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
o)
T.copy(
dO[bz, by * block_ND:(by + 1) * block_ND, bx, k * block_ND:(k + 1) * block_ND],
do)
for i, j in T.Parallel(block_ND, block_ND):
acc[i, j] += o[i, j] * do[i, j]
T.reduce_sum(acc, delta, 1)
T.copy(delta, Delta[bz, by * block_ND:(by + 1) * block_ND, bx])
return preprocess_kernel
@tilelang.jit(out_idx=[-1])
def postprocess(
B,
S_kv,
D,
D_tail,
kv_group=1,
block_N=64,
threads=128,
dtype="bfloat16",
accum_dtype="float",
):
assert dtype == "bfloat16"
assert accum_dtype == "float"
dkv_shape = [B, S_kv, kv_group, D + D_tail]
@T.prim_func
def postprocess_kernel(
dKV: T.Tensor(dkv_shape, accum_dtype),
dKV_out: T.Tensor(dkv_shape, dtype),
):
with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as (bx, by, bz):
T.copy(
dKV[bz, bx * block_N:(bx + 1) * block_N, by, :],
dKV_out[bz, bx * block_N:(bx + 1) * block_N, by, :],
)
return postprocess_kernel
@tilelang.jit(
out_idx=[-2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def bwd(
B,
S,
S_kv,
H,
D,
D_tail,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
block_size=32,
num_stages=0,
threads=256,
indices_dtype="int32",
dtype="bfloat16",
accum_dtype="float",
):
assert is_causal == True, 'non-casual is not supported now'
assert topk % block_size == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
assert dtype == "bfloat16"
assert accum_dtype == "float"
assert indices_dtype == "int32"
if sm_scale is None:
sm_scale = (D + D_tail)**(-0.5)
sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e)
H_kv = H // kv_group
q_shape = [B, S, H, D + D_tail]
k_shape = [B, S_kv, kv_group, D + D_tail]
o_shape = [B, S, H, D]
indices_shape = [B, S, kv_group, topk]
delta_shape = [B, S, H]
lse_shape = [B, S, H]
assert indices_dtype == "int32"
assert dtype == "bfloat16"
assert accum_dtype == "float"
H = H_kv
padded_H = max(tilelang.math.next_power_of_2(H_kv), 16)
BS = block_size
NS = tilelang.cdiv(topk, block_size)
split_store = 2
@T.prim_func
def sparse_mla_bwd_kernel(
Q: T.Tensor(q_shape, dtype),
KV: T.Tensor(k_shape, dtype),
dO: T.Tensor(o_shape, dtype),
Indices: T.Tensor(indices_shape, indices_dtype),
Lse: T.Tensor(lse_shape, accum_dtype),
Delta: T.Tensor(delta_shape, accum_dtype),
dQ: T.Tensor(q_shape, dtype),
dKV: T.Tensor(k_shape, accum_dtype),
):
with T.Kernel(S, B, kv_group, threads=threads) as (s_i, by, bz):
Q_shared = T.alloc_shared([padded_H, D], dtype)
Q_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
KV_shared = T.alloc_shared([BS, D], dtype)
KV_tail_shared = T.alloc_shared([BS, D_tail], dtype)
dO_shared = T.alloc_shared([padded_H, D], dtype)
mask = T.alloc_fragment([BS], "bool")
P_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dP_shared_cast = T.alloc_shared([padded_H, BS], dtype)
dQ_shared = T.alloc_shared([padded_H, D], dtype)
dQ_tail_shared = T.alloc_shared([padded_H, D_tail], dtype)
acc_p = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dp = T.alloc_fragment([padded_H, BS], accum_dtype)
acc_dq = T.alloc_fragment([padded_H, D], accum_dtype)
acc_dq_tail = T.alloc_fragment([padded_H, D_tail], accum_dtype)
acc_dkv = T.alloc_fragment([BS, D], accum_dtype)
acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype)
acc_dkv_shared = T.view(KV_shared, shape=[BS // split_store, D], dtype=accum_dtype)
acc_dkv_tail_shared = T.view(
KV_tail_shared, shape=[BS // split_store, D_tail], dtype=accum_dtype)
max_kv_i = s_i
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], Q_shared)
T.copy(Q[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:], Q_tail_shared)
T.copy(dO[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D], dO_shared)
T.clear(acc_dq)
T.clear(acc_dq_tail)
T.annotate_layout({
dQ_shared: tilelang.layout.make_swizzled_layout(dQ_shared),
dQ_tail_shared: tilelang.layout.make_swizzled_layout(dQ_tail_shared),
})
# Process each block of indices
for i_i in T.Pipelined(NS, num_stages=num_stages):
# Check which indices are valid
for bi_i in T.Parallel(BS):
mask[bi_i] = Indices[by, s_i, bz, i_i * BS + bi_i] <= max_kv_i
# Compute attention scores
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_p.dtype))
# Load KV, V for this block of indices
for bi_i, d_i in T.Parallel(BS, D):
KV_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz, d_i]
T.gemm(
Q_shared, KV_shared, acc_p, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
for bi_i, d_i in T.Parallel(BS, D_tail):
KV_tail_shared[bi_i, d_i] = KV[by, Indices[by, s_i, bz, i_i * BS + bi_i], bz,
D + d_i]
T.gemm(
Q_tail_shared,
KV_tail_shared,
acc_p,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_p[h_i, bi_i] = T.exp2(acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 -
Lse[by, s_i, bz * padded_H + h_i])
T.copy(acc_p, P_shared_cast)
T.gemm(
dO_shared,
KV_shared,
acc_dp,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
for h_i, bi_i in T.Parallel(padded_H, BS):
acc_dp[h_i, bi_i] = acc_p[h_i, bi_i] * (
acc_dp[h_i, bi_i] - Delta[by, s_i, bz * padded_H + h_i]) * sm_scale
T.copy(acc_dp, dP_shared_cast)
T.gemm(dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol)
T.gemm(dP_shared_cast, KV_tail_shared, acc_dq_tail, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
dP_shared_cast,
Q_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol,
clear_accum=True)
T.gemm(
P_shared_cast,
dO_shared,
acc_dkv,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
T.clear(acc_dkv_tail)
T.gemm(
dP_shared_cast,
Q_tail_shared,
acc_dkv_tail,
transpose_A=True,
policy=T.GemmWarpPolicy.FullCol)
for s in range(split_store):
for bi_i, d_i in T.Parallel(BS, D):
if bi_i < BS // split_store:
acc_dkv_shared[bi_i, d_i] = acc_dkv[bi_i + s * (BS // split_store), d_i]
for bi_i, d_i in T.Parallel(BS, D_tail):
if bi_i < BS // split_store:
acc_dkv_tail_shared[bi_i,
d_i] = acc_dkv_tail[bi_i + s * (BS // split_store),
d_i]
for bi_i, d_i in T.Parallel(BS // split_store, D // 4):
T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)],
bz, d_i * 4], acc_dkv_shared[bi_i, d_i * 4])
# Atomically update dKV, dKV_tail tensors
for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4):
T.atomic_addx4(
dKV[by, Indices[by, s_i, bz, i_i * BS + bi_i + s * (BS // split_store)],
bz, D + d_i * 4], acc_dkv_tail_shared[bi_i, d_i * 4])
# Store the accumulated dQ
T.copy(acc_dq, dQ_shared)
T.copy(acc_dq_tail, dQ_tail_shared)
T.copy(dQ_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, :D])
T.copy(dQ_tail_shared, dQ[by, s_i, bz * padded_H:(bz + 1) * padded_H, D:])
return sparse_mla_bwd_kernel
def sparse_mla_bwd(q,
kv,
o,
do,
indices,
lse,
sm_scale=None,
is_casual=True,
return_kernel=False,
delta=None):
assert q.is_contiguous()
assert kv.is_contiguous()
assert indices.is_contiguous()
assert lse.is_contiguous()
B, S, H, dim_plus_tail_dim = q.shape
_, S_kv, kv_group, _ = kv.shape
assert kv.shape[-1] == dim_plus_tail_dim
assert kv.shape[0] == B
# dim should be assigned
D = 512
D_tail = dim_plus_tail_dim - D
topk = indices.shape[-1]
assert indices.shape == (B, S, kv_group, topk)
assert lse.shape == (B, S, H)
# Get kernels
preprocess_kernel = preprocess(B, S, H, D)
bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual)
postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group)
if delta is None:
delta = preprocess_kernel(o, do)
dkv = torch.zeros_like(kv, dtype=torch.float32)
dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv)
dkv = postprocess_kernel(dkv)
return dq, dkv
def ref_sparse_mla_bwd_interface(q, kv, o, do, indices, lse, sm_scale=None, is_casual=True):
from sparse_mla_fwd import ref_sparse_mla_fwd_interface
q = q.detach().clone()
kv = kv.detach().clone()
q.requires_grad = True
kv.requires_grad = True
o = ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale, is_casual)
o.backward(do)
return q.grad, kv.grad
def test_sparse_mla_bwd(B=1,
S=4096,
SKV=8192,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True):
# Prepare data
q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True)
do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda')
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda')
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i
# Forward
from sparse_mla_fwd import sparse_mla_fwd_interface
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices)
tl_dq, tl_dkv = sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
ref_dq, ref_dkv = ref_sparse_mla_bwd_interface(q, kv, None, do, indices, None)
if check_correctness:
assert_tensors_similar(tl_dq, ref_dq, eps=1e-4, name="dq")
assert_tensors_similar(tl_dkv, ref_dkv, eps=1e-4, name="dkv")
print("assert_tensors_similar passed")
per_token_flop = 2 * sum([
H * DV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DQKV * topk,
H * DV * topk,
])
from tilelang.profiler import do_bench
def fn():
return sparse_mla_bwd(q, kv, tl_out, do, indices, tl_lse)
ms = do_bench(fn, rep=100, warmup=250)
print(f"Average time: {ms:.3f} ms")
print(f'bwd io bandwidth = ',
(B * S * max(DQKV * 2, DQKV + DV) * topk * 2) / (ms * 1e-3) / 1e12)
print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_bwd(
B=1,
S=4096,
SKV=8192,
H=64,
HKV=1,
DQKV=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True)
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
from utils import assert_tensors_similar
@tilelang.jit(
out_idx=[-2, -1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=64,
num_stages=2,
threads=256,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, "non-casual is not supported"
assert (topk %
block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded"
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
batch = T.dynamic("batch")
seq_len = T.dynamic("seq_len")
seq_len_kv = T.dynamic("seq_len_kv")
head_kv = heads // kv_group
q_shape = [batch, seq_len, heads, dim + tail_dim]
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert (
kv_group == 1
), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)"
BI = block_I
NI = tilelang.cdiv(topk, block_I)
D = dim
D_tail = tail_dim
if head_kv > 64:
assert head_kv % 64 == 0, "head_kv should be a multiple of 64"
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(
seq_len * REPLICATE_H, batch, kv_group, threads=threads) as (
bx,
by,
bz,
):
Q_shared = T.alloc_shared([H_per_block, D], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared = T.alloc_shared([BI, D], dtype)
K_tail_shared = T.alloc_shared([BI, D_tail], dtype)
O_shared = T.alloc_shared([H_per_block, D], dtype)
Lse_shared = T.alloc_shared([H_per_block], accum_dtype)
mask = T.alloc_fragment([BI], "bool")
acc_o = T.alloc_fragment([H_per_block, D], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
T.fill(acc_o, 0)
T.fill(sumexp, 0)
T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan
b_i, g_i = by, bz
s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H)
q_i = s_i
max_kv_i = q_i
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
for i_i in T.Pipelined(NI, num_stages=num_stages):
for bi_i in T.Parallel(BI):
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
for bi_i, d_i in T.Parallel(BI, D):
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
d_i]
for bi_i, d_i in T.Parallel(BI, D_tail):
K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i,
D + d_i]
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype))
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.gemm(
Q_tail_shared,
K_tail_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullRow,
)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i]
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D):
acc_o[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o, O_shared)
T.copy(acc_o, Output[b_i, s_i, H0:H1, :])
T.copy(sumexp, Lse_shared)
T.copy(sumexp, Lse[b_i, s_i, H0:H1])
return main
def sparse_mla_fwd_interface(q,
kv,
indices,
sm_scale=None,
return_p_sum: bool = False,
d_v=512,
block_I=64,
num_stages=2,
threads=256):
is_casual = True
assert return_p_sum == False, "This kernel file is for fwd only"
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, "you should assign dim otherwise"
dim = d_v
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
assert kv.shape[0] == batch
_, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk)
kernel = sparse_mla_fwd(
heads,
dim,
tail_dim,
topk,
kv_group,
sm_scale,
is_casual,
block_I=block_I,
num_stages=num_stages,
threads=threads)
out, lse = kernel(q, kv, indices)
return out, lse
def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True):
q = q.float()
kv = kv.float()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
assert kv.shape[-1] == 576, "you should assign dim otherwise"
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(
0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(
1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :1 - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd(B=1,
S=4096,
SKV=8192,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256):
torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(max(1, t))[:topk]
indices[b, t, h, :len(i_i)] = i_i
tl_out, tl_lse = sparse_mla_fwd_interface(
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
if check_correctness:
# otherwise may cause out of memory
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices)
assert_tensors_similar(tl_out, ref_out, eps=1e-2, name="out")
print("assert_tensors_similar passed")
def fn():
return sparse_mla_fwd_interface(
q, kv, indices, block_I=block_I, num_stages=num_stages, threads=threads)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=100,
warmup=250,
)
print(f"Average time: {ms:.3f} ms")
print("fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
test_sparse_mla_fwd(
B=1,
S=4096,
SKV=4096,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
check_correctness=True,
block_I=64,
num_stages=2,
threads=256)
# ruff: noqa
import torch
import tilelang
from tilelang import language as T
from tilelang.engine.callback import register_cuda_postproc_callback
import argparse
@tilelang.jit(
out_idx=[-2, -1],
compile_flags=[
"-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda",
"--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG"
],
)
def sparse_mla_fwd(
batch,
seq_len,
seq_len_kv,
heads,
dim,
tail_dim,
topk,
kv_stride,
kv_group=1,
sm_scale=None,
is_causal=True,
CP0=True,
block_I=64,
num_stages=0,
threads=384,
):
assert dim == tilelang.math.next_power_of_2(
dim), f"haven't check padding correctness yet, dim={dim}"
assert tail_dim == tilelang.math.next_power_of_2(
tail_dim), f"haven't check padding correctness yet, dim={tail_dim}"
assert is_causal == True, 'non-casual is not supported'
assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded'
if sm_scale is None:
sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e)
else:
sm_scale = sm_scale * 1.44269504 # log2(e)
head_kv = heads // kv_group
q_shape = [batch, seq_len, heads, dim + tail_dim]
kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim]
o_shape = [batch, seq_len, heads, dim]
indices_shape = [batch, seq_len, kv_group, topk]
lse_shape = [batch, seq_len, heads]
indices_dtype = "int32"
dtype = "bfloat16"
accum_dtype = "float"
G = kv_group
H = head_kv
padded_H = max(tilelang.math.next_power_of_2(head_kv), 16)
if padded_H != H:
assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)'
BI = block_I
NI = tilelang.cdiv(topk, block_I)
assert NI % 2 == 0, 'NI should be a multiple of 2'
D = dim
D_tail = tail_dim
KV_stride = kv_stride
if head_kv > 64:
assert head_kv % 64 == 0, 'head_kv should be a multiple of 64'
REPLICATE_H = head_kv // 64
else:
REPLICATE_H = 1
H_per_block = padded_H if REPLICATE_H == 1 else 64
@T.prim_func
def main(
Q: T.Tensor(q_shape, dtype), # type: ignore
KV: T.Tensor(kv_shape, dtype), # type: ignore
Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore
q_start_index_s: T.Tensor(1, indices_dtype),
Output: T.Tensor(o_shape, dtype), # type: ignore
Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore
):
with T.Kernel(
(seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H,
batch,
kv_group,
threads=threads) as (bx, by, bz):
Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype)
Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype)
Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype)
KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype)
KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype)
K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype)
K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype)
O_shared_l = Q_shared_l
O_shared_r = Q_shared_r
is_kv_valid = T.alloc_shared([BI], "bool", scope="shared")
acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype)
acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype)
S_shared = T.alloc_shared([H_per_block, BI], dtype)
sumexp = T.alloc_fragment([H_per_block], accum_dtype)
sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype)
sumexp_i = T.alloc_fragment([H_per_block], accum_dtype)
alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared")
alpha_local = T.alloc_fragment([H_per_block], accum_dtype)
m_i = T.alloc_fragment([H_per_block], accum_dtype)
m_i_prev = T.alloc_fragment([H_per_block], accum_dtype)
indices_local = T.alloc_local([1], indices_dtype)
# TODO: Multi buffer
bar_q = T.alloc_barrier(arrive_count=384)
bar_k_0_ready = T.alloc_barrier(arrive_count=128)
bar_k_1_ready = T.alloc_barrier(arrive_count=128)
bar_k_0_free = T.alloc_barrier(arrive_count=256)
bar_k_1_free = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256)
bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256)
b_i, g_i = by, bz
s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (
bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0))
q_i = q_start_index_s[0] + s_i
max_kv_i = (q_i + 1 - KV_stride) // KV_stride
H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64)
H1 = H0 + H_per_block
tx = T.get_thread_binding()
T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l)
T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r)
T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared)
T.barrier_arrive(bar_q)
if tx < 128:
T.set_max_nreg(240, 1)
T.fill(sumexp, 0)
T.fill(m_i, -2**30) # avoid -inf - inf to cause nan
T.fill(acc_o_l, 0)
T.barrier_wait(bar_q, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_ready[0], (i_i & 1))
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0,
-T.infinity(acc_s.dtype))
T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
if i_i != 0:
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_0_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_0_free[0])
# Buffer 1
T.barrier_wait(bar_k_1_ready[0], (i_i & 1))
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0,
-T.infinity(acc_s.dtype))
T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1)
T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1)
T.wait_wgmma(0)
T.barrier_arrive(bar_sScale_and_sS_free)
T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1)
T.copy(m_i, m_i_prev)
T.reduce_max(acc_s, m_i, dim=1, clear=False)
for h_i in T.Parallel(H_per_block):
alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale)
for h_i, bi_i in T.Parallel(H_per_block, BI):
acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale)
T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator?
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] *= alpha_local[h_i]
T.copy(alpha_local, alpha_shared)
T.copy(acc_s, S_shared)
T.gemm(S_shared, KV_shared_1_l, acc_o_l)
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_arrive(bar_k_1_free[0])
# Rescale
for h_i in T.Parallel(H_per_block):
sum_exp_shared[h_i] = sumexp[h_i]
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_l[h_i, d_i] /= sumexp[h_i]
for h_i in T.Parallel(H_per_block):
sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale
T.copy(acc_o_l, O_shared_l)
T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2])
elif tx >= 128 and tx < 256:
T.set_max_nreg(168, 1)
T.fill(acc_o_r, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1))
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_0_r, acc_o_r)
T.barrier_arrive(bar_k_0_free[0])
T.barrier_arrive(bar_sScale_and_sS_free)
# Buffer 1
T.barrier_arrive(bar_sScale_and_sS_ready)
T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1))
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] *= alpha_shared[h_i]
T.gemm(S_shared, KV_shared_1_r, acc_o_r)
T.barrier_arrive(bar_k_1_free[0])
if i_i != T.ceildiv(NI, 2) - 1:
T.barrier_arrive(bar_sScale_and_sS_free)
# Rescale
for h_i, d_i in T.Parallel(H_per_block, D // 2):
acc_o_r[h_i, d_i] /= sum_exp_shared[h_i]
T.copy(acc_o_r, O_shared_r)
T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D])
elif tx >= 256:
# producer
T.set_max_nreg(80, 0)
for i_i in T.serial(T.ceildiv(NI, 2)):
# Buffer 0
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
indices_local[0] = Indices[b_i, s_i, g_i,
(i_i * 2) * BI + r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_0_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_0_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i, D // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
D + (tx - 256) % 8 * 8 + v]
T.cp_async_barrier_noinc(bar_k_0_ready[0])
# Buffer 1
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
for r in T.serial(4):
indices_local[0] = Indices[b_i, s_i, g_i,
(i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8]
is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i
if is_kv_valid[r * 16 + (tx - 256) // 8]:
with T.attr("default", "async_scope", 1):
for u in T.serial(4):
for v in T.vectorized(8):
KV_shared_1_l[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
64 * u + (tx - 256) % 8 * 8 + v]
KV_shared_1_r[r * 16 + (tx - 256) // 8,
64 * u + (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i, D // 2 +
64 * u + (tx - 256) % 8 * 8 + v]
with T.attr("default", "async_scope", 1):
for v in T.vectorized(8):
K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 +
v] = KV[b_i, indices_local[0], g_i,
D + (tx - 256) % 8 * 8 + v]
T.cp_async_barrier_noinc(bar_k_1_ready[0])
return main
def sparse_mla_fwd_interface(q,
kv,
indices,
q_start_index_s,
kv_stride,
sm_scale=None,
is_casual=True,
return_kernel=False,
print_kernel=False):
assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
batch, seq_len, heads, dim_plus_tail_dim = q.shape
_, seq_len_kv, kv_group, _ = kv.shape
assert dim_plus_tail_dim == 576, 'you should assign dim otherwise'
dim = 512
assert kv.shape[-1] == dim_plus_tail_dim
tail_dim = dim_plus_tail_dim - dim
assert kv.shape[0] == batch
_, _, _, topk = indices.shape
assert indices.shape == (batch, seq_len, kv_group, topk)
if q_start_index_s != 0:
assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)"
CP0 = q_start_index_s == 0
kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride,
kv_group, sm_scale, is_casual, CP0)
if print_kernel:
print(kernel.get_kernel_source())
out, lse = kernel(q, kv, indices,
torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda"))
if return_kernel:
return kernel
if q_start_index_s == 0 and kv_stride > 1:
out[:, :kv_stride - 1, :, :] = 0
return out, lse
def ref_sparse_mla_fwd_interface(q,
kv,
indices,
q_start_index_s,
kv_stride=4,
sm_scale=None,
is_casual=True):
q = q.float()
kv = kv.float()
indices = indices.transpose(1, 2)
b, sq, h, dim_q = q.shape
b, sk, g, _ = kv.shape
if q_start_index_s is None:
q_start_index_s = sk * kv_stride - sq
assert kv.shape[-1] == 576, 'you should assign dim otherwise'
dim = 512
k = kv
v = kv[..., :dim]
b, _, _, dim_v = v.shape
num_kv_per_index = 1
g_index = g
h_index = h // g
compressed_casual_mask = torch.arange(
q_start_index_s, sq + q_start_index_s, dtype=torch.int32,
device="cuda").view(-1, 1) >= torch.arange(
kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1)
mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
mask = mask[..., :-1]
mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
mask[:, :, :kv_stride - 1, 0] = True
mask = mask.view(b, g_index, 1, sq, sk)
q = q.view(b, sq, g, -1, dim_q)
score = torch.einsum("bmghd,bngd->bghmn", q, k)
sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
p = score.softmax(dim=-1)
p = p.view(b, g_index, h_index, -1, sq, sk)
p = p.view(b, g, -1, sq, sk)
o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
o = o.reshape(b, sq, h, dim_v)
return o.to(torch.bfloat16)
def test_sparse_mla_fwd_pipelined(B=1,
S=4096,
SKV=8192,
H=128,
HKV=1,
DQK=576,
DV=512,
topk=2048,
dtype=torch.bfloat16,
q_start_s_index=1024,
check_correctness=True):
KV_stride = 1
torch.random.manual_seed(0)
q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10
kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10
q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")
q.clamp_(-10, 10)
kv.clamp_(-10, 10)
indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda')
for b in range(B):
for t in range(S):
for h in range(HKV):
i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk]
indices[b, t, h, :len(i_i)] = i_i
kernel = sparse_mla_fwd_interface(
q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True)
def fn():
out, lse = kernel(q, kv, indices, q_start_s_index_t)
if q_start_s_index == 0 and KV_stride > 1:
out[:, :KV_stride - 1, :, :] = 0
return out, lse
tl_out, tl_lse = fn()
ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride)
# print(f"tl_out: {tl_out}")
# print(f"ref_out: {ref_out}")
torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3)
from tilelang.profiler import do_bench
ms = do_bench(
fn,
rep=10,
warmup=10,
)
print(f"Average time: {ms:.3f} ms")
print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12)
print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--test_correctness", action="store_true")
args = parser.parse_args()
if args.test_correctness:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
else:
B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16
test_sparse_mla_fwd_pipelined(
B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=args.test_correctness)
# ruff: noqa
import tilelang.testing
from topk_selector import test_topk_selector
from fp8_lighting_indexer import test_fp8_lighting_indexer
from sparse_mla_fwd import test_sparse_mla_fwd
from sparse_mla_fwd_pipelined import test_sparse_mla_fwd_pipelined
from sparse_mla_bwd import test_sparse_mla_bwd
def test_example_topk_selector():
test_topk_selector()
def test_example_fp8_lighting_indexer():
test_fp8_lighting_indexer(S=1024, SKV=2048, H=32, HKV=1, D=64, kv_stride=1)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd():
# small shapes for testing
test_sparse_mla_fwd(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_fwd_pipelined():
# small shapes for testing
test_sparse_mla_fwd_pipelined(
S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
def test_example_sparse_mla_bwd():
test_sparse_mla_bwd(
S=256, SKV=1024, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment