import torch import tilelang import tilelang.language as T from typing import Tuple, Optional tilelang.set_log_level("WARNING") pass_configs = { tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, } FP8 = "float8_e4m3" FP4 = "float4_e2m1fn" # FE8M0 = "float8_e8m0fnu" BF16 = "bfloat16" FP32 = "float32" INT32 = "int32" def fast_log2_ceil(x): """Compute ceil(log2(x)) via IEEE 754 bit manipulation. Avoids slow log/ceil intrinsics.""" bits_x = T.reinterpret("uint32", x) exp_x = (bits_x >> 23) & 0xFF man_bits = bits_x & ((1 << 23) - 1) return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) def fast_pow2(x): """Compute 2^x for integer x via IEEE 754 bit manipulation.""" bits_x = (x + 127) << 23 return T.reinterpret("float32", bits_x) def fast_round_scale(amax, fp8_max_inv): return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) @tilelang.jit(pass_configs=pass_configs) def act_quant_kernel( N, block_size=128, in_dtype=BF16, out_dtype=FP8, scale_dtype=FP32, round_scale=False, inplace=False ): """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16.""" M = T.symbolic("M") fp8_min = -448.0 fp8_max = 448.0 fp8_max_inv = 1 / fp8_max num_stages = 0 if round_scale or inplace else 2 blk_m = 32 group_size = block_size # Internal computation in FP32; scale_dtype controls output storage format. compute_dtype = FP32 out_dtype = in_dtype if inplace else out_dtype @T.prim_func def act_quant_kernel_( X: T.Tensor[(M, N), in_dtype], Y: T.Tensor[(M, N), out_dtype], S: T.Tensor[(M, T.ceildiv(N, group_size)), scale_dtype], ): with T.Kernel(T.ceildiv(M, blk_m), T.ceildiv(N, group_size), threads=128) as ( pid_m, pid_n, ): x_shared = T.alloc_shared((blk_m, group_size), in_dtype) x_local = T.alloc_fragment((blk_m, group_size), in_dtype) amax_local = T.alloc_fragment((blk_m,), compute_dtype) s_local = T.alloc_fragment((blk_m,), compute_dtype) y_local = T.alloc_fragment((blk_m, group_size), out_dtype) y_shared = T.alloc_shared((blk_m, group_size), out_dtype) for _ in T.Pipelined(1, num_stages=num_stages): T.copy(X[pid_m * blk_m, pid_n * group_size], x_shared) T.copy(x_shared, x_local) T.reduce_absmax(x_local, amax_local, dim=1) for i in T.Parallel(blk_m): amax_local[i] = T.max(amax_local[i], 1e-4) if round_scale: s_local[i] = fast_round_scale(amax_local[i], fp8_max_inv) else: s_local[i] = amax_local[i] * fp8_max_inv if inplace: for i, j in T.Parallel(blk_m, group_size): y_local[i, j] = T.Cast( out_dtype, T.Cast(compute_dtype, T.Cast(out_dtype, T.clamp( x_local[i, j] / s_local[i], fp8_min, fp8_max ))) * s_local[i], ) else: for i, j in T.Parallel(blk_m, group_size): y_local[i, j] = T.clamp( x_local[i, j] / s_local[i], fp8_min, fp8_max ) for i in T.Parallel(blk_m): S[pid_m * blk_m + i, pid_n] = T.Cast(scale_dtype, s_local[i]) T.copy(y_local, y_shared) T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size]) return act_quant_kernel_ def act_quant( x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None, scale_dtype: torch.dtype = torch.float32, inplace: bool = False, ) -> torch.Tensor: """Block-wise FP8 quantization. inplace=True does fused quant+dequant back to BF16. When scale_fmt is set, scales are rounded to power-of-2 (MXFP).""" N = x.size(-1) assert N % block_size == 0 # tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32 tl_dtype = FP32 z = x.contiguous() y = torch.empty_like(z) if inplace else torch.empty_like(z, dtype=torch.float8_e4m3fn) s = z.new_empty(*z.size()[:-1], N // block_size, dtype=scale_dtype) kernel = act_quant_kernel( N, block_size, scale_dtype=tl_dtype, round_scale=scale_fmt is not None, inplace=inplace, ) kernel(z.view(-1, N), y.view(-1, N), s.view(-1, N // block_size)) if inplace: x.copy_(y) return x return y, s @tilelang.jit(pass_configs=pass_configs) def fp8_gemm_kernel(N, K, out_dtype=BF16, accum_dtype=FP32, scale_dtype=FP32): assert out_dtype in [BF16, FP32] M = T.symbolic("M") group_size = 128 block_M = 32 block_N = 128 block_K = 128 @T.prim_func def fp8_gemm_kernel_( A: T.Tensor[(M, K), FP8], B: T.Tensor[(N, K), FP8], C: T.Tensor[(M, N), out_dtype], scales_a: T.Tensor[(M, T.ceildiv(K, group_size)), scale_dtype], scales_b: T.Tensor[(T.ceildiv(N, group_size), T.ceildiv(K, group_size)), scale_dtype], ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as ( bx, by, ): A_shared = T.alloc_shared((block_M, block_K), FP8) B_shared = T.alloc_shared((block_N, block_K), FP8) C_shared = T.alloc_shared((block_M, block_N), out_dtype) Scale_C_shared = T.alloc_shared((block_M), FP32) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) C_local_accum = T.alloc_fragment((block_M, block_N), accum_dtype) # Improve L2 Cache T.use_swizzle(panel_size=10) T.clear(C_local) T.clear(C_local_accum) K_iters = T.ceildiv(K, block_K) for k in T.Pipelined(K_iters, num_stages=2): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) # Cast scales to FP32 for computation; scales_b has one value per block_N group Scale_B = T.Cast(FP32, scales_b[bx * block_N // group_size, k]) for i in T.Parallel(block_M): Scale_C_shared[i] = T.Cast(FP32, scales_a[by * block_M + i, k]) * Scale_B T.gemm(A_shared, B_shared, C_local, transpose_B=True) # Separate accumulator for scale-corrected results (2x accumulation precision) for i, j in T.Parallel(block_M, block_N): C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] T.clear(C_local) T.copy(C_local_accum, C_shared) T.copy(C_shared, C[by * block_M, bx * block_N]) return fp8_gemm_kernel_ def fp8_gemm( a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor, scale_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: """C[M,N] = A[M,K] @ B[N,K]^T with per-128 block FP8 scaling on both A and B.""" assert a.is_contiguous() and b.is_contiguous(), "Input tensors must be contiguous" assert a_s.is_contiguous() and b_s.is_contiguous(), ( "Scaling factor tensors must be contiguous" ) # tl_dtype = FP32 tl_dtype = FE8M0 if scale_dtype == torch.float8_e8m0fnu else FP32 K = a.size(-1) M = a.numel() // K N = b.size(0) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) kernel = fp8_gemm_kernel(N, K, scale_dtype=tl_dtype) kernel(a.view(M, K), b, c.view(M, N), a_s.view(M, -1), b_s) return c @tilelang.jit(pass_configs=pass_configs) def sparse_attn_kernel(h_orig: int, d: int, scale=None): """Sparse multi-head attention via index gathering + online softmax (FlashAttention-style). For each (batch, seq_pos), gathers top-k KV positions by index, computes attention with numerically stable running max/sum, and includes a learnable attn_sink bias.""" b = T.symbolic("b") m = T.symbolic("m") n = T.symbolic("n") topk = T.symbolic("topk") if scale is None: scale = (1.0 / d) ** 0.5 num_stages = 0 threads = 256 block = 32 num_blocks = tilelang.cdiv(topk, block) padded_H = max(tilelang.math.next_power_of_2(h_orig), 16) max_block_m = 16 if h_orig > max_block_m: assert h_orig % max_block_m == 0, f"h should be a multiple of {max_block_m}" REPLICATE_H = h_orig // max_block_m else: REPLICATE_H = 1 h = padded_H if REPLICATE_H == 1 else max_block_m @T.prim_func def sparse_attn_kernel_( q: T.Tensor[(b, m, h_orig, d), BF16], kv: T.Tensor[(b, n, d), BF16], o: T.Tensor[(b, m, h_orig, d), BF16], attn_sink: T.Tensor[(h_orig,), FP32], topk_idxs: T.Tensor[(b, m, topk), INT32], ): with T.Kernel(m * REPLICATE_H, b, threads=threads) as (bx, by): q_shared = T.alloc_fragment((h, d), BF16) kv_shared = T.alloc_shared((block, d), BF16) # o_shared = T.alloc_shared((h, d), BF16) acc_s_cast = T.alloc_shared((h, block), BF16) idxs = T.alloc_fragment(block, INT32) acc_s = T.alloc_fragment((h, block), FP32) acc_o = T.alloc_fragment((h, d), FP32) scores_max = T.alloc_fragment(h, FP32) scores_max_prev = T.alloc_fragment(h, FP32) scores_scale = T.alloc_fragment(h, FP32) scores_sum = T.alloc_fragment(h, FP32) sum_exp = T.alloc_fragment(h, FP32) T.clear(acc_o) T.clear(sum_exp) T.fill(scores_max, -T.infinity(FP32)) s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) H0 = (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * h) H1 = H0 + h T.copy(q[by, s_i, H0:H1, :], q_shared) for t in T.Pipelined(num_blocks, num_stages=num_stages): for i in T.Parallel(block): idxs[i] = T.if_then_else(t * block + i < topk, topk_idxs[by, s_i, t * block + i], -1) for i, j in T.Parallel(block, d): kv_shared[i, j] = T.if_then_else(idxs[i] != -1, kv[by, idxs[i], j], 0) for i, j in T.Parallel(h, block): acc_s[i, j] = T.if_then_else(idxs[j] != -1, 0, -T.infinity(FP32)) T.gemm(q_shared, kv_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(h, block): acc_s[i, j] *= scale T.copy(scores_max, scores_max_prev) T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(h): scores_scale[i] = T.exp(scores_max_prev[i] - scores_max[i]) for i, j in T.Parallel(h, block): acc_s[i, j] = T.exp(acc_s[i, j] - scores_max[i]) T.reduce_sum(acc_s, scores_sum, dim=1) for i in T.Parallel(h): sum_exp[i] = sum_exp[i] * scores_scale[i] + scores_sum[i] T.copy(acc_s, acc_s_cast) for i, j in T.Parallel(h, d): acc_o[i, j] *= scores_scale[i] T.gemm(acc_s_cast, kv_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) for i in T.Parallel(h): sum_exp[i] += T.exp(attn_sink[i] - scores_max[i]) for i, j in T.Parallel(h, d): acc_o[i, j] /= sum_exp[i] o_shared = T.alloc_shared((h, d), BF16) T.copy(acc_o, o_shared) T.copy(o_shared, o[by, s_i, H0:H1, :]) return sparse_attn_kernel_ def sparse_attn( q: torch.Tensor, kv: torch.Tensor, attn_sink: torch.Tensor, topk_idxs: torch.Tensor, softmax_scale: float ) -> torch.Tensor: b, s, h, d = q.size() # print(f"Teng {q.size()=}") # Pad heads to 16 for kernel efficiency (stripped after) if h < 16: q = torch.cat([q, q.new_zeros(b, s, 16 - h, d)], dim=2) attn_sink = torch.cat([attn_sink, attn_sink.new_zeros(16 - h)]) o = torch.empty_like(q) kernel = sparse_attn_kernel(q.size(2), d, softmax_scale) kernel(q, kv, o, attn_sink, topk_idxs) if h < 16: o = o.narrow(2, 0, h).contiguous() return o @tilelang.jit(pass_configs=pass_configs) def hc_split_sinkhorn_kernel(hc: int, sinkhorn_iters: int, eps: float): n = T.symbolic("n") mix_hc = (2 + hc) * hc threads = 64 @T.prim_func def hc_split_sinkhorn_kernel_( mixes: T.Tensor[(n, mix_hc), FP32], hc_scale: T.Tensor[(3,), FP32], hc_base: T.Tensor[(mix_hc,), FP32], pre: T.Tensor[(n, hc), FP32], post: T.Tensor[(n, hc), FP32], comb: T.Tensor[(n, hc, hc), FP32], ): with T.Kernel(n, threads=threads) as i: mixes_shared = T.alloc_shared(mix_hc, FP32) comb_frag = T.alloc_fragment((hc, hc), FP32) T.copy(mixes[i, :], mixes_shared) for j in T.Parallel(hc): pre[i, j] = T.sigmoid(mixes_shared[j] * hc_scale[0] + hc_base[j]) + eps for j in T.Parallel(hc): post[i, j] = 2 * T.sigmoid(mixes_shared[j + hc] * hc_scale[1] + hc_base[j + hc]) for j, k in T.Parallel(hc, hc): comb_frag[j, k] = mixes_shared[j * hc + k + hc * 2] * hc_scale[2] + hc_base[j * hc + k + hc * 2] row_sum = T.alloc_fragment(hc, FP32) col_sum = T.alloc_fragment(hc, FP32) # comb = comb.softmax(-1) + eps row_max = T.alloc_fragment(hc, FP32) T.reduce_max(comb_frag, row_max, dim=1) for j, k in T.Parallel(hc, hc): comb_frag[j, k] = T.exp(comb_frag[j, k] - row_max[j]) T.reduce_sum(comb_frag, row_sum, dim=1) for j, k in T.Parallel(hc, hc): comb_frag[j, k] = comb_frag[j, k] / row_sum[j] + eps # comb = comb / (comb.sum(-2) + eps) T.reduce_sum(comb_frag, col_sum, dim=0) for j, k in T.Parallel(hc, hc): comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) for _ in T.serial(sinkhorn_iters - 1): # comb = comb / (comb.sum(-1) + eps) T.reduce_sum(comb_frag, row_sum, dim=1) for j, k in T.Parallel(hc, hc): comb_frag[j, k] = comb_frag[j, k] / (row_sum[j] + eps) # comb = comb / (comb.sum(-2) + eps) T.reduce_sum(comb_frag, col_sum, dim=0) for j, k in T.Parallel(hc, hc): comb_frag[j, k] = comb_frag[j, k] / (col_sum[k] + eps) T.copy(comb_frag, comb[i, :, :]) return hc_split_sinkhorn_kernel_ def hc_split_sinkhorn(mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, hc_mult: int = 4, sinkhorn_iters: int = 20, eps: float = 1e-6): b, s, _ = mixes.size() pre = mixes.new_empty(b, s, hc_mult) post = mixes.new_empty(b, s, hc_mult) comb = mixes.new_empty(b, s, hc_mult, hc_mult) kernel = hc_split_sinkhorn_kernel(hc_mult, sinkhorn_iters, eps) kernel(mixes.view(-1, (2 + hc_mult) * hc_mult), hc_scale, hc_base, pre.view(-1, hc_mult), post.view(-1, hc_mult), comb.view(-1, hc_mult, hc_mult)) return pre, post, comb