Commit 2ea45ae9 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix layout inference for free fragment buffer (#443)

* [Enhancement] Improve layout inference accuracy in ParallelOp (#441)

* Added logic to use non-replicated buffers as source buffers for more accurate layout inference.
* Enhanced comments to clarify the rationale behind buffer selection in layout inference process.

* [Enhancement] Add error handling macros and refactor loop partitioning logic

* Introduced TILELANG_CHECK macro for improved error handling in CUDA and HIP code, providing detailed error messages for kernel launches.
* Enhanced loop partitioning logic to handle fragment buffers more effectively, ensuring correct replication based on thread extent.
* Added logging for thread range in PlanLoopPartition to aid in debugging and performance analysis.
* Updated pass configuration management to streamline vectorization control in the optimization process.

* lint fix

* remove debug print
parent 734c7fbe
...@@ -25,6 +25,16 @@ using int4_t = int4; ...@@ -25,6 +25,16 @@ using int4_t = int4;
#define TL_DEVICE_NOINLINE __noinline__ __device__ #define TL_DEVICE_NOINLINE __noinline__ __device__
#define TL_PATCH #define TL_PATCH
#define TILELANG_CHECK(stmt) \
do { \
cudaError_t __err = (stmt); \
if (__err != cudaSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \
return -1; \
} \
} while (0)
// abs function for bfloat_t and half_t since there is no implicit convertion // abs function for bfloat_t and half_t since there is no implicit convertion
// method // method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) { TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
......
...@@ -24,6 +24,17 @@ ...@@ -24,6 +24,17 @@
#define ushort unsigned short #define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__ #define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TILELANG_CHECK(stmt) \
do { \
hipError_t __err = (stmt); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define half _Float16 #define half _Float16
#define __float2half_rn(x) half(x) #define __float2half_rn(x) half(x)
......
...@@ -141,10 +141,31 @@ public: ...@@ -141,10 +141,31 @@ public:
PrimExpr thd = FloorMod(access_idx, num_thread); PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size + PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
FloorMod(flattened, vectorize_size); FloorMod(flattened, vectorize_size);
return Fragment(loop_vars_, {idx}, {thd}, {}); auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
if (has_fragment_) {
// for fragment buffer, we don't need to replicate the loop layout
auto thread_extent = *as_const_int(fragment->ThreadExtent());
auto num_thread_fragment = num_thread / thread_extent;
fragment = fragment->Replicate(num_thread_fragment);
}
return fragment;
} }
private: private:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
has_fragment_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
has_fragment_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *node) final { void VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) { if (node->kind == ForKind::kParallel) {
body_ = node->body; body_ = node->body;
...@@ -157,6 +178,7 @@ private: ...@@ -157,6 +178,7 @@ private:
Stmt body_; Stmt body_;
PrimExpr flattened = 0; PrimExpr flattened = 0;
bool has_fragment_ = false;
Array<IterVar> loop_vars_; Array<IterVar> loop_vars_;
}; };
......
...@@ -96,6 +96,7 @@ from . import ( ...@@ -96,6 +96,7 @@ from . import (
language, # noqa: F401 language, # noqa: F401
engine, # noqa: F401 engine, # noqa: F401
) )
from .transform import PassConfigKey # noqa: F401
from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401 from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
......
from tvm import tir, IRModule from tvm import tir, IRModule
from tvm.target import Target from tvm.target import Target
import tilelang import tilelang
from tilelang.transform import PassContext
from typing import Optional
def allow_tma_and_warp_specialized(target: Target) -> bool: def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90"}: if target.arch not in {"sm_90"}:
return False return False
cur_pass_ctx = tilelang.transform.get_pass_context() disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = cur_pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_warp_specialized = cur_pass_ctx.config.get("tl.disable_warp_specialized", False) disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not (disable_tma_lower and disable_warp_specialized) return not (disable_tma_lower and disable_warp_specialized)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
return not disable_vectorize
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module # Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod) mod = tir.transform.BindTarget(target)(mod)
...@@ -38,8 +50,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -38,8 +50,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
pass_ctx = tilelang.transform.get_pass_context()
# which may be introduced by the LegalizeSafeMemoryAccess # which may be introduced by the LegalizeSafeMemoryAccess
if allow_tma_and_warp_specialized(target): if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.IfStmtBinding()(mod) mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
...@@ -57,11 +70,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -57,11 +70,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod) mod = tilelang.transform.MergeIfStmt()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod) mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod)
......
...@@ -42,7 +42,7 @@ extern "C" int init() {{ ...@@ -42,7 +42,7 @@ extern "C" int init() {{
PREDEF_HOST_FUNC = """ PREDEF_HOST_FUNC = """
extern "C" int call({}) {{ extern "C" int call({}) {{
{} {}
return 0; \treturn 0;
}} }}
""" """
...@@ -193,7 +193,7 @@ class TLCUDASourceWrapper(object): ...@@ -193,7 +193,7 @@ class TLCUDASourceWrapper(object):
p = int(p) p = int(p)
return str(p).replace("//", "/") return str(p).replace("//", "/")
_call_str = """""" kernel_launch_code = """"""
desc_name_map: Dict[str, str] = {} desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items(): for function_name, function_info in function_informations.items():
block_info = function_info["block_info"] block_info = function_info["block_info"]
...@@ -218,14 +218,18 @@ class TLCUDASourceWrapper(object): ...@@ -218,14 +218,18 @@ class TLCUDASourceWrapper(object):
grid_str = "dim3({}, {}, {})".format( grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
_call_str += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(function_name, grid_str, kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
block_str, smem_str, function_name, grid_str, block_str, smem_str, call_args)
call_args) kernel_launch_code += "\tcudaError_t err = cudaGetLastError();\n"
kernel_launch_code += "\tif (err != cudaSuccess) {{\n"
kernel_launch_code += f"\t\tsnprintf(error_buf, ERROR_BUF_SIZE, \"{function_name}: %s - %s\", cudaGetErrorName(err), cudaGetErrorString(err));\n"
kernel_launch_code += "\t\treturn -1;\n"
kernel_launch_code += "\t}}\n"
_call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
# Wrap the kernel dispatch logic in an external C function # Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str) host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
return host_func return host_func
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str: def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
......
...@@ -3,12 +3,14 @@ ...@@ -3,12 +3,14 @@
from . import _ffi_api from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401 from .simplify import Simplify, simplify_prim_func # noqa: F401
from .pass_config import PassConfigKey # noqa: F401
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401
def get_pass_context(): def get_pass_context():
"""Get the current pass context""" """Get the current pass context"""
from tilelang import tvm as tvm return PassContext.current()
return tvm.transform.PassContext.current()
def ClusterPlanning(): def ClusterPlanning():
......
# TODO: Add more documentation for each pass config
from enum import Enum
class PassConfigKey(str, Enum):
"""Pass configuration keys for TileLang compiler."""
# TileLang specific configs
TL_SIMPLIFY = "tl.Simplify"
"""Enable/disable TileLang simplification passes. Default: True"""
TL_DYNAMIC_ALIGNMENT = "tl.dynamic_alignment"
"""Memory alignment requirement for dynamic shapes. Default: 16"""
TL_DISABLE_DYNAMIC_TAIL_SPLIT = "tl.disable_dynamic_tail_split"
"""Disable dynamic tail splitting optimization. Default: False"""
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
"""Disable warp specialization optimization. Default: False"""
TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
"""Bitwidth for configuration indices. Default: 32"""
TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower"
"""Disable TMA (Tensor Memory Access) lowering. Default: False"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
TIR_DISABLE_CSE = "tir.disable_cse_tir"
"""Disable TIR Common Subexpression Elimination. Default: False"""
TIR_SIMPLIFY = "tir.Simplify"
"""Enable/disable TIR simplification passes. Default: True"""
TIR_DISABLE_STORAGE_REWRITE = "tir.disable_storage_rewrite"
"""Disable storage rewrite optimization. Default: False"""
TIR_DISABLE_VECTORIZE = "tir.disable_vectorize"
"""Disable vectorization optimization. Default: False"""
TIR_USE_ASYNC_COPY = "tir.use_async_copy"
"""Enable asynchronous memory copy operations. Default: True"""
TIR_ENABLE_DEBUG = "tir.enable_debug"
"""Enable debug information in generated code. Default: False"""
TIR_MERGE_STATIC_SMEM = "tir.merge_static_smem"
"""Merge static shared memory allocations. Default: True"""
TIR_ADD_LOWER_PASS = "tir.add_lower_pass"
"""Additional lowering passes to be applied. Default: None"""
TIR_NOALIAS = "tir.noalias"
"""Enable pointer non-aliasing assumptions. Default: True"""
CUDA_KERNELS_OUTPUT_DIR = "cuda.kernels_output_dir"
"""Output directory for generated CUDA kernels. Default: empty string"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment