".github/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6cef7d2366c05a72f6b1e034e9260636d1eccd8d"
Unverified Commit 67cc8611 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Add thread count validation for ReduceOp fragment layout inference (#1225)

* [Enhancement] Add thread count validation for ReduceOp fragment layout inference

* Introduced a check to ensure that the thread count is divisible by the replicate extent during layout inference in ReduceOpNode. This validation prevents layout inference failures and provides detailed error messages to guide users in resolving issues related to thread block sizes and fragment layouts.
* Updated tests to remove unsupported configurations that could lead to layout inference errors, ensuring more robust testing scenarios.

* lint fix
parent eb6e8973
...@@ -389,6 +389,35 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -389,6 +389,35 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
} }
auto thd = src_layout->ForwardThread( auto thd = src_layout->ForwardThread(
fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent));
// Ensure the thread count is divisible by the replicate extent.
// Otherwise, we cannot infer a valid fragment<->fragment layout.
{
arith::Analyzer analyzer;
PrimExpr num_threads = T.thread_bounds->extent;
// Though the dest_buffer_rep_extent will be compressed at
// CondenseReplicateVar, we need to check the divisibility here to avoid
// the issue that the thread count is not divisible by the replicate
// extent.
if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) ==
0) &&
!analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) ==
0)) {
ICHECK(false) << "ReduceOp fragment layout inference failed: "
"num_threads % replicate_extent != 0. "
<< "This mapping requires the block's thread count to be "
"divisible by the "
<< "replicate extent. "
<< "Try one of: (1) choose a thread block size divisible "
"by replicate_extent; "
<< "(2) pick a different reduce dimension or adjust the "
"source fragment layout; "
<< "Details: num_threads=" << num_threads
<< ", replicate_extent=" << indice_rep_extent
<< ", src=" << src << ", dst=" << dst;
}
}
Fragment dst_layout = Fragment dst_layout =
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar() ->CondenseReplicateVar()
......
...@@ -116,7 +116,6 @@ def test_reduce_sum(): ...@@ -116,7 +116,6 @@ def test_reduce_sum():
def test_reduce_sum_shared(): def test_reduce_sum_shared():
run_reduce_sum(64, 64, mode="ss") run_reduce_sum(64, 64, mode="ss")
run_reduce_sum(32, 96, mode="ss")
def test_reduce_max(): def test_reduce_max():
...@@ -127,7 +126,6 @@ def test_reduce_max(): ...@@ -127,7 +126,6 @@ def test_reduce_max():
def test_reduce_max_shared(): def test_reduce_max_shared():
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32") run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 96, 48, "float32")
def test_reduce_min_shared(): def test_reduce_min_shared():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment