Commit 2b97e98a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Example] Add GQA Example (#118)

* Add DeepSeek MLA decode example with Flash Attention implementation

* Add GEMM SplitK and StreamK example implementations

This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques:
- `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang
- `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang

Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations.

* Refactor GEMM SplitK and StreamK example implementations

Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts:
- Remove unused import (Profiler) in splitk example
- Simplify line breaks and improve code readability
- Standardize indentation and remove unnecessary whitespace
- Optimize atomic add and copy operations for better clarity

* Add block sparse attention benchmarks for multiple libraries

This commit introduces comprehensive block sparse attention benchmarks for different libraries:
- TileLang block sparse FMHA implementation
- Triton block sparse FMHA implementation
- PyTorch reference block sparse FMHA implementation
- FlashAttention dense FMHA reference implementation

The benchmarks include:
- Configurable benchmark parameters (batch size, heads, sequence length, etc.)
- Sparse mask generation using top-k and threshold methods
- Performance measurement for different sparse attention configurations
- Utility functions for mask generation and benchmarking

* Refactor block sparse attention benchmarks with code style improvements

- Add Ruff linter ignore comments to benchmark files
- Improve code formatting and line breaks
- Remove unused imports
- Standardize print statement formatting
- Enhance code readability across multiple library benchmarks

* lint fix

* Add CUDA atomic operations for BFLOAT16 and update function naming

- Implement AtomicAdd functions for BFLOAT16 and BFLOAT16x2 in CUDA common header
- Rename existing atomic add functions to use PascalCase (atomicAdd -> AtomicAdd)
- Add a new __pack_nv_bfloat162 function for packing BFLOAT16 values
- Update kernel and language customization to use new function names
- Add return type annotations in profiler module

* lint fix

* Add example for Group Query Attention (GQA) forward pass using Flash Attention in TileLang

This commit introduces a new example script `example_gqa_fwd_bshd.py` that demonstrates:
- Group Query Attention (GQA) implementation
- Flash Attention forward pass
- Performance benchmarking
- Configurable parameters for batch, heads, sequence length, and dimension
- Autotuning support
- Reference implementation comparison

* Refactor IR lowering pipeline into modular phases

This commit introduces a new module `phase.py` to modularize the IR lowering process by splitting the complex lowering pipeline into two distinct phases:
- `LowerAndLegalize`: Handles initial IR legalization and transformation
- `OptimizeForTarget`: Applies target-specific optimizations

The changes simplify the lowering logic in multiple files by extracting the transformation steps into reusable functions, improving code readability and maintainability.

* lintfix
parent b7ca76f1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial
def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
def MMA0(
K: T.Buffer(kv_shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def MMA1(
V: T.Buffer(kv_shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)
@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]
@T.prim_func
def main(
Q: T.Buffer(q_shape, dtype),
K: T.Buffer(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype),
Output: T.Buffer(q_shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
return main
if tune:
@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel()
else:
def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)
return kernel
def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"
dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument('--groups', type=int, default=8, help='groups')
args = parser.parse_args()
batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5
if (not args.tune):
program = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)(
block_M=128, block_N=128, num_stages=1, threads=128)
ref_program = partial(ref_program, is_causal=is_causal, groups=groups)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# Licensed under the MIT License. # Licensed under the MIT License.
"""The compiler for TL programs.""" """The compiler for TL programs."""
import tilelang as tl
import os import os
import os.path as osp import os.path as osp
from typing import Union, Optional, Callable from typing import Union, Optional, Callable
...@@ -12,6 +11,10 @@ from tvm.ir import CallingConv ...@@ -12,6 +11,10 @@ from tvm.ir import CallingConv
from tvm.target import Target from tvm.target import Target
from tilelang.contrib import hipcc, nvcc from tilelang.contrib import hipcc, nvcc
from tilelang.utils.target import determine_target from tilelang.utils.target import determine_target
from tilelang.engine.phase import (
LowerAndLegalize,
OptimizeForTarget,
)
def is_cpu_device_backend(target: Target): def is_cpu_device_backend(target: Target):
...@@ -152,68 +155,12 @@ def lower( ...@@ -152,68 +155,12 @@ def lower(
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))
mod = tir.transform.BindTarget(target)(mod) # Phase 1: Lower and legalize the IR
mod = LowerAndLegalize(mod, target)
mod = tl.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod) # Phase 2: Optimize the IR for the target
mod = tl.transform.LayoutInference()(mod) mod = OptimizeForTarget(mod, target)
mod = tl.transform.LowerTileOp()(mod)
mod = tl.transform.LegalizeVectorizedLoop()(mod)
mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
# Inject Simplify to remove the duplicated conditions
mod = tir.transform.Simplify()(mod)
# which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tl.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tl.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
host_mod = tir.transform.Filter(_is_host_call)(mod) host_mod = tir.transform.Filter(_is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod) host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod) host_mod = tir.transform.FP8StorageLegalize()(host_mod)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import tir, IRModule
from tvm.target import Target
import tilelang as tl
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod)
# Legalize the frontend IR to make it compatible with TVM
mod = tl.transform.FrontendLegalize()(mod)
# Simplify the IR expressions
mod = tir.transform.Simplify()(mod)
# Infer memory layouts for fragments and shared memory
mod = tl.transform.LayoutInference()(mod)
# Lower high-level tile operations to low-level operations
mod = tl.transform.LowerTileOp()(mod)
# Legalize vectorized loops to ensure they are valid
mod = tl.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses
mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
# Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks
mod = tir.transform.Simplify()(mod)
return mod
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tl.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tl.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
return mod
...@@ -5,12 +5,15 @@ from typing import Union, Optional ...@@ -5,12 +5,15 @@ from typing import Union, Optional
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
import tilelang.transform
from tilelang.engine.lower import ( from tilelang.engine.lower import (
is_device_call, is_device_call,
determine_target, determine_target,
canon_target_host, canon_target_host,
) )
from tilelang.engine.phase import (
LowerAndLegalize,
OptimizeForTarget,
)
def match_global_kernel(source: str) -> int: def match_global_kernel(source: str) -> int:
...@@ -47,58 +50,8 @@ def get_annotated_device_mod( ...@@ -47,58 +50,8 @@ def get_annotated_device_mod(
target_host = tvm.target.Target.canon_target(target_host) target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host) target = tvm.target.Target(target, target_host)
mod = tir.transform.BindTarget(target)(mod) mod = LowerAndLegalize(mod, target)
mod = OptimizeForTarget(mod, target)
mod = tilelang.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.LayoutInference()(mod)
mod = tilelang.transform.LowerTileOp()(mod)
mod = tir.transform.Simplify()(mod)
if target.arch == "sm_90":
mod = tilelang.transform.WarpSpecializedPipeline()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tir.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
device_mod = tir.transform.Filter(is_device_call)(mod) device_mod = tir.transform.Filter(is_device_call)(mod)
return device_mod return device_mod
...@@ -5,12 +5,15 @@ from typing import Union, Optional ...@@ -5,12 +5,15 @@ from typing import Union, Optional
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tvm import IRModule, tir from tvm import IRModule, tir
from tvm.target import Target from tvm.target import Target
import tilelang.transform
from tilelang.engine.lower import ( from tilelang.engine.lower import (
is_device_call, is_device_call,
determine_target, determine_target,
canon_target_host, canon_target_host,
) )
from tilelang.engine.phase import (
LowerAndLegalize,
OptimizeForTarget,
)
def match_global_kernel(source: str) -> int: def match_global_kernel(source: str) -> int:
...@@ -47,57 +50,8 @@ def get_annotated_device_mod( ...@@ -47,57 +50,8 @@ def get_annotated_device_mod(
target_host = tvm.target.Target.canon_target(target_host) target_host = tvm.target.Target.canon_target(target_host)
target = tvm.target.Target(target, target_host) target = tvm.target.Target(target, target_host)
mod = tir.transform.BindTarget(target)(mod) mod = LowerAndLegalize(mod, target)
mod = OptimizeForTarget(mod, target)
mod = tilelang.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.LayoutInference()(mod)
mod = tilelang.transform.LowerTileOp()(mod)
mod = tir.transform.Simplify()(mod)
if target.arch == "sm_90":
mod = tilelang.transform.WarpSpecializedPipeline()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)
mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
mod = tir.transform.ThreadSync("shared")(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tir.transform.ThreadSync("shared.dyn")(mod)
mod = tilelang.transform.LowerHopperIntrin()(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)
mod = tir.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)
mod = tir.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
device_mod = tir.transform.Filter(is_device_call)(mod) device_mod = tir.transform.Filter(is_device_call)(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment