Commit c30904ea authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix dynamic axis with variable extent (#311)

* [Enhancement] Improve error message for RampNode in CUDA codegen

- Updated the error message in the VisitExpr_ method for RampNode to include the specific Ramp node and lane count when the lane count exceeds the limit of 4. This change enhances debugging by providing clearer context for the error.
- Refactored the loop vectorization logic in loop_vectorize_dynamic.cc to improve readability and maintainability, ensuring that dynamic vectorization checks are performed correctly and efficiently.

* lint fix
parent 66c7f6a1
...@@ -1336,7 +1336,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) { ...@@ -1336,7 +1336,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) { void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed."; CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
<< lanes << " lanes is not allowed.";
os << "(make_"; os << "(make_";
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << "("; os << "(";
......
...@@ -272,13 +272,19 @@ private: ...@@ -272,13 +272,19 @@ private:
Stmt VisitStmt_(const ForNode *node) final { Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node; inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node); auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) { if (inner_for_ != node) {
return ret;
}
For fnode = ret.as<For>().value(); For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var; auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent); if (!fnode->extent.as<IntImmNode>()) {
int extent = *extent_ptr; return ret;
if (dynamic_) { // only vectorize with dynamic }
ICHECK(extent_ptr) << fnode->extent; int extent = Downcast<IntImm>(fnode->extent)->value;
if (!dynamic_) {
return fnode;
}
ICHECK(extent % vector_size_ == 0) ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_; << "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min)); ICHECK(is_zero(fnode->min));
...@@ -307,20 +313,13 @@ private: ...@@ -307,20 +313,13 @@ private:
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions); VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body); Stmt vectorize_body = mutator(body);
For vectorize_for = For(inner_var, 0, vector_size_, For vectorize_for =
ForKind::kVectorized, vectorize_body); For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
For serial_for = For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition_bound, vectorize_for, serial_for); body = IfThenElse(condition_bound, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span); fnode->thread_binding, fnode->annotations, fnode->span);
return body; return body;
} else {
return fnode;
}
} else {
return ret;
}
} }
const ForNode *inner_for_; const ForNode *inner_for_;
...@@ -341,6 +340,7 @@ class LoopVectorizerDynamic : public IRMutatorWithAnalyzer { ...@@ -341,6 +340,7 @@ class LoopVectorizerDynamic : public IRMutatorWithAnalyzer {
public: public:
static Stmt Substitute(Stmt stmt) { static Stmt Substitute(Stmt stmt) {
arith::Analyzer analyzer; arith::Analyzer analyzer;
LOG(INFO) << "LoopVectorizerDynamic Substitute";
LoopVectorizerDynamic substituter(&analyzer); LoopVectorizerDynamic substituter(&analyzer);
stmt = substituter.VisitStmt(stmt); stmt = substituter.VisitStmt(stmt);
return stmt; return stmt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment