"vscode:/vscode.git/clone" did not exist on "89725f7fda100d34097d7abd38adda3a69f617d8"
Unverified Commit 27701c3d authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Parallel] Support `T.Parallel` with dynamic extents (#990)

* Allow dynamic extents in loop partition; warn when layout inversion falls back to NoCheck

* add test and introduce predicate

* test fix

* fix

* enhance

* inverse with level

* test fix

* bug fix
parent d66b83c9
...@@ -229,11 +229,34 @@ Fragment FragmentNode::BindThreadRange(Range thread_range) const { ...@@ -229,11 +229,34 @@ Fragment FragmentNode::BindThreadRange(Range thread_range) const {
return Fragment(n); return Fragment(n);
} }
Layout LayoutNode::Inverse() const { std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
arith::Analyzer analyzer; arith::Analyzer analyzer;
auto collect_symbolic = [&](const Array<PrimExpr> &shape) {
Array<PrimExpr> symbolic_dims;
for (const auto &dim : shape) {
if (!as_const_int(dim)) {
symbolic_dims.push_back(dim);
}
}
return symbolic_dims;
};
Array<PrimExpr> symbolic_dims = collect_symbolic(input_size_);
Array<PrimExpr> output_shape = OutputShape();
symbolic_dims.insert(symbolic_dims.end(), output_shape.begin(),
output_shape.end());
symbolic_dims = collect_symbolic(symbolic_dims);
bool is_static_shape = symbolic_dims.empty();
auto level = is_static_shape ? arith::IterMapLevel::Bijective
: arith::IterMapLevel::NoCheck;
if (!is_static_shape) {
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and
// warn.
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: "
<< symbolic_dims;
}
arith::IterMapResult res = arith::IterMapResult res =
arith::DetectIterMap(forward_index_, getVarMap(), 1, arith::DetectIterMap(forward_index_, getVarMap(), 1, level, &analyzer);
arith::IterMapLevel::Bijective, &analyzer);
ICHECK(res->errors.empty()) ICHECK(res->errors.empty())
<< "Layout " << DebugOutput() << " has errors: " << res->errors; << "Layout " << DebugOutput() << " has errors: " << res->errors;
...@@ -254,9 +277,13 @@ Layout LayoutNode::Inverse() const { ...@@ -254,9 +277,13 @@ Layout LayoutNode::Inverse() const {
} }
} }
return Layout(outputs_shape, backward_index); return {Layout(outputs_shape, backward_index), level};
} }
Layout LayoutNode::Inverse() const {
auto inverse_result = InverseWithLevel();
return std::move(inverse_result.first);
}
PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters, PrimExpr infer_fragment_index(const Map<Var, Range> &input_iters,
const PrimExpr &forward_thread, const PrimExpr &forward_thread,
arith::Analyzer *analyzer) { arith::Analyzer *analyzer) {
...@@ -366,6 +393,11 @@ PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars, ...@@ -366,6 +393,11 @@ PrimExpr FragmentNode::ForwardThread(const Array<PrimExpr> &vars,
} }
Layout FragmentNode::Inverse() const { Layout FragmentNode::Inverse() const {
auto result = InverseWithLevel();
return std::move(result.first);
}
std::pair<Layout, arith::IterMapLevel> FragmentNode::InverseWithLevel() const {
auto input_size_copy = input_size_; auto input_size_copy = input_size_;
input_size_copy.push_back(ReplicateExtent()); input_size_copy.push_back(ReplicateExtent());
auto forward_index_copy = forward_index_; auto forward_index_copy = forward_index_;
...@@ -373,8 +405,7 @@ Layout FragmentNode::Inverse() const { ...@@ -373,8 +405,7 @@ Layout FragmentNode::Inverse() const {
Substitute(forward_thread_, Substitute(forward_thread_,
{{ReplicationPlaceholder(), InputPlaceholder(InputDim())}})); {{ReplicationPlaceholder(), InputPlaceholder(InputDim())}}));
auto fwd = Layout(input_size_copy, forward_index_copy); auto fwd = Layout(input_size_copy, forward_index_copy);
auto bwd = fwd->Inverse(); return fwd->InverseWithLevel();
return bwd;
} }
Fragment FragmentNode::CondenseReplicateVar() const { Fragment FragmentNode::CondenseReplicateVar() const {
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
#define TVM_TL_LAYOUT_LAYOUT_H_ #define TVM_TL_LAYOUT_LAYOUT_H_
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/arith/iter_affine_map.h>
#include <utility>
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -36,6 +38,7 @@ public: ...@@ -36,6 +38,7 @@ public:
virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const; virtual Array<PrimExpr> Forward(const Array<PrimExpr> &vars) const;
virtual Layout Inverse() const; virtual Layout Inverse() const;
virtual std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const;
virtual std::string DebugOutput() const; virtual std::string DebugOutput() const;
...@@ -76,6 +79,7 @@ public: ...@@ -76,6 +79,7 @@ public:
Array<PrimExpr> GetForwardVars() const final; Array<PrimExpr> GetForwardVars() const final;
Layout Inverse() const final; Layout Inverse() const final;
std::pair<Layout, arith::IterMapLevel> InverseWithLevel() const final;
PrimExpr ThreadExtent() const; PrimExpr ThreadExtent() const;
......
...@@ -64,28 +64,88 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -64,28 +64,88 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
ICHECK(thread_var.defined()); ICHECK(thread_var.defined());
int old_loop_depth = loop_layout->InputDim(); int old_loop_depth = loop_layout->InputDim();
int new_loop_depth = loop_layout->OutputDim(); int new_loop_depth = loop_layout->OutputDim();
// Create the new loop iter var // Create the new loop iter var
Array<Var> vars; Array<Var> vars;
for (int i = 0; i < new_loop_depth; i++) { for (int i = 0; i < new_loop_depth; i++) {
Var var = Var(std::string{char('i' + i)}); Var var = Var(std::string{char('i' + i)});
analyzer->Bind(var, Range::FromMinExtent(make_zero(var->dtype),
loop_layout->OutputShape()[i]));
vars.push_back(var); vars.push_back(var);
} }
vars.push_back(thread_var); vars.push_back(thread_var);
// create the substitute map, and the loop body // create the substitute map, and the loop body
Map<Var, PrimExpr> vmap; Map<Var, PrimExpr> vmap;
Stmt body = std::move(op); Stmt body = std::move(op);
auto inv_loop = loop_layout->Inverse(); Array<PrimExpr> loop_mins;
Array<PrimExpr> loop_extents;
auto inverse_info = loop_layout->InverseWithLevel();
auto inv_loop = inverse_info.first;
// Must check the guard if the layout can not be proved as bijective
bool need_guard = inverse_info.second != arith::IterMapLevel::Bijective;
auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end())); auto indices = inv_loop->Forward(Array<PrimExpr>(vars.begin(), vars.end()));
// Normalize thread var once so we can reuse the same substitution later.
Map<Var, PrimExpr> thread_offset_map;
bool has_thread_offset = false;
if (loop_layout->ThreadRange().defined()) {
auto range = loop_layout->ThreadRange();
thread_offset_map.Set(thread_var, thread_var - range->min);
has_thread_offset = true;
}
for (int i = 0; i < old_loop_depth; i++) { for (int i = 0; i < old_loop_depth; i++) {
const ForNode *loop = body.as<ForNode>(); const ForNode *loop = body.as<ForNode>();
ICHECK(loop != nullptr); ICHECK(loop != nullptr);
vmap.Set(loop->loop_var, indices[i]); vmap.Set(loop->loop_var, indices[i]);
loop_mins.push_back(loop->min);
loop_extents.push_back(loop->extent);
body = loop->body; body = loop->body;
} }
// substitute and re-construct the serial loop // substitute and re-construct the serial loop
body = Substitute(body, vmap); body = Substitute(body, vmap);
// Guard executes the recovered loop body only if each inverse-mapped iterator
// falls back into the original For ranges. We first check every axis from the
// old loop nest (old_loop_depth) and then the extra index produced by inverse
// layouts that carry a replicate/thread component (`inv_output_shape`). Both
// must stay within bounds to ensure correctness. Example: layout([i, j]) =
// floor((i * 16 + j) / 32) may generate extra points when the new loop
// enumerates 0..31; the guard drops iterations whose inverse-mapped (i, j)
// or replicate index fall outside their original extents.
// Example: layout([i, j]) = floor((i * 16 + j) / 32) may produce extra points
// when the new loop enumerates 0..31; this guard skips iterations where the
// inverse i, j land outside the original extents. This protects
// non-surjective loop_layout mappings that otherwise over-cover the parallel
// space.
PrimExpr guard = const_true();
if (need_guard) {
for (int i = 0; i < old_loop_depth; i++) {
PrimExpr index = indices[i];
if (has_thread_offset) {
index = Substitute(index, thread_offset_map);
}
PrimExpr lower_bound = analyzer->Simplify(index >= loop_mins[i]);
PrimExpr upper_bound =
analyzer->Simplify(index < loop_mins[i] + loop_extents[i]);
guard = And(guard, And(lower_bound, upper_bound));
}
auto inv_output_shape = inv_loop->OutputShape();
if (inv_output_shape.size() > static_cast<size_t>(old_loop_depth)) {
PrimExpr replicate_index = indices[old_loop_depth];
if (has_thread_offset) {
replicate_index = Substitute(replicate_index, thread_offset_map);
}
PrimExpr replicate_extent = inv_output_shape[old_loop_depth];
PrimExpr lower_bound = analyzer->Simplify(
replicate_index >= make_zero(replicate_index.dtype()));
PrimExpr upper_bound =
analyzer->Simplify(replicate_index < replicate_extent);
guard = And(guard, And(lower_bound, upper_bound));
}
PrimExpr simplified_guard = analyzer->Simplify(guard);
if (!analyzer->CanProve(simplified_guard)) {
body = IfThenElse(simplified_guard, body, Stmt());
}
}
for (int i = new_loop_depth - 1; i >= 0; i--) { for (int i = new_loop_depth - 1; i >= 0; i--) {
body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i], body = For(vars[i], make_zero(vars[i]->dtype), inv_loop->InputShape()[i],
ForKind::kSerial, body); ForKind::kSerial, body);
...@@ -94,13 +154,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer, ...@@ -94,13 +154,11 @@ For PartitionLoop(For op, Var thread_var, arith::Analyzer *analyzer,
body = BufferIndiceSimplify(analyzer)(body); body = BufferIndiceSimplify(analyzer)(body);
auto for_node = LoopPragmaUnroll(Downcast<For>(body)); if (has_thread_offset) {
if (loop_layout->ThreadRange().defined()) { body = Substitute(body, thread_offset_map);
auto range = loop_layout->ThreadRange();
auto thread_var_with_offset = thread_var - range->min;
for_node.CopyOnWrite()->body =
Substitute(for_node->body, {{thread_var, thread_var_with_offset}});
} }
auto for_node = LoopPragmaUnroll(Downcast<For>(body));
return for_node; return for_node;
} }
...@@ -111,6 +169,10 @@ public: ...@@ -111,6 +169,10 @@ public:
private: private:
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
if (node->kind == ForKind::kSerial) { if (node->kind == ForKind::kSerial) {
auto analyzer = std::make_shared<arith::Analyzer>();
if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
return StmtExprMutator::VisitStmt_(node);
}
For new_for = GetRef<For>(node); For new_for = GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite(); auto for_ptr = new_for.CopyOnWrite();
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false)); for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
...@@ -127,22 +189,20 @@ public: ...@@ -127,22 +189,20 @@ public:
Fragment Partition(const For &op, int num_thread, int vectorize_size) { Fragment Partition(const For &op, int num_thread, int vectorize_size) {
this->VisitStmt(op); this->VisitStmt(op);
int loop_size_full = 1; ICHECK(!loop_vars_.empty());
PrimExpr flattened = 0; DataType dtype = loop_vars_[0]->var.dtype();
PrimExpr flattened = make_const(dtype, 0);
PrimExpr vector_extent = make_const(dtype, vectorize_size);
PrimExpr thread_extent_const = make_const(dtype, num_thread);
for (size_t i = 0; i < loop_vars_.size(); i++) { for (size_t i = 0; i < loop_vars_.size(); i++) {
auto ext_ptr = as_const_int(loop_vars_[i]->dom->extent); PrimExpr extent = loop_vars_[i]->dom->extent;
ICHECK(ext_ptr)
<< "Loop partitioner only works with constant loop sizes, but got "
<< loop_vars_[i]->dom->extent;
int extent = *ext_ptr;
loop_size_full *= extent;
flattened = flattened * extent + loop_vars_[i]->var; flattened = flattened * extent + loop_vars_[i]->var;
} }
ICHECK(loop_size_full % vectorize_size == 0); PrimExpr access_idx = FloorDiv(flattened, vector_extent);
PrimExpr access_idx = FloorDiv(flattened, vectorize_size); PrimExpr thd = FloorMod(access_idx, thread_extent_const);
PrimExpr thd = FloorMod(access_idx, num_thread); PrimExpr idx = FloorDiv(access_idx, thread_extent_const) * vector_extent +
PrimExpr idx = FloorDiv(access_idx, num_thread) * vectorize_size + FloorMod(flattened, vector_extent);
FloorMod(flattened, vectorize_size);
auto fragment = Fragment(loop_vars_, {idx}, {thd}, {}); auto fragment = Fragment(loop_vars_, {idx}, {thd}, {});
if (has_fragment_) { if (has_fragment_) {
// for fragment buffer, we don't need to replicate the loop layout // for fragment buffer, we don't need to replicate the loop layout
......
...@@ -94,7 +94,7 @@ public: ...@@ -94,7 +94,7 @@ public:
private: private:
void VisitStmt_(const ForNode *node) final { void VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
auto extent_ptr = as_const_int(node->extent); auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
// Here I disable dynamic shape completely, // Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with // In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size // arithmetic info outside to prove the dividiblity of vector size
......
import tilelang
import tilelang.language as T
import torch
import tilelang.testing
import pytest
tilelang.testing.set_random_seed()
@tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length):
B[i] = A[i] + 1.0
return main
@tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):
@T.prim_func
def main(
A: T.Tensor((max_len,), dtype),
B: T.Tensor((max_len,), dtype),
valid_len: T.int32,
):
with T.Kernel(1, threads=threads) as _:
for i in T.Parallel(max_len):
B[i] = 0.0
span = T.min(valid_len, max_len)
for i in T.Parallel(span):
B[i] = A[i] - 1.0
return main
def _require_cuda_tensor(shape, dtype=torch.float32):
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
try:
return torch.randn(*shape, device="cuda", dtype=dtype)
except RuntimeError as err:
pytest.skip(f"CUDA runtime unavailable: {err}")
def test_parallel_static_extent():
kernel = parallel_elementwise_static(length=256)
data = _require_cuda_tensor((256,), torch.float32)
result = kernel(data)
torch.testing.assert_close(result, data + 1.0, atol=1e-5, rtol=1e-5)
def test_parallel_dynamic_extent():
kernel = parallel_elementwise_dynamic(max_len=512, threads=256)
data = _require_cuda_tensor((512,), torch.float32)
for valid_len in [0, 13, 200, 600]:
out = kernel(data, valid_len)
reference = torch.zeros_like(data)
clip = min(valid_len, data.shape[0])
reference[:clip] = data[:clip] - 1.0
torch.testing.assert_close(out, reference, atol=1e-5, rtol=1e-5)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment