import argparse import torch import tilelang from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, repeat import itertools def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) return out def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): """ Argument: cb: (batch, nchunks, ngroups, chunk_size, chunk_size) x: (batch, seqlen, nheads, headdim) dt: (batch, nheads, nchunks, chunk_size) dA_cumsum: (batch, nheads, nchunks, chunk_size) C: (batch, seqlen, ngroups, dstate) prev_states: (batch, nchunks, nheads, headdim, dstate) D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) Return: out: (batch, seqlen, nheads, headdim) """ _, _, ngroups, _, _ = cb.shape batch, seqlen, nheads, headdim = x.shape # _, _, ngroups, dstate = B.shape # assert B.shape == (batch, seqlen, ngroups, dstate) _, _, nchunks, chunk_size = dt.shape assert seqlen == nchunks * chunk_size # assert C.shape == B.shape # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups) # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) # (batch, nheads, nchunks, chunksize, chunksize) dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] decay = torch.exp(dt_segment_sum) scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") causal_mask = torch.tril( torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) scores_decay = scores_decay.masked_fill(~causal_mask, 0) out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out out = out + out_prev out = rearrange(out, "b c l h p -> b (c l) h p") if D is not None: if D.dim() == 1: D = rearrange(D, "h -> h 1") out = out + x * D return out def get_configs(): iter_params = dict( block_M=[64, 128, 256], block_N=[32, 64], block_K=[64, 128, 256], block_Dstate=[128], num_stages=[1, 2, 3, 4, 5]) return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] @autotune(configs=get_configs(), warmup=10, rep=10) @tilelang.jit( out_idx=[7], pass_configs={ tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, }, ) def chunk_scan_fwd(batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128): dtype = "float16" accum_dtype = "float" nchunks = T.ceildiv(seqlen, chunk_size) p = 1.44269504 @T.prim_func def main( cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore D: T.Tensor((nheads), dtype), # type: ignore Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore ): with T.Kernel( nheads, T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), batch * nchunks, threads=threads) as (bz, bx, by): acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) acc_o_shared = T.alloc_shared((block_M, block_N), dtype) cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") cb_local = T.alloc_fragment((block_M, block_K), dtype) dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) dt_shared = T.alloc_shared((block_K), dtype, scope="shared") dt_local = T.alloc_fragment((block_K), accum_dtype) x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") scale_m_local = T.alloc_fragment((block_M), accum_dtype) C_shared = T.alloc_shared((block_M, block_Dstate), dtype) prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) D_local = T.alloc_fragment((1), accum_dtype) x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) batch_idx = by % batch chunk_idx = by // batch # m: chunk_size # n : headdim m_idx = bx // T.ceildiv(headdim, block_N) n_idx = bx % T.ceildiv(headdim, block_N) T.annotate_layout({ acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) }) T.no_set_max_nreg() T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], dA_cs_m_shared) T.copy(dA_cs_m_shared, dA_cs_m_local) T.clear(acc_o) for i in T.Parallel(block_M): scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) T.copy( C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) T.copy( prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, 0:block_Dstate], prev_state_shared) T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] *= scale_m_local[i] loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( cb[batch_idx, chunk_idx, bz // (nheads // ngroups), m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], cb_shared) T.copy(cb_shared, cb_local) T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dA_cs_k_shared) T.copy(dA_cs_k_shared, dA_cs_k_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] = cb_local[i, j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) T.copy(dt_shared, dt_local) for i, j in T.Parallel(block_M, block_K): cb_local[i, j] *= dt_local[j] for i, j in T.Parallel(block_M, block_K): cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, cb_local[i, j], 0) T.copy( x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) T.gemm(cb_local, x_shared, acc_o) D_local[0] = D[bz] T.copy( x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], x_residual_shared) T.copy(x_residual_shared, x_residual_local) for i, j in T.Parallel(block_M, block_N): acc_o[i, j] += x_residual_local[i, j] * D_local[0] T.copy(acc_o, acc_o_shared) T.copy( acc_o_shared, Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) return main if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') parser.add_argument('--heads', type=int, default=80, help='heads') parser.add_argument('--groups', type=int, default=1, help='groups') parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') parser.add_argument('--dim', type=int, default=64, help='dim') parser.add_argument('--dstate', type=int, default=128, help='dstate') parser.add_argument('--tune', action='store_true', help='tune configs') args = parser.parse_args() batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate if (not args.tune): kernel = chunk_scan_fwd( batch, seq_len, chunk_size, groups, heads, dim, dstate, block_M=64, block_N=64, block_K=64, block_Dstate=128, num_stages=2, threads=128) profiler = kernel.get_profiler(tilelang.TensorSupplyType.Normal) profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) print("All checks pass.") latency = profiler.do_bench(ref_program, warmup=500) print("Ref: {:.2f} ms".format(latency)) print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) latency = profiler.do_bench(warmup=500) print("Tile-lang: {:.2f} ms".format(latency)) print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) else: kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) best_latency = kernel.latency best_config = kernel.config ref_latency = kernel.ref_latency print(f"Best latency: {best_latency}") print(f"Best TFlops: {total_flops / best_latency * 1e-9}") print(f"Best config: {best_config}")