Commit 02a0cf59 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Pass][Simplify] Introduce symbolic level simplify for condition expression (#634)

* [Enhancement] Add argument simplification option to StmtSimplifier

- Introduced a new `simplify_arguments` flag in the `StmtSimplifier::Apply` method to control argument simplification behavior.
- Updated the `Simplify` function to accept the new flag, allowing for enhanced flexibility in the simplification process.
- Adjusted the `LowerAndLegalize` and `_Simplify` functions to utilize the new argument, ensuring consistent behavior across the codebase.
- Added comments to clarify the purpose of the new flag and its impact on simplification logic.

* lint fix

* [Enhancement] Improve layout inference and reduce operation handling

- Updated `ParallelOp::InferLayout` to check for pure buffer stores, enhancing layout inference logic.
- Modified `ReduceOp::Lower` to include all threads in the AllReduce operation, improving performance on specific architectures.
- Added a TODO comment in `AllReduce` to consider merging synchronization barriers for optimization.

* lint fix

* [Enhancement] Add input validation for GEMM parameters

- Introduced checks to ensure that the dimensions M and N are divisible by their respective warp sizes (kMPerWarp and kNPerWarp) in the Gemm::ComputeWarpPartition method.
- Added informative error messages to assist in debugging when the input parameters do not meet the required conditions.

* bug fix
parent a0dfa516
...@@ -154,9 +154,16 @@ def main(): ...@@ -154,9 +154,16 @@ def main():
print(f"Sparsity Ratio: {sparsity}") print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms") print(f"Best Kernel Latency: {best_latency:.6f} ms")
else: else:
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K, kernel = blocksparse_matmul(
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM, M,
DEFAULT_ENABLE_RASTERIZATION) N,
K,
block_M=DEFAULT_BLOCK_M,
block_N=DEFAULT_BLOCK_N,
block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
......
...@@ -65,6 +65,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target, ...@@ -65,6 +65,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
constexpr int kNPerWarp = 8; // Columns processed by a single warp constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma && bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0); (this->M >= 64) && (num_warps % 4 == 0);
ICHECK(this->M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M;
ICHECK(this->N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N;
if (allow_wgmma) { if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads."; ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
......
...@@ -197,7 +197,21 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -197,7 +197,21 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
} }
}); });
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access) { // check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
}
}
});
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse(); auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd; Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
......
...@@ -225,8 +225,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -225,8 +225,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
bool has_arch = T.target->attrs.count("arch") > 0; bool has_arch = T.target->attrs.count("arch") > 0;
if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") { if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run_hopper"; << reducing_threads << ", " << (*scale) << ", " << all_threads
<< ">::run_hopper";
} else { } else {
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", " ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run"; << reducing_threads << ", " << (*scale) << ">::run";
......
...@@ -51,6 +51,7 @@ struct AllReduce { ...@@ -51,6 +51,7 @@ struct AllReduce {
if constexpr (offset >= 32) { if constexpr (offset >= 32) {
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads)); asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
red_buf[threadIdx.x] = x; red_buf[threadIdx.x] = x;
// TODO(lei): maybe we can merge the two bar.sync into one?
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads)); asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]); x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else { } else {
......
...@@ -207,7 +207,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig); ...@@ -207,7 +207,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer { class StmtSimplifier : public IRMutatorWithAnalyzer {
public: public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer, static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = NullOpt) { Optional<SimplifyConfig> config_opt = NullOpt,
bool simplify_arguments = false) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>()); auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions( analyzer->rewrite_simplify.SetEnabledExtensions(
config->GetEnabledExtensions()); config->GetEnabledExtensions());
...@@ -243,8 +244,8 @@ public: ...@@ -243,8 +244,8 @@ public:
} }
} }
} }
// return func;
if (param_updated) { if (simplify_arguments && param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type, return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span); new_buffer_map, func->attrs, func->span);
} else { } else {
...@@ -437,6 +438,12 @@ private: ...@@ -437,6 +438,12 @@ private:
if (const int64_t *as_int = as_const_int(condition)) { if (const int64_t *as_int = as_const_int(condition)) {
return Bool(*as_int); return Bool(*as_int);
} else { } else {
// May have symbolic, need kSymbolicBound level prover.
if (analyzer_->CanProve(condition) ||
analyzer_->CanProve(condition,
arith::ProofStrength::kSymbolicBound)) {
return Bool(true);
}
return NullOpt; return NullOpt;
} }
} }
...@@ -453,11 +460,11 @@ private: ...@@ -453,11 +460,11 @@ private:
using namespace tir::transform; using namespace tir::transform;
tvm::transform::Pass Simplify() { tvm::transform::Pass Simplify(bool simplify_arguments = true) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify"); auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg); return StmtSimplifier::Apply(f, &analyzer, cfg, simplify_arguments);
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {}); return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
} }
......
...@@ -89,7 +89,10 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -89,7 +89,10 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(16)(mod) mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(16)(mod)
# Simplify again to clean up any duplicated conditions # Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks # that may have been introduced by safety checks
mod = tir.transform.Simplify()(mod) # use an enhanced pass to simplify the dynamic symbolics
# TODO(lei): return to tir pass when kSymbolicBound simplification
# is merged into tvm.
mod = tilelang.transform.Simplify()(mod)
# Try to vectorize loop with dynamic shape # Try to vectorize loop with dynamic shape
mod = tilelang.transform.LoopVectorizeDynamic()(mod) mod = tilelang.transform.LoopVectorizeDynamic()(mod)
return mod return mod
......
...@@ -5,7 +5,7 @@ from typing import Union, Callable ...@@ -5,7 +5,7 @@ from typing import Union, Callable
from . import _ffi_api from . import _ffi_api
def Simplify(): def Simplify(simplify_arguments: bool = False):
"""Simplify """Simplify
Returns Returns
...@@ -13,16 +13,16 @@ def Simplify(): ...@@ -13,16 +13,16 @@ def Simplify():
fpass : tvm.transform.Pass fpass : tvm.transform.Pass
The result pass The result pass
""" """
return _ffi_api.Simplify() # type: ignore return _ffi_api.Simplify(simplify_arguments) # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]: def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc): if isinstance(stmt, PrimFunc):
mod = Simplify()(IRModule.from_expr(stmt)) mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function" assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop() return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule): elif isinstance(stmt, IRModule):
return Simplify()(stmt) return Simplify(simplify_arguments=True)(stmt)
else: else:
raise ValueError(f"Unsupported type: {type(stmt)}") raise ValueError(f"Unsupported type: {type(stmt)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment