Commit be9abf18 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev][Benchmark] Add MLA paged decoding example and benchmark script (#158)

* [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16

- Remove redundant `acc_s_0` fragment in flash attention kernel
- Simplify memory copy and reduction operations
- Reorder memory copy and scaling steps for improved performance
- Add Hopper-specific synchronization method in CUDA reduce template
- Update reduce operation to use architecture-specific synchronization

* [Dev] Add DeepSeek MLA Decoding (Paged+Varlen) kernel and Performance Benchmark Script

- Implement comprehensive MLA (Multi-Head Latent Attention) decoding benchmark script
- Add support for multiple implementations: Torch, TileLang, FlashMLA, FlashInfer, and Triton
- Create flexible configuration for benchmarking different batch sizes, sequence lengths, and head configurations
- Implement performance comparison and CSV output for detailed performance analysis
- Add command-line argument support for targeted benchmarking and comparison

* [Dev] Refactor MLA Paged Decoding Kernel with Improved Block Handling and Precision

- Replace `d` parameter with `dv` to clarify value dimension in MLA decoding
- Enhance block distribution logic for split KV processing
- Improve handling of remaining blocks in split KV computation
- Add initialization of `lse_max_local` to prevent potential precision issues
- Optimize block start and range calculations for more accurate sequence processing

* lint
parent 3c53297b
This diff is collapsed.
...@@ -182,6 +182,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_ ...@@ -182,6 +182,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.clear(lse_logsum_local) T.clear(lse_logsum_local)
T.clear(o_accum_local) T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split): for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1): for k in T.Pipelined(num_split, num_stages=1):
......
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
from tilelang.profiler import do_bench
import math
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = h_q // h_kv
VALID_BLOCK_H = min(block_H, kv_group_num)
assert h_kv == 1, "h_kv must be 1"
assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N"
@T.macro
def flash_mla_kernel(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Buffer([batch], "int32"),
Output: T.Buffer([batch, h_q, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
KV_shared = T.alloc_shared([block_N, dv], dtype)
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
for kr in T.Pipelined(loop_range, num_stages=2):
k = loop_range - 1 - kr
kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
with T.If(kr == 0), T.Then():
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_mla_split_kv_kernel(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
KV_shared = T.alloc_shared([block_N, dv], dtype)
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N)
blocks_per_split = T.floordiv(total_blocks, num_split)
remaining_blocks = T.floormod(total_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0))
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
):
with T.Kernel(h_q, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dv], dtype)
o_accum_local = T.alloc_fragment([dv], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dv):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dv):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dv):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
):
flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
# cache_seqlens: [b]
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out.to(dtype), lse.to(dtype)
out_torch, _ = ref_mla()
return out_torch
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
BLOCK_N = 64
BLOCK_H = 64
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang():
out = mod.func(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe),
block_table,
cache_seqlens,
glse,
out_partial,
)
return out.view([b, s_q, h_q, dv])
out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
print("All close")
return out_flash, t
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--h_q', type=int, default=128, help='q heads number')
parser.add_argument('--h_kv', type=int, default=1, help='kv heads number')
parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length')
parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe')
parser.add_argument('--dv', type=int, default=512, help='value head dim')
args = parser.parse_args()
b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv
device = "cuda"
dtype = torch.float16
s_q = 1 # for decode, s_q = 1
block_size = 64
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device)
dpe = d - dv
causal = True
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256
total_flops = s_q * total_seqlens * h_q * (d + dv) * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
...@@ -72,7 +72,7 @@ def get_configs(M, N, K, with_roller=False): ...@@ -72,7 +72,7 @@ def get_configs(M, N, K, with_roller=False):
if roller_hints is None: if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling") raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = [] configs = []
for hint in roller_hints: for hint in roller_hints:
config = {} config = {}
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License. # Licensed under the MIT License.
"""The auto-tune module for tl programs.""" """The auto-tune module for tilelang programs."""
import tilelang as tl import tilelang
from tilelang import tvm as tvm from tilelang import tvm as tvm
import inspect import inspect
from functools import wraps from functools import wraps
...@@ -21,9 +21,9 @@ logging.basicConfig( ...@@ -21,9 +21,9 @@ logging.basicConfig(
@dataclass(frozen=True) @dataclass(frozen=True)
class JITContext: class JITContext:
mod: tl.Profiler mod: tilelang.Profiler
out_idx: List[int] out_idx: List[int]
supply_type: tl.TensorSupplyType supply_type: tilelang.TensorSupplyType
ref_prog: Callable ref_prog: Callable
rtol: float rtol: float
atol: float atol: float
...@@ -144,7 +144,7 @@ def autotune(configs: Any, ...@@ -144,7 +144,7 @@ def autotune(configs: Any,
rep: int = 100, rep: int = 100,
timeout: int = 100) -> Callable: timeout: int = 100) -> Callable:
""" """
Decorator for tl program Decorator for tilelang program
""" """
def decorator(fn: Callable) -> Autotuner: def decorator(fn: Callable) -> Autotuner:
...@@ -154,7 +154,7 @@ def autotune(configs: Any, ...@@ -154,7 +154,7 @@ def autotune(configs: Any,
def jit(out_idx: List[int], def jit(out_idx: List[int],
supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal, supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal,
ref_prog: Callable = None, ref_prog: Callable = None,
rtol: float = 1e-2, rtol: float = 1e-2,
atol: float = 1e-2, atol: float = 1e-2,
...@@ -169,9 +169,9 @@ def jit(out_idx: List[int], ...@@ -169,9 +169,9 @@ def jit(out_idx: List[int],
def decorator(*args, **kwargs) -> float: def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion # Enabling Efficient Fusion
with tvm.transform.PassContext(config={"tir.merge_static_smem": True}): with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
mod, params = tl.lower(fn(*args, **kwargs), target=target) mod, params = tilelang.lower(fn(*args, **kwargs), target=target)
mod = tl.Profiler(mod, params, out_idx, supply_type) mod = tilelang.Profiler(mod, params, out_idx, supply_type)
return JITContext( return JITContext(
mod=mod, mod=mod,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from tvm import tir, IRModule from tvm import tir, IRModule
from tvm.target import Target from tvm.target import Target
import tilelang as tl import tilelang
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
...@@ -11,17 +11,17 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -11,17 +11,17 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.BindTarget(target)(mod) mod = tir.transform.BindTarget(target)(mod)
# Legalize the frontend IR to make it compatible with TVM # Legalize the frontend IR to make it compatible with TVM
mod = tl.transform.FrontendLegalize()(mod) mod = tilelang.transform.FrontendLegalize()(mod)
# Simplify the IR expressions # Simplify the IR expressions
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
# Infer memory layouts for fragments and shared memory # Infer memory layouts for fragments and shared memory
mod = tl.transform.LayoutInference()(mod) mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations # Lower high-level tile operations to low-level operations
mod = tl.transform.LowerTileOp()(mod) mod = tilelang.transform.LowerTileOp()(mod)
# Legalize vectorized loops to ensure they are valid # Legalize vectorized loops to ensure they are valid
mod = tl.transform.LegalizeVectorizedLoop()(mod) mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses # Add safety checks for memory accesses
mod = tl.transform.LegalizeSafeMemoryAccess()(mod) mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Simplify again to clean up any duplicated conditions # Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks # that may have been introduced by safety checks
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
...@@ -32,23 +32,23 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -32,23 +32,23 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# which may be introduced by the LegalizeSafeMemoryAccess # which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90": if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.RewriteWgmmaSync()(mod) mod = tilelang.transform.RewriteWgmmaSync()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod) # mod = tilelang.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
else: else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod) mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tl.transform.VectorizeLoop()(mod) mod = tilelang.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod)
...@@ -68,19 +68,19 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -68,19 +68,19 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# We can find a way better to create var instead # We can find a way better to create var instead
# of putting the LowerThreadAllreduce before # of putting the LowerThreadAllreduce before
# the Legalization. # the Legalization.
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) mod = tilelang.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod) mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod) mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod) mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tl.transform.ThreadSync("shared")(mod) mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod) mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod) mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tl.transform.AnnotateDeviceRegions()(mod) mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod) mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tl.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
return mod return mod
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment