Unverified Commit 95e7bc37 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Benchmark] Update triton and helion baselines in mamba-chuk-scan (#1131)

* [Benchmark] Update triton and helion baselines in mamba-chuk-scan

* lint

* update mamba baseline version
parent 6e1dc6a1
......@@ -45,6 +45,12 @@ PY
| 16384 | 2.531 | 135.711 |
| 32768 | 5.076 | 135.379 |
## Compare with Baselines
- Triton: v3.5.0, mamba-ssm: v2.2.6.post3
- Helion: v0.2.1
<figure style="text-align: center">
<a href="mamba_benchmark_result.png">
<img src="mamba_benchmark_result.png" alt="Mamba2_chunk_scan Performance Comparison on H100">
......
......@@ -5,6 +5,20 @@ from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, repeat
import itertools
import math
from tilelang.profiler import do_bench
try:
from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd
except ImportError as err:
raise ImportError("Please install mamba-ssm to use the triton chunk scan operator.") from err
try:
import helion
from helion._testing import run_example
import helion.language as hl
except ImportError as err:
raise ImportError("Please install helion to use the helion chunk scan operator.") from err
def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
......@@ -54,6 +68,119 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D):
return out
def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D):
out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D)
return out
def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D):
@helion.kernel()
def helion_mamba2_chunk_scan_kernel(
cb: torch.Tensor,
x: torch.Tensor,
dt: torch.Tensor,
dA_cumsum: torch.Tensor,
C: torch.Tensor,
prev_states: torch.Tensor,
D: torch.Tensor,
) -> torch.Tensor:
"""
Argument:
cb: (batch, nchunks, ngroups, chunk_size, chunk_size)
x: (batch, seqlen, nheads, headdim)
dt: (batch, nheads, nchunks, chunk_size)
dA_cumsum: (batch, nheads, nchunks, chunk_size)
C: (batch, seqlen, ngroups, dstate)
prev_states: (batch, nchunks, nheads, headdim, dstate)
D: (nheads,)
Return:
out: (batch, seqlen, nheads, headdim)
"""
batch, nchunks, ngroups, chunk_size, _ = cb.shape
_, seqlen, nheads, headdim = x.shape
_, _, _, dstate = C.shape
assert nchunks == (seqlen + chunk_size - 1) // chunk_size
block_m = hl.register_block_size(chunk_size)
block_n = hl.register_block_size(headdim)
block_k = hl.register_block_size(64, 64)
dstate = hl.specialize(dstate)
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert C.shape == (batch, seqlen, ngroups, dstate)
assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate)
assert D.shape == (nheads,)
dtype = cb.dtype
accum_dtype = torch.float32
assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype ==
dtype)
out = torch.empty_like(x)
p = 1.44269504
for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile(
[nheads, chunk_size, headdim, batch, nchunks],
block_size=[1, block_m, block_n, 1, 1],
):
acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype)
dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
tile_m].to(torch.float32)
scale_m_local = torch.exp2(dA_cumsum_local_m * p)
C_local = C[
tile_b.begin,
tile_m.index + tile_c.begin * chunk_size,
tile_h.begin // (nheads // ngroups),
:,
]
prev_states_local = prev_states[tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :]
acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o)
acc_o *= scale_m_local[:, None]
for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k):
cb_local = cb[
tile_b.begin,
tile_c.begin,
tile_h.begin // (nheads // ngroups),
tile_m,
tile_k,
]
dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin,
tile_k].to(torch.float32)
cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p -
dA_cumsum_local_k[None, :] * p)
dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32)
cb_local = (cb_local * dt_local[None, :]).to(dtype)
pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :]
cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local))
x_local = x[
tile_b.begin,
tile_c.begin * chunk_size + tile_k.index,
tile_h.begin,
tile_n,
]
acc_o = hl.dot(cb_local, x_local, acc=acc_o)
D_local = D[tile_h.begin].to(torch.float32)
x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
tile_n].to(torch.float32)
acc_o += x_residual * D_local
out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin,
tile_n] = acc_o.to(dtype=dtype)
return out
args = (cb, x, dt, dA_cumsum, C, states, D)
run_example(helion_mamba2_chunk_scan_kernel, ref_program, args)
def get_configs():
iter_params = dict(
block_M=[64, 128, 256],
......@@ -212,8 +339,10 @@ if __name__ == "__main__":
parser.add_argument('--tune', action='store_true', help='tune configs')
args = parser.parse_args()
batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate
nchunks = math.ceil(seq_len / chunk_size)
total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate
print("Benchmarking TileLang...")
kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate)
best_latency = kernel.latency
best_config = kernel.config
......@@ -221,3 +350,19 @@ if __name__ == "__main__":
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
cb = torch.randn(batch, nchunks, groups, chunk_size, chunk_size).half().cuda()
x = torch.randn(batch, seq_len, heads, dim).half().cuda()
dt = torch.randn(batch, heads, nchunks, chunk_size).half().cuda()
dA_cumsum = torch.randn(batch, heads, nchunks, chunk_size).half().cuda()
C = torch.randn(batch, seq_len, groups, dstate).half().cuda()
states = torch.randn(batch, nchunks, heads, dim, dstate).half().cuda()
D = torch.randn(heads).half().cuda()
print("Benchmarking Triton...")
triton_latency = do_bench(
lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10)
print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}")
print("Benchmarking Helion...")
chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment