Commit 2ea45ae9 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix layout inference for free fragment buffer (#443)

* [Enhancement] Improve layout inference accuracy in ParallelOp (#441)

* Added logic to use non-replicated buffers as source buffers for more accurate layout inference.
* Enhanced comments to clarify the rationale behind buffer selection in layout inference process.

* [Enhancement] Add error handling macros and refactor loop partitioning logic

* Introduced TILELANG_CHECK macro for improved error handling in CUDA and HIP code, providing detailed error messages for kernel launches.
* Enhanced loop partitioning logic to handle fragment buffers more effectively, ensuring correct replication based on thread extent.
* Added logging for thread range in PlanLoopPartition to aid in debugging and performance analysis.
* Updated pass configuration management to streamline vectorization control in the optimization process.

* lint fix

* remove debug print
parent 734c7fbe
......@@ -25,6 +25,16 @@ using int4_t = int4;
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TL_PATCH
#define TILELANG_CHECK(stmt) \
do { \
cudaError_t __err = (stmt); \
if (__err != cudaSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, cudaGetErrorName(__err), cudaGetErrorString(__err)); \
return -1; \
} \
} while (0)
// abs function for bfloat_t and half_t since there is no implicit convertion
// method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
......
......@@ -24,6 +24,17 @@
#define ushort unsigned short
#define TL_DEVICE __forceinline__ __device__
#define TL_DEVICE_NOINLINE __noinline__ __device__
#define TILELANG_CHECK(stmt) \
do { \
hipError_t __err = (stmt); \
if (__err != hipSuccess) { \
snprintf(error_buf, ERROR_BUF_SIZE, "%s:%d: %s - %s", __FILE__, \
__LINE__, hipGetErrorName(__err), hipGetErrorString(__err)); \
return -1; \
} \
} while (0)
#define half _Float16
#define __float2half_rn(x) half(x)
......
......@@ -141,10 +141,31 @@ public:
PrimExpr thd = FloorMod(access_idx, num_thread);
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size +
FloorMod(flattened, vectorize_size);
return Fragment(loop_vars_, {idx}, {thd}, {});
auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
if (has_fragment_) {
// for fragment buffer, we don't need to replicate the loop layout
auto thread_extent = *as_const_int(fragment->ThreadExtent());
auto num_thread_fragment = num_thread / thread_extent;
fragment = fragment->Replicate(num_thread_fragment);
}
return fragment;
}
private:
void VisitExpr_(const BufferLoadNode *op) final {
if (op->buffer.scope() == "local.fragment") {
has_fragment_ = true;
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
if (op->buffer.scope() == "local.fragment") {
has_fragment_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kParallel) {
body_ = node->body;
......@@ -157,6 +178,7 @@ private:
Stmt body_;
PrimExpr flattened = 0;
bool has_fragment_ = false;
Array<IterVar> loop_vars_;
};
......
......@@ -96,6 +96,7 @@ from . import (
language, # noqa: F401
engine, # noqa: F401
)
from .transform import PassConfigKey # noqa: F401
from .engine import lower, register_cuda_postproc, register_hip_postproc # noqa: F401
......
from tvm import tir, IRModule
from tvm.target import Target
import tilelang
from tilelang.transform import PassContext
from typing import Optional
def allow_tma_and_warp_specialized(target: Target) -> bool:
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90"}:
return False
cur_pass_ctx = tilelang.transform.get_pass_context()
disable_tma_lower = cur_pass_ctx.config.get("tl.disable_tma_lower", False)
disable_warp_specialized = cur_pass_ctx.config.get("tl.disable_warp_specialized", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_warp_specialized = pass_ctx.config.get("tl.disable_warp_specialized", False)
return not (disable_tma_lower and disable_warp_specialized)
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
disable_vectorize = pass_ctx.config.get("tir.disable_vectorize", False)
return not disable_vectorize
def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
# Bind the target device information to the module
mod = tir.transform.BindTarget(target)(mod)
......@@ -38,8 +50,9 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
pass_ctx = tilelang.transform.get_pass_context()
# which may be introduced by the LegalizeSafeMemoryAccess
if allow_tma_and_warp_specialized(target):
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.IfStmtBinding()(mod)
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
......@@ -57,11 +70,14 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.PipelinePlanning()(mod)
mod = tilelang.transform.InjectSoftwarePipeline()(mod)
mod = tilelang.transform.MergeIfStmt()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tilelang.transform.VectorizeLoop()(mod)
mod = tilelang.transform.VectorizeLoop(enable_vectorize=allow_vectorize(pass_ctx=pass_ctx))(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
......
......@@ -42,7 +42,7 @@ extern "C" int init() {{
PREDEF_HOST_FUNC = """
extern "C" int call({}) {{
{}
return 0;
\treturn 0;
}}
"""
......@@ -193,7 +193,7 @@ class TLCUDASourceWrapper(object):
p = int(p)
return str(p).replace("//", "/")
_call_str = """"""
kernel_launch_code = """"""
desc_name_map: Dict[str, str] = {}
for function_name, function_info in function_informations.items():
block_info = function_info["block_info"]
......@@ -218,14 +218,18 @@ class TLCUDASourceWrapper(object):
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
_call_str += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(function_name, grid_str,
block_str, smem_str,
call_args)
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tcudaError_t err = cudaGetLastError();\n"
kernel_launch_code += "\tif (err != cudaSuccess) {{\n"
kernel_launch_code += f"\t\tsnprintf(error_buf, ERROR_BUF_SIZE, \"{function_name}: %s - %s\", cudaGetErrorName(err), cudaGetErrorString(err));\n"
kernel_launch_code += "\t\treturn -1;\n"
kernel_launch_code += "\t}}\n"
_call_str = self.generate_tma_descriptor_args(desc_name_map) + _call_str
kernel_launch_code = self.generate_tma_descriptor_args(desc_name_map) + kernel_launch_code
# Wrap the kernel dispatch logic in an external C function
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
host_func = PREDEF_HOST_FUNC.format(def_args, kernel_launch_code)
return host_func
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
......
......@@ -3,12 +3,14 @@
from . import _ffi_api
from .simplify import Simplify, simplify_prim_func # noqa: F401
from .pass_config import PassConfigKey # noqa: F401
from tilelang import tvm as tvm # noqa: F401
from tvm.ir.transform import PassContext # noqa: F401
def get_pass_context():
"""Get the current pass context"""
from tilelang import tvm as tvm
return tvm.transform.PassContext.current()
return PassContext.current()
def ClusterPlanning():
......
# TODO: Add more documentation for each pass config
from enum import Enum
class PassConfigKey(str, Enum):
"""Pass configuration keys for TileLang compiler."""
# TileLang specific configs
TL_SIMPLIFY = "tl.Simplify"
"""Enable/disable TileLang simplification passes. Default: True"""
TL_DYNAMIC_ALIGNMENT = "tl.dynamic_alignment"
"""Memory alignment requirement for dynamic shapes. Default: 16"""
TL_DISABLE_DYNAMIC_TAIL_SPLIT = "tl.disable_dynamic_tail_split"
"""Disable dynamic tail splitting optimization. Default: False"""
TL_DISABLE_WARP_SPECIALIZED = "tl.disable_warp_specialized"
"""Disable warp specialization optimization. Default: False"""
TL_CONFIG_INDEX_BITWIDTH = "tl.config_index_bitwidth"
"""Bitwidth for configuration indices. Default: 32"""
TL_DISABLE_TMA_LOWER = "tl.disable_tma_lower"
"""Disable TMA (Tensor Memory Access) lowering. Default: False"""
# TIR related configs
TIR_ENABLE_EQUIV_TERMS_IN_CSE = "tir.enable_equiv_terms_in_cse_tir"
"""Enable equivalent terms in TIR Common Subexpression Elimination. Default: True"""
TIR_DISABLE_CSE = "tir.disable_cse_tir"
"""Disable TIR Common Subexpression Elimination. Default: False"""
TIR_SIMPLIFY = "tir.Simplify"
"""Enable/disable TIR simplification passes. Default: True"""
TIR_DISABLE_STORAGE_REWRITE = "tir.disable_storage_rewrite"
"""Disable storage rewrite optimization. Default: False"""
TIR_DISABLE_VECTORIZE = "tir.disable_vectorize"
"""Disable vectorization optimization. Default: False"""
TIR_USE_ASYNC_COPY = "tir.use_async_copy"
"""Enable asynchronous memory copy operations. Default: True"""
TIR_ENABLE_DEBUG = "tir.enable_debug"
"""Enable debug information in generated code. Default: False"""
TIR_MERGE_STATIC_SMEM = "tir.merge_static_smem"
"""Merge static shared memory allocations. Default: True"""
TIR_ADD_LOWER_PASS = "tir.add_lower_pass"
"""Additional lowering passes to be applied. Default: None"""
TIR_NOALIAS = "tir.noalias"
"""Enable pointer non-aliasing assumptions. Default: True"""
CUDA_KERNELS_OUTPUT_DIR = "cuda.kernels_output_dir"
"""Output directory for generated CUDA kernels. Default: empty string"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment