Unverified Commit c382dcbc authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Introduce Flexible Parallel to Support T.serial and local buffers...


[Layout] Introduce Flexible Parallel to Support T.serial and local buffers inside T.Parallel loop (#844)

* Support T.serial and local buffers inside T.Parallel loop.

* Fix reducer layout in T.Parallel nested inside other loops

* Debug output with LOG(INFO)

* Add disable option for WGMMA.

* fix

* Use DLOG; fix missing registration for new pass config

* bug fix

* lint fix

* Enhance GEMM instruction set with UTCMMA and improve local buffer handling in casting example

* Update format.sh shebang, improve logging in layout inference, and enhance buffer store wrapper with detailed comments

* Enhance GEMM instantiation logic and improve layout inference for local buffer detection

- Updated the GEMM instantiation logic to include a check for WGMMA compatibility, ensuring that the conditions for using WGMMA are more robust.
- Refined the layout inference process to better identify when loops manipulate only local buffers, improving the accuracy of thread binding decisions in parallel loops.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
parent bf67fb19
......@@ -42,6 +42,7 @@ Checks: >
-cppcoreguidelines-pro-type-static-cast-downcast,
-performance-unnecessary-value-param,
-performance-enum-size,
-clang-analyzer-deadcode.DeadStores,
WarningsAsErrors: '*'
......
......@@ -29,7 +29,7 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "float8_e4m3")
row_offset = T.alloc_local((1,), "int32")
row_offset = T.alloc_fragment((1,), "int32")
T.annotate_layout({
y_local:
......
......@@ -29,6 +29,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION(kDisableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnableFastMath, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kPtxasRegisterUsageLevel, Integer);
TVM_REGISTER_PASS_CONFIG_OPTION(kEnablePTXASVerboseOutput, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableWGMMA, Bool);
TVM_REGISTER_PASS_CONFIG_OPTION(kDisableShuffleElect, Bool);
DataType cuTensorMapType() { return DataType::UInt(8, 128); }
......
......@@ -45,6 +45,7 @@ static constexpr const char *kPtxasRegisterUsageLevel =
"tl.ptxas_register_usage_level";
static constexpr const char *kEnablePTXASVerboseOutput =
"tl.enable_ptxas_verbose_output";
static constexpr const char *kDisableWGMMA = "tl.disable_wgmma";
static constexpr const char *kDisableShuffleElect = "tl.disable_shuffle_elect";
/*!
* \brief Whether to disable dynamic tail split
......
......@@ -92,10 +92,14 @@ TileOperator GemmNode::Clone() const {
}
GemmNode::GemmInst GemmNode::GetGemmInst(int block_size, Target target) const {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
int warp_size = TargetGetWarpSize(target);
int num_warps = block_size / warp_size;
bool allow_wgmma = TargetIsHopper(target) && (this->M >= 64) &&
(num_warps % 4 == 0) && CheckWGMMA();
bool allow_wgmma =
!ctxt->GetConfig(kDisableWGMMA, Optional<Bool>()).value_or(false) &&
TargetIsHopper(target) && (this->M >= 64) && (num_warps % 4 == 0) &&
CheckWGMMA();
if (allow_wgmma) {
return GemmInst::kWGMMA;
} else if (TargetIsCDNA(target)) {
......
......@@ -128,9 +128,13 @@ private:
* visitor's reducer_info_map_. Continues traversal into the loop body.
*/
void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) {
ICHECK(op->kind == ForKind::kParallel);
p->loop_vars_.push_back(
IterVar(Range(op->min, op->extent), op->loop_var, IterVarType::kDataPar));
if (op->kind == ForKind::kParallel)
p->loop_vars_.push_back(IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kDataPar));
else
p->inner_vars_.Set(op->loop_var,
IterVar(Range(op->min, op->extent), op->loop_var,
IterVarType::kOrdered));
p->analyzer_.Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
auto reducer_info_map =
op->annotations.Get(attr::kReducerInfo)->as<Map<Var, ReducerInfo>>();
......@@ -244,17 +248,33 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
}
auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) {
Fragment src_layout = T.layout_map[buffer].as<Fragment>().value();
DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `"
<< buffer << "` of layout " << src_layout->DebugOutput() << '\n';
Fragment result;
if (IsCommonAccessIndice(buffer)) {
return src_layout;
result = src_layout;
} else {
Var rep;
auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep,
IterVarType::kDataPar);
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
return Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
}
DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get "
<< result->DebugOutput() << '\n';
return result;
};
if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
......@@ -317,15 +337,21 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';
PrimExpr loop_total_size = 1;
for (Stmt l = root_; l.as<For>().has_value();
l = l.as<For>().value()->body)
loop_total_size = loop_total_size * l.as<For>().value()->extent;
DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size
<< '\n';
while (!analyzer_.CanProve(
floormod(loop_total_size,
T.thread_bounds->extent * vector_size) == 0) &&
vector_size > 1)
vector_size /= 2;
DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = "
<< vector_size << '\n';
// Check if coalesced_width is defined
if (auto coalesced_width =
......@@ -342,7 +368,12 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
LOG(FATAL) << "coalesced_width should be an IntImmNode.";
}
}
DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_
<< " ############# vector_size = " << vector_size
<< ", thread_bounds = " << T.thread_bounds << '\n';
loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds);
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
} else {
return {};
......
......@@ -128,6 +128,7 @@ private:
void AddPredicate(const PrimExpr &expr) const {
predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr;
}
// Allow ParallelLoopNestVisitor to access private members.
friend class ParallelLoopNestVisitor;
......@@ -139,6 +140,8 @@ private:
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_is_write_;
// The loop variables for the parallel loop nest.
Array<IterVar> loop_vars_;
// The inner_vars_
Map<Var, IterVar> inner_vars_;
// Analyzer for simplifying and analyzing expressions, mutable for lazy use.
mutable arith::Analyzer analyzer_;
// Mapping from buffer to reducer info.
......
......@@ -105,13 +105,16 @@ public:
"required for layout inference.";
// Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';
// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
......@@ -140,6 +143,8 @@ public:
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) {
layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from "
<< src_layout->DebugOutput() << ", accepted" << '\n';
continue;
}
}
......@@ -151,6 +156,7 @@ public:
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n';
if (!update_queue)
continue;
......@@ -210,6 +216,11 @@ public:
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";
DLOG(INFO) << "[InferLayout] all participating operators:" << '\n';
for (int i = 0; i < infer_list_stmt_.size(); ++i) {
DLOG(INFO) << " op " << i << ":" << infer_list_stmt_[i] << '\n';
}
// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.
......@@ -470,6 +481,13 @@ private:
void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) {
DLOG(INFO) << "Enforced layout maps:" << '\n';
for (auto &&[k, v] : layout_map) {
DLOG(INFO) << " " << k << ": " << v->DebugOutput() << '\n';
}
DLOG(INFO) << '\n';
// Group operators into connected components
UnionFind<int> uf;
for (int i = 0; i < infer_list_.size(); i++) {
......@@ -505,52 +523,53 @@ private:
std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) {
DLOG(INFO) << "======================= processing component " << root
<< '\n';
decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX;
int min_reg_num_infer_root = -1;
// Try each member as the root of inference for this component
for (int attempt_infer_root : members) {
// backup infer_list_ in class member
DLOG(INFO) << "----------------------- try root " << attempt_infer_root
<< '\n';
// Backup the current infer_list_ state
auto back_infer_list = BackupInferList();
// create temporarily used layout_map, new handle so that it copies on
// write
// Copy the current layout_map for temporary use
LayoutMap tmp_layout_map = layout_map;
// infer from attempt_infer_root in free mode
bool do_update = true;
try {
// Run inference starting from attempt_infer_root
RunInferStep(attempt_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it.
// This would constantly result in incomplete layouts for buffers in
// this component. Instead of trying all combinations of root
// selection order, we simply go through all other loops in order
// after the first search from attempt_infer_root.
// After the first search, run inference for all other members in
// order
for (int other_infer_root : members) {
if (other_infer_root != attempt_infer_root) {
RunInferStep(other_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
// must also be kFree here to avoid conflicts.
FinishInferQueue(InferLevel::kFree, tmp_layout_map,
strict_layout_map, q, in_queue);
}
}
} catch (LayoutConflictException e) {
// such an order fails, try others
} catch (const LayoutConflictException &e) {
do_update = false;
} catch (NormalizeIterException e) {
// such an order encounters iterators that is not normalizable, try
// others e.g. i * 576 % 2048
DLOG(INFO) << "attempt failed due to LayoutConflictException "
<< e.what() << '\n';
} catch (const NormalizeIterException &e) {
do_update = false;
DLOG(INFO) << "attempt failed due to NormalizeIterException "
<< e.what() << '\n';
}
if (do_update) {
// compute total register number
// Compute the total register number for this layout
int64_t reg_num = 0;
for (auto &&[buffer, layout] : tmp_layout_map) {
for (const auto &[buffer, layout] : tmp_layout_map) {
if (auto frag = layout.as<Fragment>()) {
int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) {
......@@ -561,21 +580,24 @@ private:
reg_num += frag_reg_num;
}
}
// if it's any better, update the best_* storage
// Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) {
best_infer_list = std::move(infer_list_);
best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map;
min_reg_num = reg_num;
min_reg_num_infer_root = attempt_infer_root;
}
}
// recover stateful infer_list_, head on next
// Restore infer_list_ state for the next attempt
infer_list_ = std::move(back_infer_list);
}
if (min_reg_num < INT64_MAX) {
// now apply the best plan for this component
ICHECK(min_reg_num < INT64_MAX) << "no available layout found" << '\n';
// Apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
}
DLOG(INFO) << "[InferInFreeMode] Final selection is attempt_infer_root = "
<< min_reg_num_infer_root << '\n';
}
}
};
......@@ -682,20 +704,25 @@ private:
// Here, A_local is a register-local buffer held independently by each
// thread, so explicit thread binding is not required.
//
// We use PostOrderVisit to detect whether the buffer store targets a
// "local" buffer, which indicates register usage and justifies skipping
// We use PostOrderVisit to detect whether the loop only manuplates
// "local" buffers, which indicates register usage and justifies skipping
// thread binding.
bool is_register_store = false;
bool local_register_only = true;
PostOrderVisit(root, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
if (store->buffer.scope() == "local") {
is_register_store = true;
if (store->buffer.scope() != "local") {
local_register_only = false;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
if (load->buffer.scope() != "local") {
local_register_only = false;
}
}
});
auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_;
// FIXME: tell in-Parallel and out-of-Parallel `local`s apart
bool parallel_loop = !skip_thread_partition_ && !local_register_only;
if (parallel_loop) {
for_node =
......
......@@ -178,7 +178,8 @@ private:
Stmt VisitStmt_(const ForNode *op) final {
// only annotate the outermost loop
bool should_annotate = false;
if (!inside_reducer_range_.empty() && !already_annotated_) {
if (!inside_reducer_range_.empty() && !already_annotated_ &&
op->kind == ForKind::kParallel) {
should_annotate = true;
already_annotated_ = true;
}
......
......@@ -639,12 +639,12 @@ private:
};
void PlanAlignment(const Stmt &stmt) {
LOG(INFO) << "PlanAlignment";
DLOG(INFO) << "PlanAlignment";
PostOrderVisit(stmt, [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(tl::tl_gemm()) ||
call->op.same_as(tl::tl_gemm_sp())) {
LOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
DLOG(INFO) << "PostOrderVisit CallNode tl_gemm and tl_gemm_sp: "
<< call->op;
}
}
......
......@@ -1789,7 +1789,7 @@ public:
PrimExpr last_extent = extents[extents.size() - 1];
extents.Set(extents.size() - 1,
last_extent / make_const(last_extent.dtype(), info.factor()));
LOG(INFO) << "Allocate with " << new_buffer_var << " and "
DLOG(INFO) << "Allocate with " << new_buffer_var << " and "
<< info.new_element_dtype << " extents: " << extents;
return Allocate(new_buffer_var, info.new_element_dtype, extents,
op->condition, op->body);
......
......@@ -45,6 +45,9 @@ class PassConfigKey(str, Enum):
TL_DISABLE_SAFE_MEMORY_ACCESS = "tl.disable_safe_memory_legalize"
"""Disable safe memory access optimization. Default: False"""
TL_DISABLE_WGMMA = "tl.disable_wgmma"
"""Disable usage of Hopper WGMMA. Default: False"""
TL_DEBUG_MERGE_SHARED_MEMORY_ALLOCATIONS = "tl.debug_merge_shared_memory_allocations"
"""Enable debug information for merge shared memory allocations. Default: False"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment