"docs/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "b49d6d0fee3cf83d72ed658bd9f514bd87fcaa56"
Commit 8df45c9d authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Avoid duplicate data access when cross thread buffer meet replicate register (#606)

* [Enhancement] Improve debug output formatting in layout and fragment nodes

- Updated the `DebugOutput` methods in `LayoutNode` and `FragmentNode` to provide more structured and informative output, including transformation details and thread range information.
- Enhanced layout inference logic in `ParallelOp` to add predicates for cross-thread shared memory access, improving layout handling in parallel operations.
- Minor adjustment in `layout_inference.cc` to ensure clarity in parallel loop handling.

* lint fix
parent a8811d9b
......@@ -370,19 +370,22 @@ Fragment FragmentNode::CondenseReplicateVar() const {
std::string LayoutNode::DebugOutput() const {
std::stringstream ss;
ss << "Layout Shape: " << InputShape() << " -> " << OutputShape() << " -> "
<< GetForwardIndex();
ss << "Layout(" << InputShape() << " -> " << OutputShape()
<< ", transform: " << GetForwardVars() << " -> " << GetForwardIndex()
<< ")";
return ss.str();
}
std::string FragmentNode::DebugOutput() const {
std::stringstream ss;
ss << "Fragment Shape: " << InputShape() << " -> " << OutputShape();
ss << " -> replicate: " << ReplicateExtent();
ss << " -> thread: " << ThreadExtent();
ss << " -> forward_thread: " << forward_thread_;
ss << " -> forward_index: " << GetForwardIndex();
ss << " -> thread_range: " << thread_range_;
ss << "Fragment(" << InputShape() << " -> " << OutputShape()
<< ", replicate: " << ReplicateExtent() << ", thread: " << ThreadExtent()
<< ", forward_thread: " << forward_thread_
<< ", forward_index: " << GetForwardIndex();
if (thread_range_.defined()) {
ss << ", thread_range: " << thread_range_;
}
ss << ")";
return ss.str();
}
......
......@@ -174,16 +174,38 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
// // if still has replication, add a condition
// if (!is_one(loop_layout_->ReplicateExtent())) {
// auto inv = loop_layout_->Inverse();
// Array<PrimExpr> fwd;
// for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
// fwd.push_back(0);
// fwd.push_back(InputPlaceholder(0));
// auto rep = inv->Forward(fwd).back();
// AddPredicate(EQ(rep, 0));
// }
// For free layout inference
// If replication exists and buffer has cross-thread shared memory access,
// add predicate
bool has_cross_thread_access = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// check if scope is shared or global
if (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "global") {
has_cross_thread_access = true;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
// check if scope is shared or global
if (load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn" ||
load->buffer.scope() == "global") {
has_cross_thread_access = true;
}
}
});
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0));
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else {
// Vectorize Size must be aware of the buffer_remap
// As the pass will do post processing to the layout
......
......@@ -598,6 +598,7 @@ private:
auto loop_layout = result_.for_map[root];
bool parallel_loop = !is_register_store && !skip_thread_partition_;
if (parallel_loop) {
for_node =
PartitionLoop(for_node, thread_var_->var, analyzer_, loop_layout);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment