Unverified Commit 17fafc1b authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Smem Reuse] Optimize to do memory alignment on identical buffers. (#693)

* [Enhancement] Refactor GEMM operations for improved warp partitioning and target instruction handling

- Introduced a new `GetGemmInst` method to determine the appropriate GEMM instruction based on block size and target architecture.
- Updated `ComputeWarpPartition` to accept the GEMM instruction type, enhancing flexibility in warp partitioning logic.
- Added `TargetGetWarpSize` utility to streamline warp size retrieval based on target architecture.
- Refactored layout inference and lowering methods to utilize the new GEMM instruction handling, improving clarity and maintainability of the codebase.

* bug fix

* test fix

* lint fix

* phase out Canonialize

* add option --expt-relaxed-constexpr

* [Enhancement] Introduce tilelang intrinsic operations for GEMM

- Added `tl_gemm` and `tl_gemm_sp` built-in operations to support general and sparse matrix multiplication in tilelang.
- Updated the lowering logic in `Gemm` and `GemmSP` to utilize the new tilelang operations.
- Enhanced CUDA and HIP code generation to handle the new GEMM operations, ensuring proper argument validation and external call printing.
- Implemented shared memory alignment planning for GEMM operations to optimize performance on supported architectures.

* lint fix

* lint fix

* test fix

* test fix

* rebase

* Update builtin.cc
parent fdbf4d6c
......@@ -131,5 +131,14 @@ TIR_DEFINE_TL_BUILTIN(loop_break)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tl_gemm).set_num_inputs(4).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(tl_gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl
} // namespace tvm
......@@ -279,6 +279,21 @@ TVM_DLL const Op &tvm_rdna_wmma();
*/
TVM_DLL const Op &tvm_rdna_wmma_store();
/*!
* \brief tilelang intrinsic for general matrix multiplication (GEMM).
*
* This op is used to represent a generic GEMM operation in tilelang.
*/
TVM_DLL const Op &tl_gemm();
/*!
* \brief tilelang intrinsic for sparse matrix multiplication (GEMM with
* sparsity).
*
* This op is used to represent a sparse GEMM operation in tilelang.
*/
TVM_DLL const Op &tl_gemm_sp();
} // namespace tl
} // namespace tvm
......
......@@ -311,16 +311,9 @@ Stmt Gemm::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
ss << ", " << wg_wait;
}
ss << ">";
auto A_buffer = T.buffer_remap.count(A) ? T.buffer_remap[A] : A;
auto B_buffer = T.buffer_remap.count(B) ? T.buffer_remap[B] : B;
auto C_buffer = T.buffer_remap[C];
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(Aptr);
new_args.push_back(Bptr);
new_args.push_back(Cptr);
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
auto new_call = Call(DataType::Handle(), tl::tl_gemm(),
Array<PrimExpr>{StringImm(ss.str()), Aptr, Bptr, Cptr});
return Evaluate(new_call);
}
......
......@@ -248,13 +248,11 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto C_buffer = T.buffer_remap[C];
auto E_buffer = T.buffer_remap.count(E) ? T.buffer_remap[E] : E;
Array<PrimExpr> new_args;
new_args.push_back(StringImm(ss.str()));
new_args.push_back(A_buffer.access_ptr(1));
new_args.push_back(B_buffer.access_ptr(1));
new_args.push_back(C_buffer.access_ptr(3));
new_args.push_back(E_buffer.access_ptr(1));
auto new_call = Call(DataType::Handle(), builtin::call_extern(), new_args);
auto new_call =
Call(DataType::Handle(), tl::tl_gemm_sp(),
Array<PrimExpr>{StringImm(ss.str()), A_buffer.access_ptr(1),
B_buffer.access_ptr(1), C_buffer.access_ptr(3),
E_buffer.access_ptr(1)});
return Evaluate(new_call);
}
......
......@@ -79,11 +79,6 @@ Stmt Operator::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return Evaluate(0);
}
Stmt Operator::Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const {
return {};
}
LayoutMap Operator::InferLayout(const LayoutInferArgs &T, InferLevel level) {
return {};
}
......
......@@ -59,15 +59,9 @@ struct LayoutInferArgs {
Map<Buffer, Buffer> buffer_remap;
};
struct CanonializeArgs {
Target target;
};
class Operator {
public:
virtual Stmt Lower(const LowerArgs &T, arith::Analyzer *analyzer) const;
virtual Stmt Canonialize(const CanonializeArgs &T,
arith::Analyzer *analyzer) const;
virtual LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level);
virtual ~Operator() = default;
};
......
......@@ -991,6 +991,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
print_extern_call_stmt("tl::mbarrier_wait");
} else if (op->op.same_as(tl::sync_thread_partial())) {
print_extern_call_stmt("tl::syncthreads_partial");
} else if (op->op.same_as(tl::no_set_max_nreg())) {
return;
} else if (op->op.same_as(tl::tma_load())) {
std::ostringstream ss;
ICHECK_GE(op->args.size(), 2);
......@@ -1519,6 +1521,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else if (op->op.same_as(tl::tl_gemm())) {
ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
ICHECK(op->args.size() == 5)
<< "tl_gemm_sp expects 5 arguments <op_instance, A_ptr, B_ptr, C_ptr, "
"E_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
enable_sparse_gemm_ = true;
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
} else {
CodeGenC::VisitExpr_(op, os);
}
......@@ -1634,14 +1652,6 @@ void CodeGenTileLangCUDA::VisitStmt_(const EvaluateNode *op) {
stream << " " << vid_global_barrier_expect_ << " = 0;\n";
PrintIndent();
stream << "}\n";
} else if (call && call->op.same_as(builtin::call_extern())) {
ICHECK(call->args.size() >= 1)
<< "call_extern must have at least 1 argument";
std::string func_name = call->args[0].as<StringImmNode>()->value;
if (func_name.find("tl::gemm_sp") == 0) {
enable_sparse_gemm_ = true;
}
CodeGenC::VisitStmt_(op);
} else {
CodeGenC::VisitStmt_(op);
}
......
......@@ -946,6 +946,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("{c_ref}", c_ref);
replacer.register_rule("{c_bias}", c_bias);
os << replacer.rewrite(call_mfma_code);
} else if (op->op.same_as(builtin::thread_return())) {
os << "return";
} else if (op->op.same_as(tl::tl_gemm())) {
ICHECK(op->args.size() == 4) << "tl_gemm expects 4 arguments <op_instance, "
"A_ptr, B_ptr, C_ptr>, but got "
<< op->args.size();
auto op_instance = Downcast<StringImm>(op->args[0]);
this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_instance->value,
op->args, true, os);
} else if (op->op.same_as(tl::tl_gemm_sp())) {
LOG(FATAL) << "tl_gemm_sp is not supported on HIP";
} else {
CodeGenC::VisitExpr_(op, os);
}
......
......@@ -35,9 +35,11 @@
#include <unordered_set>
#include "../op/builtin.h"
#include "../target/utils.h"
#include "runtime/thread_storage_scope.h"
#include "support/arena.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/function.h"
namespace tvm {
namespace tl {
......@@ -315,6 +317,46 @@ private:
size_t scope_level_{0};
};
class SharedMemoryAlignmentPlanner : public StmtExprVisitor {
public:
static std::unordered_map<const VarNode *, int> Plan(const Stmt &stmt) {
SharedMemoryAlignmentPlanner planner;
planner(stmt);
return planner.shmem_alignment_map_;
}
private:
void VisitExpr_(const CallNode *op) {
if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) ||
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store())) {
under_alignment_scope_ = true;
StmtExprVisitor::VisitExpr_(op);
under_alignment_scope_ = false;
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
void VisitExpr_(const VarNode *op) {
auto ptr_type = op->type_annotation.as<PointerTypeNode>();
if (ptr_type && under_alignment_scope_) {
auto scope = GetPtrStorageScope(GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined";
const int alignment = TargetIsHopper(target) ? 1024 : 16;
shmem_alignment_map_[op] = alignment;
}
}
StmtExprVisitor::VisitExpr_(op);
}
bool under_alignment_scope_{false};
std::unordered_map<const VarNode *, int> shmem_alignment_map_;
};
/*!
* \brief merge the buffers whose live range has no intersection and rewrite the
* body
......@@ -342,6 +384,7 @@ public:
SharedMemLinearAccessPatternFinder finder(is_dynamic,
enable_aggressive_merge, verbose);
finder(stmt);
shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt);
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
}
......@@ -359,6 +402,14 @@ private:
for (const StorageEntry *e : sym_free_list_) {
all_entry.push_back(e);
}
// Sort the storage entries in descending order of their total allocation
// size (in bits). This ensures that larger allocations are placed first,
// which can help minimize fragmentation and improve memory packing
// efficiency when merging shared memory buffers.
std::sort(all_entry.begin(), all_entry.end(),
[](const StorageEntry *a, const StorageEntry *b) {
return a->const_nbits > b->const_nbits;
});
for (const StorageEntry *e : all_entry) {
max_layer_num =
std::max(max_layer_num, static_cast<int>(e->allocs.size()));
......@@ -375,18 +426,28 @@ private:
}
}
}
// calculate offset for each buffer based on the align of each layer
for (const StorageEntry *e : all_entry) {
PrimExpr max_inner_offset = 0;
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
PrimExpr inner_offset = 0;
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
inner_offset +=
auto alignment = align[i];
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for exmaple) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if (shmem_alignment_map_.find(buffer) !=
shmem_alignment_map_.end()) {
alignment = std::max(align[i], shmem_alignment_map_[buffer]);
}
PrimExpr start_offset = merged_alloc_size_ + inner_offset;
PrimExpr aligned_offset =
indexdiv(start_offset + alignment - 1, alignment) * alignment;
buffer_byte_offsets_[buffer] = aligned_offset;
inner_offset =
aligned_offset - merged_alloc_size_ +
alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes();
inner_offset +=
indexmod(align[i] - indexmod(inner_offset, align[i]), align[i]);
}
max_inner_offset = max(max_inner_offset, inner_offset);
}
......@@ -576,6 +637,18 @@ private:
std::vector<const VarNode *> kill;
};
void PlanAlignment(const Stmt &stmt) {
LOG(INFO) << "PlanAlignment";
PostOrderVisit(stmt, [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(tl::tl_gemm()) ||
call->op.same_as(tl::tl_gemm_sp())) {
LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op;
}
}
});
}
/*!
* \brief Liveness analysis to find gen and kill point of each variable.
* \param seq the linear pattern of storage access
......@@ -1004,6 +1077,8 @@ private:
std::unordered_map<const VarNode *, StorageEntry *> alloc_map_;
/*! \brief allocator of all the StorageEntry*/
support::Arena arena_;
// The mapping of buffer bytes alignment
std::unordered_map<const VarNode *, int> shmem_alignment_map_;
};
Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem,
......
......@@ -164,12 +164,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# MergeSharedMemoryAllocations must be applied after SplitHostDevice
# because the merged allocation site is at the beginning of each device function
enable_aggressive_merge = should_enable_aggressive_merge(pass_ctx=pass_ctx, target=target)
# Hopper Swizzling requires dynamic shared memory address to be aligned to 1024 bytes
# For other devices, we align to 16 bytes
smem_align_bytes = 1024 if have_tma(target) else 16
# Workaround, wait for a element wise synchronization pass
mod = tilelang.transform.MergeSharedMemoryAllocations(
enable_aggressive_merge=enable_aggressive_merge, align_bytes=smem_align_bytes)(
enable_aggressive_merge=enable_aggressive_merge)(
mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment