Unverified Commit 5c11d245 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Merge bulk copy into copy and improve layout inference for bulk copy (#746)

* [Refactor] Merge bulk copy into copy and refactor layout inference for bulk copy

* Deleted the `bulk_copy` operator implementation and its header file as it is no longer needed.
* Introduced a new function `cuTensorMapType()` to return the data type for CUDA tensor mapping.
* Updated related files to reflect these changes, ensuring that the codebase remains clean and maintainable.

* lint fix

* Fix typos in intrinsic names and remove unused print statement in block_sparse_attn_tilelang.py. Updated references from `ptx_ldmatirx` to `ptx_ldmatrix` across multiple files for consistency.

* remove bulk copy

* Refactor copy and atomic add operations to support TMA lower configuration

- Updated `GetCopyInst` to accept a `disable_tma_lower` parameter, allowing for conditional usage of TMA in bulk load/store operations.
- Modified `Lower` method in `Copy` to incorporate the new TMA configuration.
- Refactored `AtomicAdd::Lower` to streamline layout inference and vectorization logic.
- Removed unused `disable_tma_lower` field from `LowerArgs` structure for clarity.
- Enhanced atomic add vectorization by replacing the buggy implementation with a more robust loop vectorization approach.

* Enhance TMA bulk copy logic in `LowerBulkCopy` method

- Added a condition to set `desc.swizzle` to `CU_TENSOR_MAP_SWIZZLE_NONE` when `shared_layout` matches `linear_layout`, improving clarity in layout handling.
- Updated warning log to provide more detailed information about fallback scenarios, including source and destination buffer names and shapes, enhancing debugging capabilities.

* lint fix

* Remove fallback logging for non-swizzled global layout in `LowerBulkCopy` method to streamline the bulk copy logic. This change enhances code clarity by eliminating unnecessary warning messages related to inner box dimensions.

* Enhance reshape kernel compilation in `run_reshape` and `run_reshape_smem_1d_2_2d` functions

- Updated the `tl.compile` method to include `pass_configs` that disable TMA lower and warp specialization, addressing shared memory layout transformation limitations.
- Added TODO comments to indicate the need for further improvements in shared memory handling.

* Update `native_sparse_attention` function to include TMA configuration options

- Added `pass_configs` to the JIT decorator to disable TMA lower and warp specialization, addressing potential issues with shared memory layout transformations.
- Updated comments to clarify modifications in tensor shapes for inference, specifically setting `q` sequence length to 1.

* Refactor JIT decorator formatting in `native_sparse_attention` function

- Improved readability by reformatting the JIT decorator parameters for `native_sparse_attention`, ensuring consistent style across the codebase.
- No functional changes were made; this update focuses on code clarity and maintainability.

* Enhance thread management and logging in TileLang compilation

- Added a method to check if printing is enabled during compilation, improving control over logging behavior.
- Updated the JIT kernel class to utilize the new method for logging compilation status, ensuring consistent and clear output.
- Added comments to clarify the purpose of changes and improve code readability.

* Add warp specialization scope and refactor register management in TileLang

- Introduced a new constant `kWarpSpecializationScope` in `builtin.h` for better attribute management.
- Removed the `SetMaxNRegCollector` class and its related logic from `warp_specialized_rewriter.cc`, streamlining the warp specialization process.
- Added functions `annotate_producer_reg_dealloc` and `annotate_consumer_reg_alloc` in `builtin.py` to facilitate register management.
- Implemented `AnnotateWarpGroupRegAlloc` in `__init__.py` to inject register allocation calls into warp-specialized functions, enhancing the overall register handling in the compilation process.

* Refactor test for InjectSetMaxNReg pass in TileLang

- Improved readability by restructuring conditional checks and assertions in the test cases.
- Enhanced clarity in the collection of `set_max_nreg` calls by simplifying the logic.
- Ensured consistent formatting and spacing throughout the test functions for better maintainability.

* Enhance bulk copy and store checks in `Copy` class

- Updated scope validation for source and destination tensors in `CheckBulkLoad` and `CheckBulkStore` methods to include both `shared.dyn` and `shared` as valid options.
- Modified `CheckLDSMCopy` and `CheckSTSMCopy` methods to accommodate the new scope validation, ensuring compatibility with shared memory configurations.
- Improved logging in `LowerBulkCopy` to provide clearer warnings regarding unsupported swizzle layouts, including source and destination names for better debugging.

* lint fix
parent cb37bfef
...@@ -104,6 +104,13 @@ bool TargetHasStmatrix(Target target) { ...@@ -104,6 +104,13 @@ bool TargetHasStmatrix(Target target) {
return arch >= 90; return arch >= 90;
} }
bool TargetHasBulkCopy(Target target) {
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 90;
}
int TargetGetWarpSize(Target target) { int TargetGetWarpSize(Target target) {
int res = 32; int res = 32;
if (TargetIsCDNA(target)) if (TargetIsCDNA(target))
......
...@@ -25,6 +25,7 @@ bool TargetIsCDNA(Target target); ...@@ -25,6 +25,7 @@ bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target); bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target); bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target); bool TargetHasStmatrix(Target target);
bool TargetHasBulkCopy(Target target);
int TargetGetWarpSize(Target target); int TargetGetWarpSize(Target target);
} // namespace tl } // namespace tl
......
/*!
* \file annotate_warp_group_reg_alloc.cc
* \brief Annotate warp group reg alloc for warp specialization
*/
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_set>
#include <vector>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
class SetMaxNRegCollector : public StmtExprVisitor {
public:
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
: collector.nreg_;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24)
<< "Invalid reg hint: " << reg_hint;
ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;
// producer should decrease register hint while consumer should increase
// register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(no_set_max_nreg())) {
has_no_set_max_nreg_ = true;
}
}
StmtExprVisitor::VisitStmt_(op);
}
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
};
class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f);
f.CopyOnWrite()->body = T(f->body);
return f;
}
private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg()) ||
call->op.same_as(no_set_max_nreg())) {
// Remove the original set_max_nreg calls as they will be re-inserted
// at appropriate locations
return Evaluate(0);
}
}
return StmtExprMutator::VisitStmt_(op);
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent &&
Downcast<IterVar>(op->node)->thread_tag == "threadIdx.x") {
thread_iv_ = Downcast<IterVar>(op->node);
need_update_thread_extent_ = false;
AttrStmt attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(op));
if (need_update_thread_extent_) {
thread_iv_.CopyOnWrite()->dom = {0, updated_thread_extent_.value()};
attr_stmt.CopyOnWrite()->node = thread_iv_;
attr_stmt.CopyOnWrite()->value = updated_thread_extent_.value();
}
thread_iv_ = {};
return attr_stmt;
} else if (op->attr_key == attr::kWarpSpecializationScope) {
auto if_then_else = Downcast<IfThenElse>(op->body);
if (!if_then_else.defined()) {
return StmtExprMutator::VisitStmt_(op);
}
auto producer_body = if_then_else->then_case;
Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value;
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
// Only inject if we have valid register hints and no SIMT copy
// For now, we assume no SIMT copy detection is available here
// TODO: Add SIMT copy detection if needed
bool has_simt_copy = false; // Placeholder
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
}
// Inject register setting statements
Array<Stmt> producer_stmts;
producer_stmts.push_back(dec_reg_stmt);
producer_stmts.push_back(producer_body);
auto new_producer_body = SeqStmt(producer_stmts);
Array<Stmt> consumer_stmts;
consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
return new_attr;
} else {
return StmtExprMutator::VisitStmt_(op);
}
}
Array<IntImm> nreg_;
IterVar thread_iv_;
Optional<PrimExpr> updated_thread_extent_;
bool need_update_thread_extent_ = false;
};
using namespace tir::transform;
tvm::transform::Pass AnnotateWarpGroupRegAlloc() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) -> PrimFunc {
return SetMaxNRegInjector::Inject(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.AnnotateWarpGroupRegAlloc", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.AnnotateWarpGroupRegAlloc",
AnnotateWarpGroupRegAlloc);
});
} // namespace tl
} // namespace tvm
...@@ -203,12 +203,9 @@ private: ...@@ -203,12 +203,9 @@ private:
Stmt body = Substitute(fnode->body, vmap); Stmt body = Substitute(fnode->body, vmap);
return For(outer_var, 0, extent / vector_size_, fnode->kind, body, return For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
} else {
return fnode;
} }
} else {
return ret;
} }
return ret;
} }
PrimExpr VisitExpr_(const CallNode *node) final { PrimExpr VisitExpr_(const CallNode *node) final {
......
...@@ -57,7 +57,7 @@ public: ...@@ -57,7 +57,7 @@ public:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
Proxy proxy = Proxy::kAsync; Proxy proxy = Proxy::kAsync;
if (auto call = op->value.as<CallNode>()) { if (auto call = op->value.as<CallNode>()) {
if (call->op.same_as(ptx_ldmatirx()) || if (call->op.same_as(ptx_ldmatrix()) ||
call->op.same_as(ptx_stmatrix())) { call->op.same_as(ptx_stmatrix())) {
proxy = Proxy::kGeneric; proxy = Proxy::kGeneric;
} }
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h" #include "../runtime/runtime.h"
namespace tvm { namespace tvm {
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h" #include "../runtime/runtime.h"
namespace tvm { namespace tvm {
......
...@@ -430,11 +430,6 @@ private: ...@@ -430,11 +430,6 @@ private:
return workspace.access_ptr(2); // write return workspace.access_ptr(2); // write
}; };
// Get pass config `tl.disable_tma_lower`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_tma_lower =
ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
bool disable_tma_lower = opt_disable_tma_lower.value_or(Bool(false));
Range thread_bounds; Range thread_bounds;
if (analyzer_->const_int_bound.IsBound(thread_var_->var)) { if (analyzer_->const_int_bound.IsBound(thread_var_->var)) {
...@@ -449,9 +444,9 @@ private: ...@@ -449,9 +444,9 @@ private:
thread_bounds = Range::FromMinExtent(0, 1); thread_bounds = Range::FromMinExtent(0, 1);
} }
auto lowered = tile_op->Lower( auto lowered =
LowerArgs{target_, thread_bounds, thread_var_->var, callback, tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
layout_map_, buffer_remap_, disable_tma_lower}, callback, layout_map_, buffer_remap_},
analyzer_); analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
......
...@@ -10,7 +10,6 @@ ...@@ -10,7 +10,6 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h" #include "../runtime/runtime.h"
namespace tvm { namespace tvm {
......
...@@ -1146,42 +1146,6 @@ private: ...@@ -1146,42 +1146,6 @@ private:
bool has_simt_copy_ = false; bool has_simt_copy_ = false;
}; };
class SetMaxNRegCollector : public StmtExprVisitor {
public:
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
: collector.nreg_;
}
private:
void VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg())) {
int reg_hint = call->args[0].as<IntImmNode>()->value;
int is_inc = call->args[1].as<IntImmNode>()->value;
ICHECK(reg_hint <= 240 && reg_hint >= 24)
<< "Invalid reg hint: " << reg_hint;
ICHECK(is_inc == 0 || is_inc == 1) << "Invalid is_inc: " << is_inc;
// producer should decrease register hint while consumer should increase
// register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(no_set_max_nreg())) {
has_no_set_max_nreg_ = true;
}
}
StmtExprVisitor::VisitStmt_(op);
}
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
};
class WarpSpecializedRewriter : public StmtExprMutator { class WarpSpecializedRewriter : public StmtExprMutator {
public: public:
WarpSpecializedRewriter(bool disable_warp_specialized, WarpSpecializedRewriter(bool disable_warp_specialized,
...@@ -1202,7 +1166,6 @@ public: ...@@ -1202,7 +1166,6 @@ public:
auto T = WarpSpecializedRewriter(disable_warp_specialized, auto T = WarpSpecializedRewriter(disable_warp_specialized,
disable_shuffle_elect); disable_shuffle_elect);
T.nreg_ = SetMaxNRegCollector::Collect(f);
T.buffer_lca_ = DetectBufferAccessLCA(f); T.buffer_lca_ = DetectBufferAccessLCA(f);
for (auto [buffer, _] : T.buffer_lca_) for (auto [buffer, _] : T.buffer_lca_)
T.buffer_data_to_buffer_.Set(buffer->data, buffer); T.buffer_data_to_buffer_.Set(buffer->data, buffer);
...@@ -1229,16 +1192,6 @@ private: ...@@ -1229,16 +1192,6 @@ private:
} }
} }
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(set_max_nreg()) ||
call->op.same_as(no_set_max_nreg())) {
return Evaluate(0);
}
}
return StmtExprMutator::VisitStmt_(op);
}
// If users define a thread binding, we will replace the thread binding with // If users define a thread binding, we will replace the thread binding with
// threadIdx.x We require the thread binding is threadIdx.x, and the extent is // threadIdx.x We require the thread binding is threadIdx.x, and the extent is
// the same as the thread extent // the same as the thread extent
...@@ -1334,22 +1287,6 @@ private: ...@@ -1334,22 +1287,6 @@ private:
if (!marker.HasSimtCopy()) if (!marker.HasSimtCopy())
producer_thread_extent = 128; producer_thread_extent = 128;
// TODO: estimate the correct reg usage.
int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value;
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0 && !marker.HasSimtCopy()) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), set_max_nreg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
}
producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
updated_thread_extent_ = consumer_thread_extent + producer_thread_extent; updated_thread_extent_ = consumer_thread_extent + producer_thread_extent;
producer_code = ThreadIdxRewriter::Rewrite( producer_code = ThreadIdxRewriter::Rewrite(
...@@ -1382,7 +1319,7 @@ private: ...@@ -1382,7 +1319,7 @@ private:
// Add an attr here to handle the partial thread count in ThreadSync pass. // Add an attr here to handle the partial thread count in ThreadSync pass.
Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent), Array<IntImm> ws_partition = {Downcast<IntImm>(producer_thread_extent),
Downcast<IntImm>(consumer_thread_extent)}; Downcast<IntImm>(consumer_thread_extent)};
body = AttrStmt(ws_partition, "kWarpSpecializationScope", 0, body); body = AttrStmt(ws_partition, attr::kWarpSpecializationScope, 0, body);
block.CopyOnWrite()->body = SeqStmt({init_barrier, body}); block.CopyOnWrite()->body = SeqStmt({init_barrier, body});
block_realize.CopyOnWrite()->block = block; block_realize.CopyOnWrite()->block = block;
...@@ -1399,17 +1336,26 @@ private: ...@@ -1399,17 +1336,26 @@ private:
bool need_update_thread_extent_ = false; bool need_update_thread_extent_ = false;
bool disable_warp_specialized_ = false; bool disable_warp_specialized_ = false;
bool disable_shuffle_elect_ = false; bool disable_shuffle_elect_ = false;
Array<IntImm> nreg_;
bool only_has_wgmma_ = false; bool only_has_wgmma_ = false;
}; };
class WarpSpecializedDetector : public IRVisitorWithAnalyzer { class WarpSpecializedDetector : public IRVisitorWithAnalyzer {
public: public:
// return true means this aws will be disabled
static bool Detect(Stmt stmt, bool skip_thread_partition = false) { static bool Detect(Stmt stmt, bool skip_thread_partition = false) {
WarpSpecializedDetector detector; WarpSpecializedDetector detector;
detector.VisitStmt(stmt); detector.VisitStmt(stmt);
return detector.has_warp_specialization_ || if (detector.has_warp_specialization_) {
(detector.has_tma_op_ && detector.has_mbarrier_op_); LOG(WARNING) << "Auto warp specialization will be disabled because warp "
"specialization is manually enabled";
return true;
}
if (detector.has_tma_op_ && detector.has_mbarrier_op_) {
LOG(WARNING) << "Auto warp specialization will be disabled because TMA "
"and mbarrier are both present";
return true;
}
return false;
} }
WarpSpecializedDetector() { WarpSpecializedDetector() {
......
...@@ -20,7 +20,15 @@ def reshape_test(N, M, dtype): ...@@ -20,7 +20,15 @@ def reshape_test(N, M, dtype):
def run_reshape(N, M, dtype): def run_reshape(N, M, dtype):
program = reshape_test(N, M, dtype) program = reshape_test(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1) # TODO(lei): reshape cannot apply shared memory
# layout transform propagation
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
...@@ -56,7 +64,15 @@ def reshape_test_smem_1d_2_2d(N, M, dtype): ...@@ -56,7 +64,15 @@ def reshape_test_smem_1d_2_2d(N, M, dtype):
def run_reshape_smem_1d_2_2d(N, M, dtype): def run_reshape_smem_1d_2_2d(N, M, dtype):
program = reshape_test_smem_1d_2_2d(N, M, dtype) program = reshape_test_smem_1d_2_2d(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1) # TODO(lei): reshape cannot apply shared memory
# layout transform propagation
jit_kernel = tl.compile(
program,
out_idx=-1,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = jit_kernel.get_profiler() profiler = jit_kernel.get_profiler()
def ref_program(A): def ref_program(A):
......
...@@ -70,7 +70,7 @@ def matmul_sp( ...@@ -70,7 +70,7 @@ def matmul_sp(
backend="cutlass", backend="cutlass",
block_k=block_K), block_k=block_K),
}) })
T.no_set_max_nreg() T.disable_warp_group_reg_alloc()
T.clear(C_local) T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(E[by * block_M, k * block_K // E_factor], E_shared) T.copy(E[by * block_M, k * block_K // E_factor], E_shared)
......
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
import tilelang.testing
from tvm import tir
tilelang.disable_cache()
def test_inject_set_max_nreg():
"""Test the InjectSetMaxNReg pass"""
@T.prim_func
def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16")):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
with T.block(""):
T.reads(A[by * 64, 0:512], B[0:512, bx * 64])
T.writes()
# Add set_max_nreg hints
T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24
T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if v >= 128:
# Producer branch - should have set_max_nreg(24, 0)
for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3 + 3), T.bitwise_xor(k // 3 % 2, 1))
if v - 128 == 0:
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1,
0, 2, 2, 0), T.get_mbarrier(k % 3),
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, by * 64)
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3)]))
else:
# Consumer branch - should have set_max_nreg(240, 1)
for k in range(16):
T.mbarrier_wait_parity(T.get_mbarrier(k % 3), k // 3 % 2)
T.call_extern(
"handle", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(
T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(
T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3))
T.evaluate(
tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
# Apply the InjectSetMaxNReg pass
func = before
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# Check that set_max_nreg calls are properly injected
main_func = mod["main"]
set_max_nreg_calls = []
def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# We should have at least 2 set_max_nreg calls (one for producer, one for consumer)
assert len(set_max_nreg_calls
) >= 2, f"Expected at least 2 set_max_nreg calls, got {len(set_max_nreg_calls)}"
# Check that we have the expected register values
reg_values = [call[0] for call in set_max_nreg_calls]
assert 24 in reg_values, f"Expected register value 24 in {reg_values}"
assert 240 in reg_values, f"Expected register value 240 in {reg_values}"
print("InjectSetMaxNReg test passed!")
def test_inject_set_max_nreg_no_set_max_nreg():
"""Test the InjectSetMaxNReg pass with no_set_max_nreg"""
@T.prim_func
def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")):
bx = T.launch_thread("blockIdx.x", 8)
v = T.launch_thread("threadIdx.x", 128)
with T.block(""):
T.reads(A[bx * 64, 0:64])
T.writes()
# Add no_set_max_nreg to disable register hinting
T.disable_warp_group_reg_alloc()
T.create_list_of_mbarrier(128, 128)
T.attr([128, 128], "kWarpSpecializationScope", 0)
if v >= 128:
# Producer branch - should not have set_max_nreg calls
T.evaluate(0)
else:
# Consumer branch - should not have set_max_nreg calls
T.evaluate(0)
# Apply the InjectSetMaxNReg pass
func = before_no_set_max_nreg
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# Check that no set_max_nreg calls are injected when no_set_max_nreg is present
main_func = mod["main"]
set_max_nreg_calls = []
def collect_set_max_nreg(stmt):
if (isinstance(stmt, tvm.tir.Evaluate) and hasattr(stmt.value, 'op') and
hasattr(stmt.value.op, 'name') and stmt.value.op.name == "tl.set_max_nreg"):
set_max_nreg_calls.append(stmt.value)
tvm.tir.stmt_functor.post_order_visit(main_func.body, collect_set_max_nreg)
# Should have no set_max_nreg calls when no_set_max_nreg is present
assert len(
set_max_nreg_calls
) == 0, f"Expected 0 set_max_nreg calls when no_set_max_nreg is present, got {len(set_max_nreg_calls)}"
print("InjectSetMaxNReg with no_set_max_nreg test passed!")
if __name__ == "__main__":
tilelang.testing.main()
...@@ -101,6 +101,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -101,6 +101,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod) mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
# if tma is not enabled, we can also do pipeline planning # if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy # to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
......
...@@ -232,6 +232,9 @@ class Environment: ...@@ -232,6 +232,9 @@ class Environment:
def disable_cache(self) -> None: def disable_cache(self) -> None:
CacheState.disable() CacheState.disable()
def is_print_on_compilation_enabled(self) -> bool:
return self.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on")
# Instantiate as a global configuration object # Instantiate as a global configuration object
env = Environment() env = Environment()
......
...@@ -117,7 +117,7 @@ class JITKernel(object): ...@@ -117,7 +117,7 @@ class JITKernel(object):
# Print log on compilation starts # Print log on compilation starts
# NOTE(Chenggang): printing could let the training/inference framework easier to know # NOTE(Chenggang): printing could let the training/inference framework easier to know
# whether the communication timeout is from compilation # whether the communication timeout is from compilation
if env.TILELANG_PRINT_ON_COMPILATION.lower() in ("1", "true", "yes", "on"): if env.is_print_on_compilation_enabled():
# assert func must have "global_symbol" # assert func must have "global_symbol"
func_name = func.attrs.get("global_symbol") func_name = func.attrs.get("global_symbol")
assert func_name is not None, "func must have global_symbol" assert func_name is not None, "func must have global_symbol"
...@@ -126,6 +126,11 @@ class JITKernel(object): ...@@ -126,6 +126,11 @@ class JITKernel(object):
# Compile the TileLang function and create a kernel adapter for execution. # Compile the TileLang function and create a kernel adapter for execution.
adapter = self._compile_and_create_adapter(func, out_idx) adapter = self._compile_and_create_adapter(func, out_idx)
if env.is_print_on_compilation_enabled():
func_name = func.attrs.get("global_symbol")
assert func_name is not None, "func must have global_symbol"
logger.info(f"TileLang completes to compile kernel `{func_name}`")
# The adapter's function is assigned as the callable function for this instance. # The adapter's function is assigned as the callable function for this instance.
self.adapter = adapter self.adapter = adapter
self.torch_function = adapter.func self.torch_function = adapter.func
......
...@@ -142,12 +142,30 @@ def dec_max_nreg(reg_count: int): ...@@ -142,12 +142,30 @@ def dec_max_nreg(reg_count: int):
return set_max_nreg(reg_count, 0) return set_max_nreg(reg_count, 0)
def annotate_producer_reg_dealloc(reg_count: int = 24):
"""Annotate the producer reg dealloc.
"""
return dec_max_nreg(reg_count)
def annotate_consumer_reg_alloc(reg_count: int = 240):
"""Annotate the consumer reg alloc.
"""
return inc_max_nreg(reg_count)
def no_set_max_nreg(): def no_set_max_nreg():
"""Disable the maximum register limit setting. """Disable the maximum register limit setting.
""" """
return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg")) return tir.call_intrin("handle", tir.op.Op.get("tl.no_set_max_nreg"))
def disable_warp_group_reg_alloc():
"""Disable the warp group reg alloc.
"""
return no_set_max_nreg()
def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]): def mbarrier_wait_parity(mbarrier: Union[int, PrimExpr, tir.Call], parity: Union[int, Var]):
"""Wait for memory barrier parity condition. """Wait for memory barrier parity condition.
......
...@@ -189,6 +189,21 @@ def WarpSpecialized(): ...@@ -189,6 +189,21 @@ def WarpSpecialized():
return _ffi_api.WarpSpecialized() # type: ignore return _ffi_api.WarpSpecialized() # type: ignore
def AnnotateWarpGroupRegAlloc():
"""Inject set_max_nreg calls into warp-specialized functions.
This pass analyzes the function to collect register hints from set_max_nreg
and no_set_max_nreg calls, then injects appropriate set_max_nreg calls into
producer and consumer branches of warp-specialized code.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.AnnotateWarpGroupRegAlloc() # type: ignore
def InjectTmaBarrier(): def InjectTmaBarrier():
"""InjectTmaBarrier """InjectTmaBarrier
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment