Unverified Commit c382dcbc authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Introduce Flexible Parallel to Support T.serial and local buffers...


[Layout] Introduce Flexible Parallel to Support T.serial and local buffers inside T.Parallel loop (#844)

* Support T.serial and local buffers inside T.Parallel loop.

* Fix reducer layout in T.Parallel nested inside other loops

* Debug output with LOG(INFO)

* Add disable option for WGMMA.

* fix

* Use DLOG; fix missing registration for new pass config

* bug fix

* lint fix

* Enhance GEMM instruction set with UTCMMA and improve local buffer handling in casting example

* Update format.sh shebang, improve logging in layout inference, and enhance buffer store wrapper with detailed comments

* Enhance GEMM instantiation logic and improve layout inference for local buffer detection

- Updated the GEMM instantiation logic to include a check for WGMMA compatibility, ensuring that the conditions for using WGMMA are more robust.
- Refined the layout inference process to better identify when loops manipulate only local buffers, improving the accuracy of thread binding decisions in parallel loops.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
parent bf67fb19
...@@ -42,6 +42,7 @@ Checks: > ...@@ -42,6 +42,7 @@ Checks: >
-cppcoreguidelines-pro-type-static-cast-downcast, -cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param, -performance-unnecessary-value-param,
-performance-enum-size, -performance-enum-size,
-clang-analyzer-deadcode.DeadStores,
WarningsAsErrors: '*' WarningsAsErrors: '*'
......
...@@ -29,7 +29,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m): ...@@ -29,7 +29,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_s_local = T.alloc_fragment((blk_m,), accum_dtype) y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype) y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3") y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_local((1,), "int32") row_offset = T.alloc_fragment((1,), "int32")
T.annotate_layout({ T.annotate_layout({
y_local: y_local:
......
...@@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool); ...@@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer); TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool); TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
DataType cuTensorMapType() { return DataType::UInt(8, 128); } DataType cuTensorMapType() { return DataType::UInt(8, 128); }
......
...@@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel = ...@@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level"; "tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput = static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output"; "tl.enable_ptxas_verbose_output";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect"; static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
/*! /*!
* \brief Whether to disable dynamic tail split * \brief Whether to disable dynamic tail split
......
...@@ -92,10 +92,14 @@ TileOperator GemmNode::Clone() const { ...@@ -92,10 +92,14 @@ TileOperator GemmNode::Clone() const {
} }
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const { GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target); int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size; int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) && bool allow_wgmma =
(num_warps % 4 == 0) && CheckWGMMA(); !ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
if (allow_wgmma) { if (allow_wgmma) {
return GemmInst::kWGMMA; return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) { } else if (TargetIsCDNA(target)) {
......
...@@ -128,9 +128,13 @@ private: ...@@ -128,9 +128,13 @@ private:
* visitor's reducer_info_map_. Continues traversal into the loop body. * visitor's reducer_info_map_. Continues traversal into the loop body.
*/ */
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel); if (op->kind == ForKind::kParallel)
p->loop_vars_.push_back( p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar)); IterVarType::kDataPar));
else
p->inner_vars_.Set(op->loop_var,
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map = auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>(); op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
...@@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} }
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) { auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value(); Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
Fragment result;
if (IsCommonAccessIndice(buffer)) { if (IsCommonAccessIndice(buffer)) {
return src_layout; result = src_layout;
} else { } else {
Var rep; Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
IterVarType::kDataPar); IterVarType::kDataPar);
PrimExpr loop_var_to_thread = PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep); src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
->BindThreadRange(T.thread_bounds); PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
} }
DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
<< result->DebugOutput() << '\n';
return result;
}; };
if (source_buffer.defined()) { if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer); loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
...@@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_); int vector_size = GetVectorizeSize(maybe_remapped_root_);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
PrimExpr loop_total_size = 1; PrimExpr loop_total_size = 1;
for (Stmt l = root_; l.as<For>().has_value(); for (Stmt l = root_; l.as<For>().has_value();
l = l.as<For>().value()->body) l = l.as<For>().value()->body)
loop_total_size = loop_total_size * l.as<For>().value()->extent; loop_total_size = loop_total_size * l.as<For>().value()->extent;
DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
<< '\n';
while (!analyzer_.CanProve( while (!analyzer_.CanProve(
floormod(loop_total_size, floormod(loop_total_size,
T.thread_bounds->extent * vector_size) == 0) && T.thread_bounds->extent * vector_size) == 0) &&
vector_size > 1) vector_size > 1)
vector_size /= 2; vector_size /= 2;
DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
<< vector_size << '\n';
// Check if coalesced_width is defined // Check if coalesced_width is defined
if (auto coalesced_width = if (auto coalesced_width =
...@@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
LOG(FATAL) << "coalesced_width should be an IntImmNode."; LOG(FATAL) << "coalesced_width should be an IntImmNode.";
} }
} }
DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
<< " ############# vector_size = " << vector_size
<< ", thread_bounds = " << T.thread_bounds << '\n';
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds); loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
} }
} else { } else {
return {}; return {};
......
...@@ -128,6 +128,7 @@ private: ...@@ -128,6 +128,7 @@ private:
void AddPredicate(const PrimExpr &expr) const { void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
} }
// Allow ParallelLoopNestVisitor to access private members. // Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor; friend class ParallelLoopNestVisitor;
...@@ -139,6 +140,8 @@ private: ...@@ -139,6 +140,8 @@ private:
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_; std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// The loop variables for the parallel loop nest. // The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_; Array<IterVar> loop_vars_;
// The inner_vars_
Map<Var, IterVar> inner_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use. // Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_; mutable arith::Analyzer analyzer_;
// Mapping from buffer to reducer info. // Mapping from buffer to reducer info.
......
...@@ -105,13 +105,16 @@ public: ...@@ -105,13 +105,16 @@ public:
"required for layout inference."; "required for layout inference.";
// Run InferLayout // Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates = auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob}, &analyzer_, buffer_oob},
level); level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';
// Basic validity checks // Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
...@@ -140,6 +143,8 @@ public: ...@@ -140,6 +143,8 @@ public:
if (ProveFragmentContains(src_layout, dst_layout, indices, indices, if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) { inner_analyzer)) {
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from "
<< src_layout->DebugOutput() << ", accepted" << '\n';
continue; continue;
} }
} }
...@@ -151,6 +156,7 @@ public: ...@@ -151,6 +156,7 @@ public:
} else { } else {
// Otherwise, update map // Otherwise, update map
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n';
if (!update_queue) if (!update_queue)
continue; continue;
...@@ -210,6 +216,11 @@ public: ...@@ -210,6 +216,11 @@ public:
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length."; "length.";
DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
for (int i = 0; i < infer_list_stmt_.size(); ++i) {
DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n';
}
// If needed, you can also check that annotated_layout_map_ is not empty, or // If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup. // anything else relevant to your setup.
...@@ -470,6 +481,13 @@ private: ...@@ -470,6 +481,13 @@ private:
void InferInFreeMode(LayoutMap &layout_map, void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) { const LayoutMap &strict_layout_map) {
DLOG(INFO) << "Enforced layout maps:" << '\n';
for (auto &&[k, v] : layout_map) {
DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n';
}
DLOG(INFO) << '\n';
// Group operators into connected components // Group operators into connected components
UnionFind<int> uf; UnionFind<int> uf;
for (int i = 0; i < infer_list_.size(); i++) { for (int i = 0; i < infer_list_.size(); i++) {
...@@ -505,52 +523,53 @@ private: ...@@ -505,52 +523,53 @@ private:
std::vector<bool> in_queue(infer_list_.size(), false); std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) { for (auto &&[root, members] : components) {
DLOG(INFO) << "======================= processing component " << root
<< '\n';
decltype(infer_list_) best_infer_list; decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map; LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX; int64_t min_reg_num = INT64_MAX;
int min_reg_num_infer_root = -1;
// Try each member as the root of inference for this component
for (int attempt_infer_root : members) { for (int attempt_infer_root : members) {
// backup infer_list_ in class member DLOG(INFO) << "----------------------- try root " << attempt_infer_root
<< '\n';
// Backup the current infer_list_ state
auto back_infer_list = BackupInferList(); auto back_infer_list = BackupInferList();
// create temporarily used layout_map, new handle so that it copies on // Copy the current layout_map for temporary use
// write
LayoutMap tmp_layout_map = layout_map; LayoutMap tmp_layout_map = layout_map;
// infer from attempt_infer_root in free mode
bool do_update = true; bool do_update = true;
try { try {
// Run inference starting from attempt_infer_root
RunInferStep(attempt_infer_root, InferLevel::kFree, true, RunInferStep(attempt_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue); tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map, FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue); q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have // After the first search, run inference for all other members in
// complicated conditioning inside and we know nothing about it. // order
// This would constantly result in incomplete layouts for buffers in
// this component. Instead of trying all combinations of root
// selection order, we simply go through all other loops in order
// after the first search from attempt_infer_root.
for (int other_infer_root : members) { for (int other_infer_root : members) {
if (other_infer_root != attempt_infer_root) { if (other_infer_root != attempt_infer_root) {
RunInferStep(other_infer_root, InferLevel::kFree, true, RunInferStep(other_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue); tmp_layout_map, strict_layout_map, q, in_queue);
// must also be kFree here to avoid conflicts.
FinishInferQueue(InferLevel::kFree, tmp_layout_map, FinishInferQueue(InferLevel::kFree, tmp_layout_map,
strict_layout_map, q, in_queue); strict_layout_map, q, in_queue);
} }
} }
} catch (LayoutConflictException e) { } catch (const LayoutConflictException &e) {
// such an order fails, try others
do_update = false; do_update = false;
} catch (NormalizeIterException e) { DLOG(INFO) << "attempt failed due to LayoutConflictException "
// such an order encounters iterators that is not normalizable, try << e.what() << '\n';
// others e.g. i * 576 % 2048 } catch (const NormalizeIterException &e) {
do_update = false; do_update = false;
DLOG(INFO) << "attempt failed due to NormalizeIterException "
<< e.what() << '\n';
} }
if (do_update) { if (do_update) {
// compute total register number // Compute the total register number for this layout
int64_t reg_num = 0; int64_t reg_num = 0;
for (auto &&[buffer, layout] : tmp_layout_map) { for (const auto &[buffer, layout] : tmp_layout_map) {
if (auto frag = layout.as<Fragment>()) { if (auto frag = layout.as<Fragment>()) {
int64_t frag_reg_num = 1; int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) { for (auto i : frag.value()->OutputShape()) {
...@@ -561,21 +580,24 @@ private: ...@@ -561,21 +580,24 @@ private:
reg_num += frag_reg_num; reg_num += frag_reg_num;
} }
} }
// if it's any better, update the best_* storage // Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) { if (reg_num < min_reg_num) {
best_infer_list = std::move(infer_list_); best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map; best_layout_map = tmp_layout_map;
min_reg_num = reg_num; min_reg_num = reg_num;
min_reg_num_infer_root = attempt_infer_root;
} }
} }
// recover stateful infer_list_, head on next // Restore infer_list_ state for the next attempt
infer_list_ = std::move(back_infer_list); infer_list_ = std::move(back_infer_list);
} }
if (min_reg_num < INT64_MAX) { ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
// now apply the best plan for this component // Apply the best plan for this component
infer_list_ = std::move(best_infer_list); infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map; layout_map = best_layout_map;
} DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
<< min_reg_num_infer_root << '\n';
} }
} }
}; };
...@@ -682,20 +704,25 @@ private: ...@@ -682,20 +704,25 @@ private:
// Here, A_local is a register-local buffer held independently by each // Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required. // thread, so explicit thread binding is not required.
// //
// We use PostOrderVisit to detect whether the buffer store targets a // We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffer, which indicates register usage and justifies skipping // "local" buffers, which indicates register usage and justifies skipping
// thread binding. // thread binding.
bool is_register_store = false; bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) { PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) { if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") { if (store->buffer.scope() != "local") {
is_register_store = true; local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
} }
} }
}); });
auto loop_layout = result_.for_map[root]; auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_; // FIXME: tell in-Parallel and out-of-Parallel `local`s apart
bool parallel_loop = !skip_thread_partition_ && !local_register_only;
if (parallel_loop) { if (parallel_loop) {
for_node = for_node =
......
...@@ -178,7 +178,8 @@ private: ...@@ -178,7 +178,8 @@ private:
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop // only annotate the outermost loop
bool should_annotate = false; bool should_annotate = false;
if (!inside_reducer_range_.empty() && !already_annotated_) { if (!inside_reducer_range_.empty() && !already_annotated_ &&
op->kind == ForKind::kParallel) {
should_annotate = true; should_annotate = true;
already_annotated_ = true; already_annotated_ = true;
} }
......
...@@ -639,13 +639,13 @@ private: ...@@ -639,13 +639,13 @@ private:
}; };
void PlanAlignment(const Stmt &stmt) { void PlanAlignment(const Stmt &stmt) {
LOG(INFO) << "PlanAlignment"; DLOG(INFO) << "PlanAlignment";
PostOrderVisit(stmt, [&](const ObjectRef &node) { PostOrderVisit(stmt, [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) { if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(tl::tl_gemm()) || if (call->op.same_as(tl::tl_gemm()) ||
call->op.same_as(tl::tl_gemm_sp())) { call->op.same_as(tl::tl_gemm_sp())) {
LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: " DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op; << call->op;
} }
} }
}); });
......
...@@ -1789,8 +1789,8 @@ public: ...@@ -1789,8 +1789,8 @@ public:
PrimExpr last_extent = extents[extents.size() - 1]; PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1, extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor())); last_extent / make_const(last_extent.dtype(), info.factor()));
LOG(INFO) << "Allocate with " << new_buffer_var << " and " DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents; << info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents, return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body); op->condition, op->body);
} }
......
...@@ -45,6 +45,9 @@ class PassConfigKey(str, Enum): ...@@ -45,6 +45,9 @@ class PassConfigKey(str, Enum):
TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize" TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize"
"""Disable safe memory access optimization. Default: False""" """Disable safe memory access optimization. Default: False"""
TL_DISABLE_WGMMA = "tl.disable_wgmma"
"""Disable usage of Hopper WGMMA. Default: False"""
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations" TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False""" """Enable debug information for merge shared memory allocations. Default: False"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment