Commit 02a0cf59 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Pass][Simplify] Introduce symbolic level simplify for condition expression (#634)

* [Enhancement] Add argument simplification option to StmtSimplifier

- Introduced a new `simplify_arguments` flag in the `StmtSimplifier::Apply` method to control argument simplification behavior.
- Updated the `Simplify` function to accept the new flag, allowing for enhanced flexibility in the simplification process.
- Adjusted the `LowerAndLegalize` and `_Simplify` functions to utilize the new argument, ensuring consistent behavior across the codebase.
- Added comments to clarify the purpose of the new flag and its impact on simplification logic.

* lint fix

* [Enhancement] Improve layout inference and reduce operation handling

- Updated `ParallelOp::InferLayout` to check for pure buffer stores, enhancing layout inference logic.
- Modified `ReduceOp::Lower` to include all threads in the AllReduce operation, improving performance on specific architectures.
- Added a TODO comment in `AllReduce` to consider merging synchronization barriers for optimization.

* lint fix

* [Enhancement] Add input validation for GEMM parameters

- Introduced checks to ensure that the dimensions M and N are divisible by their respective warp sizes (kMPerWarp and kNPerWarp) in the Gemm::ComputeWarpPartition method.
- Added informative error messages to assist in debugging when the input parameters do not meet the required conditions.

* bug fix
parent a0dfa516
......@@ -154,9 +154,16 @@ def main():
print(f"Sparsity Ratio: {sparsity}")
print(f"Best Kernel Latency: {best_latency:.6f} ms")
else:
kernel = blocksparse_matmul(M, N, K, DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K,
DEFAULT_NUM_STAGES, DEFAULT_THREAD_NUM,
DEFAULT_ENABLE_RASTERIZATION)
kernel = blocksparse_matmul(
M,
N,
K,
block_M=DEFAULT_BLOCK_M,
block_N=DEFAULT_BLOCK_N,
block_K=DEFAULT_BLOCK_K,
num_stages=DEFAULT_NUM_STAGES,
thread_num=DEFAULT_THREAD_NUM,
enable_rasteration=DEFAULT_ENABLE_RASTERIZATION)
block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K
print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})")
......
......@@ -65,6 +65,10 @@ std::pair<int, int> Gemm::ComputeWarpPartition(int num_warps, Target target,
constexpr int kNPerWarp = 8; // Columns processed by a single warp
bool allow_wgmma = TargetIsHopper(target) && maybe_hopper_wgmma &&
(this->M >= 64) && (num_warps % 4 == 0);
ICHECK(this->M % kMPerWarp == 0)
<< "M must be divisible by " << kMPerWarp << ", but got " << this->M;
ICHECK(this->N % kNPerWarp == 0)
<< "N must be divisible by " << kNPerWarp << ", but got " << this->N;
if (allow_wgmma) {
ICHECK(num_warps % 4 == 0) << "Warp-Group MMA requires 128×k threads.";
......
......@@ -197,7 +197,21 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
}
});
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access) {
// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
}
}
});
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
......
......@@ -225,8 +225,10 @@ Stmt ReduceOp::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
bool has_arch = T.target->attrs.count("arch") > 0;
if (has_arch && Downcast<String>(T.target->attrs["arch"]) == "sm_90") {
auto all_threads = T.thread_bounds->extent;
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run_hopper";
<< reducing_threads << ", " << (*scale) << ", " << all_threads
<< ">::run_hopper";
} else {
ss << "tl::AllReduce<" << this->MakeCodegenReducer() << ", "
<< reducing_threads << ", " << (*scale) << ">::run";
......
......@@ -51,6 +51,7 @@ struct AllReduce {
if constexpr (offset >= 32) {
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
red_buf[threadIdx.x] = x;
// TODO(lei): maybe we can merge the two bar.sync into one?
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
x = Reducer()(x, red_buf[threadIdx.x ^ offset]);
} else {
......
......@@ -207,7 +207,8 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
public:
static PrimFunc Apply(PrimFunc func, Analyzer *analyzer,
Optional<SimplifyConfig> config_opt = NullOpt) {
Optional<SimplifyConfig> config_opt = NullOpt,
bool simplify_arguments = false) {
auto config = config_opt.value_or(AttrsWithDefaultValues<SimplifyConfig>());
analyzer->rewrite_simplify.SetEnabledExtensions(
config->GetEnabledExtensions());
......@@ -243,8 +244,8 @@ public:
}
}
}
// return func;
if (param_updated) {
if (simplify_arguments && param_updated) {
return PrimFunc(new_params, func.CopyOnWrite()->body, func->ret_type,
new_buffer_map, func->attrs, func->span);
} else {
......@@ -437,6 +438,12 @@ private:
if (const int64_t *as_int = as_const_int(condition)) {
return Bool(*as_int);
} else {
// May have symbolic, need kSymbolicBound level prover.
if (analyzer_->CanProve(condition) ||
analyzer_->CanProve(condition,
arith::ProofStrength::kSymbolicBound)) {
return Bool(true);
}
return NullOpt;
}
}
......@@ -453,11 +460,11 @@ private:
using namespace tir::transform;
tvm::transform::Pass Simplify() {
tvm::transform::Pass Simplify(bool simplify_arguments = true) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
arith::Analyzer analyzer;
auto cfg = ctx->GetConfig<SimplifyConfig>("tl.Simplify");
return StmtSimplifier::Apply(f, &analyzer, cfg);
return StmtSimplifier::Apply(f, &analyzer, cfg, simplify_arguments);
};
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
......
......@@ -89,7 +89,10 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.AlignDynamicSharedMemoryAllocations(16)(mod)
# Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks
mod = tir.transform.Simplify()(mod)
# use an enhanced pass to simplify the dynamic symbolics
# TODO(lei): return to tir pass when kSymbolicBound simplification
# is merged into tvm.
mod = tilelang.transform.Simplify()(mod)
# Try to vectorize loop with dynamic shape
mod = tilelang.transform.LoopVectorizeDynamic()(mod)
return mod
......
......@@ -5,7 +5,7 @@ from typing import Union, Callable
from . import _ffi_api
def Simplify():
def Simplify(simplify_arguments: bool = False):
"""Simplify
Returns
......@@ -13,16 +13,16 @@ def Simplify():
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.Simplify() # type: ignore
return _ffi_api.Simplify(simplify_arguments) # type: ignore
def _Simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
if isinstance(stmt, PrimFunc):
mod = Simplify()(IRModule.from_expr(stmt))
mod = Simplify(simplify_arguments=True)(IRModule.from_expr(stmt))
assert len(mod.functions) == 1, "Simplify should return a single function"
return list(mod.functions.values()).pop()
elif isinstance(stmt, IRModule):
return Simplify()(stmt)
return Simplify(simplify_arguments=True)(stmt)
else:
raise ValueError(f"Unsupported type: {type(stmt)}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment