Unverified Commit 7e8d1f82 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Enhance let binding handling in layout inference and warp specialized pass (#1484)

* [Feature] Add FullyReplicated Fragment Layout and Enhance Layout Inference

* Introduced a new static method `FullyReplicated` in the `Fragment` class to create fully replicated fragment layouts, ensuring all threads hold identical copies of the buffer.
* Updated `CopyNode` to collect fragment layouts and mark them as fully replicated during layout inference.
* Enhanced `ParallelOpNode` to expand let bindings for fragment buffer accesses, improving layout inference accuracy.
* Added documentation for new methods and updated existing methods to support the new layout features.

* lint fix

* Remove debug logging statements from layout inference process to streamline output and improve performance.
parent 168aec7b
...@@ -549,6 +549,12 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index, ...@@ -549,6 +549,12 @@ Fragment::Fragment(Array<PrimExpr> input_size, Array<PrimExpr> forward_index,
data_ = std::move(n); data_ = std::move(n);
} }
Fragment Fragment::FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent) {
return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent,
std::nullopt);
}
// which means the forward_thread is rep_var -> lambda i, rep: rep // which means the forward_thread is rep_var -> lambda i, rep: rep
bool FragmentNode::IsCompletedReplicated() const { bool FragmentNode::IsCompletedReplicated() const {
arith::Analyzer analyzer; arith::Analyzer analyzer;
......
...@@ -175,6 +175,20 @@ public: ...@@ -175,6 +175,20 @@ public:
PrimExpr forward_thread, PrimExpr replicate_size, PrimExpr forward_thread, PrimExpr replicate_size,
Optional<Var> replicate_var); Optional<Var> replicate_var);
/*!
* \brief Create a fully replicated fragment layout.
*
* A fully replicated fragment means all threads hold identical copies of the
* entire buffer. This is useful for index buffers or masks that need to be
* accessed uniformly across all threads.
*
* \param shape The shape of the buffer.
* \param thread_extent The number of threads.
* \return A Fragment where each thread has a complete copy of all elements.
*/
TVM_DLL static Fragment FullyReplicated(Array<PrimExpr> shape,
PrimExpr thread_extent);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Fragment, Layout, FragmentNode);
}; };
......
...@@ -555,14 +555,34 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -555,14 +555,34 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D; copy_inst == CopyInst::kBulkLoad || copy_inst == CopyInst::kBulkLoad1D;
Buffer global_tensor = is_load ? src : dst; Buffer global_tensor = is_load ? src : dst;
Buffer shared_tensor = is_load ? dst : src; Buffer shared_tensor = is_load ? dst : src;
Map<Buffer, Layout> result_map;
// Collect fragment buffers from indices and mark them as fully replicated
// For Bulk Load/Store, fragment buffers used as indices should be
// replicated across all threads
PrimExpr thread_extent = T.thread_bounds->extent;
for (const auto &range : src_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
for (const auto &range : dst_range) {
CollectFragmentLayouts(range->min, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
CollectFragmentLayouts(range->extent, T.let_var_to_expr, T.layout_map,
thread_extent, T.thread_bounds, result_map);
}
// check shared layout is non-swizzle // check shared layout is non-swizzle
// skip layout inference if shared layout is already annotated // skip layout inference if shared layout is already annotated
if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) { if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) {
// create a new layout map for tma linear layout // create a new layout map for tma linear layout
Layout linear_layout = ComputeLinearLayout(shared_tensor); Layout linear_layout = ComputeLinearLayout(shared_tensor);
return Map<Buffer, Layout>({{shared_tensor, linear_layout}}); result_map.Set(shared_tensor, linear_layout);
} }
return {}; return result_map;
} }
// for LDSM/STSM, the layout was deduced from register layout // for LDSM/STSM, the layout was deduced from register layout
// so we can directly apply the layout of normal copy // so we can directly apply the layout of normal copy
...@@ -571,7 +591,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T, ...@@ -571,7 +591,8 @@ LayoutMap CopyNode::InferLayout(const LayoutInferArgs &T,
arith::Analyzer analyzer; arith::Analyzer analyzer;
par_op_ = ParallelOp((MakeSIMTLoop(&analyzer))); par_op_ = ParallelOp((MakeSIMTLoop(&analyzer)));
} }
return par_op_->InferLayout(T, level); auto layout_map = par_op_->InferLayout(T, level);
return layout_map;
} }
/** /**
* @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA) * @brief Determine whether this CopyNode can be lowered to a Bulk Load (TMA)
...@@ -940,8 +961,13 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, ...@@ -940,8 +961,13 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict, std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree}; InferLevel::kFree};
for (auto level : levels) { for (auto level : levels) {
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, par_op->InferLayout({T.target,
false, T.buffer_remap}, T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
level); level);
} }
auto loop_layout = par_op->GetLoopLayout(); auto loop_layout = par_op->GetLoopLayout();
...@@ -2034,6 +2060,31 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const { ...@@ -2034,6 +2060,31 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
return args; return args;
} }
void CopyNode::CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent,
Range thread_bounds,
Map<Buffer, Layout> &result_map) const {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!existing_layouts.count(bl->buffer) &&
!result_map.count(bl->buffer)) {
auto f = Fragment::FullyReplicated(bl->buffer->shape, thread_extent);
result_map.Set(bl->buffer, f->BindThreadRange(thread_bounds));
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
CollectFragmentLayouts(let_var_to_expr[var], let_var_to_expr,
existing_layouts, thread_extent, thread_bounds,
result_map);
}
}
});
}
// Register the Copy operation with TVM's TIR system // Register the Copy operation with TVM's TIR system
// This makes the copy operation available for use in TVM programs // This makes the copy operation available for use in TVM programs
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
......
...@@ -269,6 +269,28 @@ protected: ...@@ -269,6 +269,28 @@ protected:
* @return Reference to the singleton TVM Op representing this operator. * @return Reference to the singleton TVM Op representing this operator.
*/ */
TileOperator Clone() const; TileOperator Clone() const;
private:
/*!
* \brief Collect fragment buffers from expression and create fully replicated
* layouts.
*
* Recursively searches the expression for BufferLoad nodes with
* "local.fragment" scope, following let bindings. For each found fragment
* buffer, creates a fully replicated layout and adds it to result_map.
*
* \param expr Expression to search.
* \param let_var_to_expr Map from let variables to their bound expressions.
* \param existing_layouts Existing layout map to check for already-inferred
* layouts. \param thread_extent Number of threads for replication. \param
* thread_bounds Thread bounds for binding the layout. \param result_map
* Output map to store collected fragment layouts.
*/
void CollectFragmentLayouts(const PrimExpr &expr,
const Map<Var, PrimExpr> &let_var_to_expr,
const LayoutMap &existing_layouts,
PrimExpr thread_extent, Range thread_bounds,
Map<Buffer, Layout> &result_map) const;
}; };
class Copy : public TileOperator { class Copy : public TileOperator {
......
...@@ -158,8 +158,13 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { ...@@ -158,8 +158,13 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (dst.scope() == "local.fragment") { if (dst.scope() == "local.fragment") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, par_op->InferLayout({T.target,
false, T.buffer_remap}, T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
...@@ -176,8 +181,13 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -176,8 +181,13 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") { dst.scope() == "global") {
auto par_op = ParallelOp(MakeSIMTLoop(analyzer)); auto par_op = ParallelOp(MakeSIMTLoop(analyzer));
par_op->InferLayout({T.target, T.thread_bounds, T.layout_map, analyzer, par_op->InferLayout({T.target,
false, T.buffer_remap}, T.thread_bounds,
T.layout_map,
analyzer,
false,
T.buffer_remap,
{}},
InferLevel::kFree); InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout()); par_op->GetLoopLayout());
......
...@@ -39,6 +39,9 @@ struct LowerArgs { ...@@ -39,6 +39,9 @@ struct LowerArgs {
AddWorkspaceCallback AddWorkspace; AddWorkspaceCallback AddWorkspace;
LayoutMap layout_map; LayoutMap layout_map;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
}; };
struct LayoutInferArgs { struct LayoutInferArgs {
...@@ -48,6 +51,9 @@ struct LayoutInferArgs { ...@@ -48,6 +51,9 @@ struct LayoutInferArgs {
arith::Analyzer *analyzer; arith::Analyzer *analyzer;
bool buffer_oob = false; bool buffer_oob = false;
Map<Buffer, Buffer> buffer_remap; Map<Buffer, Buffer> buffer_remap;
// Map from LetStmt variable to its bound expression, for resolving
// fragment buffer accesses through let bindings
Map<Var, PrimExpr> let_var_to_expr;
}; };
class TileOperator; class TileOperator;
......
...@@ -182,6 +182,34 @@ TileOperator ParallelOpNode::Clone() const { ...@@ -182,6 +182,34 @@ TileOperator ParallelOpNode::Clone() const {
return ParallelOp(op); return ParallelOp(op);
} }
void ParallelOpNode::ExpandLetBindings(
const Map<Var, PrimExpr> &let_var_to_expr) {
if (let_var_to_expr.empty())
return;
// Helper function to recursively find BufferLoads through let bindings
std::function<void(const PrimExpr &)> expand = [&](const PrimExpr &expr) {
PostOrderVisit(expr, [&](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.scope() == "local.fragment" &&
!indice_map_.count(bl->buffer)) {
indice_map_.Set(bl->buffer, bl->indices);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr.count(var)) {
expand(let_var_to_expr[var]);
}
}
});
};
// Scan all let bindings
for (const auto &[var, expr] : let_var_to_expr) {
expand(expr);
}
}
Stmt ParallelOpNode::Lower(const LowerArgs &T, Stmt ParallelOpNode::Lower(const LowerArgs &T,
arith::Analyzer *analyzer) const { arith::Analyzer *analyzer) const {
return root_; return root_;
...@@ -215,6 +243,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -215,6 +243,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
if (loop_layout_.defined()) if (loop_layout_.defined())
return {}; return {};
// Expand let bindings to find fragment buffer accesses
if (!T.let_var_to_expr.empty()) {
const_cast<ParallelOpNode *>(this)->ExpandLetBindings(T.let_var_to_expr);
}
if (level == InferLevel::kStrict) { if (level == InferLevel::kStrict) {
LayoutMap results; LayoutMap results;
// Deduce buffers that should be complicated replicated. // Deduce buffers that should be complicated replicated.
......
...@@ -105,6 +105,10 @@ private: ...@@ -105,6 +105,10 @@ private:
void AddPredicate(const PrimExpr &expr) const { void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
// Expand let bindings to find fragment buffer accesses and add them to
// indice_map_. This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0],
// ...)
void ExpandLetBindings(const Map<Var, PrimExpr> &let_var_to_expr);
// Allow ParallelLoopNestVisitor to access private members. // Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor; friend class ParallelLoopNestVisitor;
......
...@@ -110,9 +110,13 @@ public: ...@@ -110,9 +110,13 @@ public:
"required for layout inference."; "required for layout inference.";
// Run InferLayout // Run InferLayout
auto updates = auto updates = next->InferLayout(LayoutInferArgs{target_,
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, thread_bounds,
cur_analyzer, buffer_oob}, layout_map,
cur_analyzer,
buffer_oob,
{},
let_var_to_expr_},
level); level);
// Process the returned updates // Process the returned updates
...@@ -479,6 +483,10 @@ private: ...@@ -479,6 +483,10 @@ private:
} else if (auto buffer = getBufferFromRegion(arg)) { } else if (auto buffer = getBufferFromRegion(arg)) {
addToUseList(buffer.value()); addToUseList(buffer.value());
} }
// Check if the argument uses any LetStmt variables that reference
// fragment buffers. If so, add those buffers to the use list.
// This handles cases like: a = block_mask_f[i]; T.copy(A[a, 0], ...)
CollectFragmentBuffersFromExpr(arg);
} }
// Compute thread_var_ and thread_bounds_ // Compute thread_var_ and thread_bounds_
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
...@@ -754,6 +762,30 @@ private: ...@@ -754,6 +762,30 @@ private:
IRVisitorWithAnalyzer::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
void VisitStmt_(const LetStmtNode *op) final {
// Record Let variable to its bound expression.
// This enables tracking fragment buffer accesses through let bindings.
let_var_to_expr_.Set(op->var, op->value);
IRVisitorWithAnalyzer::VisitStmt_(op);
}
// Helper: recursively collect fragment buffers from an expression,
// following let bindings chain.
void CollectFragmentBuffersFromExpr(const PrimExpr &expr) {
PostOrderVisit(expr, [this](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
if (bl->buffer.defined() && bl->buffer.scope() == "local.fragment") {
addToUseList(bl->buffer);
}
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
if (let_var_to_expr_.count(var)) {
CollectFragmentBuffersFromExpr(let_var_to_expr_[var]);
}
}
});
}
void VisitExpr_(const BufferLoadNode *op) final { void VisitExpr_(const BufferLoadNode *op) final {
// Collect buffer from BufferLoad // Collect buffer from BufferLoad
if (op->buffer.defined() && op->buffer->data.defined()) { if (op->buffer.defined() && op->buffer->data.defined()) {
...@@ -815,6 +847,8 @@ private: ...@@ -815,6 +847,8 @@ private:
} }
Map<Var, Array<Buffer>> buffer_data_to_buffers_; Map<Var, Array<Buffer>> buffer_data_to_buffers_;
// Map from LetStmt variable to its bound expression
Map<Var, PrimExpr> let_var_to_expr_;
std::vector<ObjectRef> infer_list_stmt_; std::vector<ObjectRef> infer_list_stmt_;
std::vector<TileOperator> infer_list_; std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
......
...@@ -213,8 +213,7 @@ private: ...@@ -213,8 +213,7 @@ private:
const auto &buffer = opt_buffer.value(); const auto &buffer = opt_buffer.value();
Fragment f; Fragment f;
if (info->rep == ReducerRepType::ALL) { if (info->rep == ReducerRepType::ALL) {
f = Fragment(buffer->shape, {}, ReplicationPlaceholder(), f = Fragment::FullyReplicated(buffer->shape, thread_extent);
thread_extent, std::nullopt);
} else if (info->rep == ReducerRepType::NONE) { } else if (info->rep == ReducerRepType::NONE) {
PrimExpr flatten_idx = InputPlaceholder(0); PrimExpr flatten_idx = InputPlaceholder(0);
for (int i = 1; i < buffer->shape.size(); ++i) for (int i = 1; i < buffer->shape.size(); ++i)
......
...@@ -638,9 +638,15 @@ private: ...@@ -638,9 +638,15 @@ private:
thread_bounds = Range::FromMinExtent(0, 1); thread_bounds = Range::FromMinExtent(0, 1);
} }
auto lowered = // Convert let_bindings_ to Map<Var, PrimExpr> for LowerArgs
tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var, Map<Var, PrimExpr> let_var_to_expr;
callback, layout_map_, buffer_remap_}, for (const auto &[var, expr] : let_bindings_) {
let_var_to_expr.Set(var, expr);
}
auto lowered = tile_op->Lower(
LowerArgs{target_, thread_bounds, thread_var_->var, callback,
layout_map_, buffer_remap_, let_var_to_expr},
analyzer_); analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered); return IRMutatorWithAnalyzer::VisitStmt(lowered);
} }
......
...@@ -50,6 +50,7 @@ class ProducerUsedBufferFinder : public StmtExprVisitor { ...@@ -50,6 +50,7 @@ class ProducerUsedBufferFinder : public StmtExprVisitor {
public: public:
auto FindProducerusedBuffer(const Stmt &stmt) { auto FindProducerusedBuffer(const Stmt &stmt) {
producer_buffers_.clear(); producer_buffers_.clear();
let_var_to_expr_.clear();
std::unordered_set<const BufferNode *> last_producer_buffers_; std::unordered_set<const BufferNode *> last_producer_buffers_;
for (;;) { for (;;) {
VisitStmt(stmt); VisitStmt(stmt);
...@@ -68,6 +69,28 @@ public: ...@@ -68,6 +69,28 @@ public:
for (const auto &buffer : usage.buffer_use_count_) { for (const auto &buffer : usage.buffer_use_count_) {
producer_buffers_.insert(buffer.first); producer_buffers_.insert(buffer.first);
} }
// Also collect buffers through let bindings
CollectBuffersFromExpr(expr);
}
// Collect buffers from expression, following let bindings
void CollectBuffersFromExpr(const PrimExpr &expr) {
PostOrderVisit(expr, [this](const ObjectRef &node) {
if (auto bl = node.as<BufferLoadNode>()) {
producer_buffers_.insert(bl->buffer.get());
} else if (auto var_node = node.as<VarNode>()) {
auto var = tvm::ffi::GetRef<Var>(var_node);
auto it = let_var_to_expr_.find(var.get());
if (it != let_var_to_expr_.end()) {
CollectBuffersFromExpr(it->second);
}
}
});
}
void VisitStmt_(const LetStmtNode *op) final {
let_var_to_expr_[op->var.get()] = op->value;
StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const IfThenElseNode *op) final { void VisitStmt_(const IfThenElseNode *op) final {
...@@ -102,15 +125,15 @@ public: ...@@ -102,15 +125,15 @@ public:
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
for (auto arg : op->args) { for (auto arg : op->args) {
if (auto buffer_load = arg.as<BufferLoadNode>()) { // Collect buffers from args, including through let bindings
producer_buffers_.insert(buffer_load->buffer.get()); CollectBuffersFromExpr(arg);
}
} }
} }
} }
private: private:
std::unordered_set<const BufferNode *> producer_buffers_; std::unordered_set<const BufferNode *> producer_buffers_;
std::unordered_map<const VarNode *, PrimExpr> let_var_to_expr_;
}; };
class WarpSpecializedRoleMarker : public StmtVisitor { class WarpSpecializedRoleMarker : public StmtVisitor {
......
"""
Test layout inference for LetStmt expressions.
This test validates that TileLang correctly handles layout inference when
fragment buffer accesses occur through let bindings. For example:
block_mask_f = T.alloc_fragment((N_S,), T.int32)
T.copy(BlockMask[by, :], block_mask_f)
for i in T.Pipelined(N_S):
a = block_mask_f[i] # LetStmt: a is bound to fragment buffer load
T.copy(A[a, 0], A_shared) # a is used as index in TMA copy
Key scenarios tested:
1. Fragment buffer layout inference through let bindings
2. TMA (Tensor Memory Accelerator) copy with let-bound indices
3. CP.ASYNC copy with let-bound indices
4. Warp specialization with let-bound fragment accesses
"""
import tilelang
import tilelang.language as T
import tilelang.testing
import torch
def blocksparse_copy_kernel(M, N, N_S, block_M, block_N, dtype=T.float16):
"""BlockSparse copy kernel using fragment for block mask indices."""
block_mask_shape = (M // block_M, N_S)
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M, N), dtype),
BlockMask: T.Tensor(block_mask_shape, T.int32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M)) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), dtype)
B_shared = T.alloc_shared((block_M, block_N), dtype)
block_mask_f = T.alloc_fragment((N_S,), T.int32)
T.clear(B_shared)
T.copy(BlockMask[by, :], block_mask_f)
for i in T.Pipelined(N_S):
a = block_mask_f[i] # LetStmt: fragment buffer access
if a >= 0:
T.copy(A[a, 0], A_shared)
T.copy(A_shared, B[by * block_M : (by + 1) * block_M, i * block_N : (i + 1) * block_N])
return main
def ref_blocksparse_copy(A, B, BlockMask, M, N, N_S, block_M, block_N):
"""Reference implementation for blocksparse copy."""
ref_B = B.clone()
num_row_blocks = M // block_M
for by in range(num_row_blocks):
for i in range(N_S):
src_row_start = BlockMask[by, i].item()
ref_B[by * block_M : (by + 1) * block_M, i * block_N : (i + 1) * block_N] = A[
src_row_start : src_row_start + block_M, 0:block_N
]
return ref_B
def run_blocksparse_copy(M, N, block_M, block_N, pass_configs=None):
"""Run blocksparse copy test with given parameters."""
N_S = N // block_N
program = blocksparse_copy_kernel(M, N, N_S, block_M, block_N)
kernel = tilelang.compile(
program,
out_idx=[1],
target="cuda",
pass_configs=pass_configs or {},
)
# Initialize tensors
a = torch.randn(M, N, device="cuda", dtype=torch.float16)
b = torch.zeros(M, N, device="cuda", dtype=torch.float16)
# Create BlockMask with valid row indices
num_row_blocks = M // block_M
block_mask = torch.zeros((num_row_blocks, N_S), dtype=torch.int32, device="cuda")
for by in range(num_row_blocks):
for i in range(N_S):
max_row_block = (M - block_M) // block_M
block_mask[by, i] = torch.randint(0, max_row_block + 1, (1,)).item() * block_M
# Run kernel
c = kernel(a, block_mask)
# Compute reference
ref_c = ref_blocksparse_copy(a, b, block_mask, M, N, N_S, block_M, block_N)
# Verify
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
@tilelang.testing.requires_cuda
def test_blocksparse_copy_tma():
"""Test blocksparse copy with TMA (Tensor Memory Accelerator)."""
run_blocksparse_copy(M=1024, N=1024, block_M=128, block_N=128, pass_configs={})
@tilelang.testing.requires_cuda
def test_blocksparse_copy_cp_async():
"""Test blocksparse copy with CP.ASYNC (without TMA)."""
run_blocksparse_copy(
M=1024,
N=1024,
block_M=128,
block_N=128,
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment