Commit be9abf18 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Dev][Benchmark] Add MLA paged decoding example and benchmark script (#158)

* [Dev] Adjust computation logic to avoid precision loss when casting acc_s from float to float16

- Remove redundant `acc_s_0` fragment in flash attention kernel
- Simplify memory copy and reduction operations
- Reorder memory copy and scaling steps for improved performance
- Add Hopper-specific synchronization method in CUDA reduce template
- Update reduce operation to use architecture-specific synchronization

* [Dev] Add DeepSeek MLA Decoding (Paged+Varlen) kernel and Performance Benchmark Script

- Implement comprehensive MLA (Multi-Head Latent Attention) decoding benchmark script
- Add support for multiple implementations: Torch, TileLang, FlashMLA, FlashInfer, and Triton
- Create flexible configuration for benchmarking different batch sizes, sequence lengths, and head configurations
- Implement performance comparison and CSV output for detailed performance analysis
- Add command-line argument support for targeted benchmarking and comparison

* [Dev] Refactor MLA Paged Decoding Kernel with Improved Block Handling and Precision

- Replace `d` parameter with `dv` to clarify value dimension in MLA decoding
- Enhance block distribution logic for split KV processing
- Improve handling of remaining blocks in split KV computation
- Add initialization of `lse_max_local` to prevent potential precision issues
- Optimize block start and range calculations for more accurate sequence processing

* lint
parent 3c53297b
This diff is collapsed.
......@@ -182,6 +182,7 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
......
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from einops import rearrange, einsum
import argparse
from tilelang.profiler import do_bench
import math
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size):
scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = h_q // h_kv
VALID_BLOCK_H = min(block_H, kv_group_num)
assert h_kv == 1, "h_kv must be 1"
assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N"
@T.macro
def flash_mla_kernel(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Buffer([batch], "int32"),
Output: T.Buffer([batch, h_q, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
KV_shared = T.alloc_shared([block_N, dv], dtype)
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N)
for kr in T.Pipelined(loop_range, num_stages=2):
k = loop_range - 1 - kr
kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
with T.If(kr == 0), T.Then():
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
@T.macro
def flash_mla_split_kv_kernel(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
CACHE_SEQLENS: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
):
with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
Q_shared = T.alloc_shared([block_H, dv], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, dpe], dtype)
KV_shared = T.alloc_shared([block_N, dv], dtype)
K_pe_shared = T.alloc_shared([block_N, dpe], dtype)
O_shared = T.alloc_shared([block_H, dv], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dv], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
cur_kv_head = by // (kv_group_num // block_H)
T.use_swizzle(10)
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
})
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N)
blocks_per_split = T.floordiv(total_blocks, num_split)
remaining_blocks = T.floormod(total_blocks, num_split)
loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0))
start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size
T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared)
T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j])
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dv):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz])
T.copy(acc_o, O_shared)
T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :])
@T.macro
def combine(
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
):
with T.Kernel(h_q, batch, threads=128) as (by, bz):
po_local = T.alloc_fragment([dv], dtype)
o_accum_local = T.alloc_fragment([dv], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bz, by, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dv):
po_local[i] = Output_partial[bz, by, k, i]
lse_local_split[0] = glse[bz, by, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dv):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dv):
Output[bz, by, i] = o_accum_local[i]
@T.prim_func
def main_split(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
):
flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial)
combine(glse, Output_partial, Output)
@T.prim_func
def main_no_split(
Q: T.Buffer([batch, h_q, dv], dtype),
Q_pe: T.Buffer([batch, h_q, dpe], dtype),
KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype),
K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype),
block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"),
cache_seqlens: T.Buffer([batch], "int32"),
glse: T.Buffer([batch, h_q, num_split], dtype),
Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype),
Output: T.Buffer([batch, h_q, dv], dtype),
):
flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output)
if num_split > 1:
return main_split
else:
return main_no_split
def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False):
query = query.float()
key = key.float()
value = value.float()
key = key.repeat_interleave(h_q // h_kv, dim=0)
value = value.repeat_interleave(h_q // h_kv, dim=0)
attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1))
if is_causal:
s_q = query.shape[-2]
s_k = key.shape[-2]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weight += attn_bias
lse = attn_weight.logsumexp(dim=-1)
attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32)
return attn_weight @ value, lse
@torch.inference_mode()
def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
# q: [b, s_q, h_q, d]
# block_table: [b, max_seqlen_pad // block_size]
# blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d]
# cache_seqlens: [b]
blocked_v = blocked_k[..., :dv]
def ref_mla():
out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device)
lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device)
for i in range(b):
begin = i * max_seqlen_pad
end = begin + cache_seqlens[i]
O, LSE = scaled_dot_product_attention(
q[i].transpose(0, 1),
blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),
blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),
h_q, h_kv,
is_causal=causal,
)
out[i] = O.transpose(0, 1)
lse[i] = LSE
return out.to(dtype), lse.to(dtype)
out_torch, _ = ref_mla()
return out_torch
def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
assert d > dv, "mla with rope dim should be larger than no rope dim"
q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous()
blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous()
dpe = d - dv
num_kv_splits = 1
BLOCK_N = 64
BLOCK_H = 64
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device)
program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size)
mod, params = tilelang.lower(program)
mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn)
def flash_mla_tilelang():
out = mod.func(
q_nope.view(-1, h_q, dv),
q_pe.view(-1, h_q, dpe),
blocked_k_nope.view(-1, h_kv, dv),
blocked_k_pe.view(-1, h_kv, dpe),
block_table,
cache_seqlens,
glse,
out_partial,
)
return out.view([b, s_q, h_q, dv])
out_flash = flash_mla_tilelang()
t = do_bench(flash_mla_tilelang)
out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01)
print("All close")
return out_flash, t
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--h_q', type=int, default=128, help='q heads number')
parser.add_argument('--h_kv', type=int, default=1, help='kv heads number')
parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length')
parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe')
parser.add_argument('--dv', type=int, default=512, help='value head dim')
args = parser.parse_args()
b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv
device = "cuda"
dtype = torch.float16
s_q = 1 # for decode, s_q = 1
block_size = 64
cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device)
dpe = d - dv
causal = True
total_seqlens = cache_seqlens.sum().item()
mean_seqlens = cache_seqlens.float().mean().int().item()
max_seqlen = cache_seqlens.max().item()
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256
total_flops = s_q * total_seqlens * h_q * (d + dv) * 2
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size)
blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device)
out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
\ No newline at end of file
......@@ -72,7 +72,7 @@ def get_configs(M, N, K, with_roller=False):
if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
configs = []
for hint in roller_hints:
config = {}
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""The auto-tune module for tl programs."""
"""The auto-tune module for tilelang programs."""
import tilelang as tl
import tilelang
from tilelang import tvm as tvm
import inspect
from functools import wraps
......@@ -21,9 +21,9 @@ logging.basicConfig(
@dataclass(frozen=True)
class JITContext:
mod: tl.Profiler
mod: tilelang.Profiler
out_idx: List[int]
supply_type: tl.TensorSupplyType
supply_type: tilelang.TensorSupplyType
ref_prog: Callable
rtol: float
atol: float
......@@ -144,7 +144,7 @@ def autotune(configs: Any,
rep: int = 100,
timeout: int = 100) -> Callable:
"""
Decorator for tl program
Decorator for tilelang program
"""
def decorator(fn: Callable) -> Autotuner:
......@@ -154,7 +154,7 @@ def autotune(configs: Any,
def jit(out_idx: List[int],
supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal,
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal,
ref_prog: Callable = None,
rtol: float = 1e-2,
atol: float = 1e-2,
......@@ -169,9 +169,9 @@ def jit(out_idx: List[int],
def decorator(*args, **kwargs) -> float:
# Enabling Efficient Fusion
with tvm.transform.PassContext(config={"tir.merge_static_smem": True}):
mod, params = tl.lower(fn(*args, **kwargs), target=target)
mod, params = tilelang.lower(fn(*args, **kwargs), target=target)
mod = tl.Profiler(mod, params, out_idx, supply_type)
mod = tilelang.Profiler(mod, params, out_idx, supply_type)
return JITContext(
mod=mod,
......
......@@ -3,7 +3,7 @@
from tvm import tir, IRModule
from tvm.target import Target
import tilelang as tl
import tilelang
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
......@@ -11,17 +11,17 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tir.transform.BindTarget(target)(mod)
# Legalize the frontend IR to make it compatible with TVM
mod = tl.transform.FrontendLegalize()(mod)
mod = tilelang.transform.FrontendLegalize()(mod)
# Simplify the IR expressions
mod = tir.transform.Simplify()(mod)
# Infer memory layouts for fragments and shared memory
mod = tl.transform.LayoutInference()(mod)
mod = tilelang.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations
mod = tl.transform.LowerTileOp()(mod)
mod = tilelang.transform.LowerTileOp()(mod)
# Legalize vectorized loops to ensure they are valid
mod = tl.transform.LegalizeVectorizedLoop()(mod)
mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses
mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks
mod = tir.transform.Simplify()(mod)
......@@ -32,23 +32,23 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.RewriteWgmmaSync()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.RewriteWgmmaSync()(mod)
# mod = tilelang.transform.WarpSpecializedPipeline()(mod)
mod = tilelang.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tl.transform.VectorizeLoop()(mod)
mod = tilelang.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
......@@ -68,19 +68,19 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tilelang.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tl.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tilelang.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tl.transform.MakePackedAPI()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
return mod
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment