Commit c30904ea authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix dynamic axis with variable extent (#311)

* [Enhancement] Improve error message for RampNode in CUDA codegen

- Updated the error message in the VisitExpr_ method for RampNode to include the specific Ramp node and lane count when the lane count exceeds the limit of 4. This change enhances debugging by providing clearer context for the error.
- Refactored the loop vectorization logic in loop_vectorize_dynamic.cc to improve readability and maintainability, ensuring that dynamic vectorization checks are performed correctly and efficiently.

* lint fix
parent 66c7f6a1
......@@ -1336,7 +1336,8 @@ void CodeGenTileLangCUDA::VisitStmt_(const AllocateNode *op) {
void CodeGenTileLangCUDA::VisitExpr_(const RampNode *op, std::ostream &os) {
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not allowed.";
CHECK_LE(lanes, 4) << "Translate Ramp Node " << GetRef<Ramp>(op) << " with "
<< lanes << " lanes is not allowed.";
os << "(make_";
PrintType(op->dtype, os);
os << "(";
......
......@@ -272,55 +272,54 @@ private:
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ == node) {
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
auto extent_ptr = as_const_int(fnode->extent);
int extent = *extent_ptr;
if (dynamic_) { // only vectorize with dynamic
ICHECK(extent_ptr) << fnode->extent;
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
VectorizedConditionExtracter extracter;
std::vector<PrimExpr> conditions = extracter.GetConditions(body);
// Set vectorize variable to the max value of the extent (i.e.
// vector_size_ - 1)
PrimExpr condition = conditions[0];
for (int i = 1; i < conditions.size(); ++i) {
condition = condition && conditions[i];
}
// add condition ifthenelse here
Map<Var, PrimExpr> vmap_condition;
vmap_condition.Set(inner_var, vector_size_ - 1);
PrimExpr condition_bound = Substitute(condition, vmap_condition);
// modify body in the vectorized loop
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body);
For vectorize_for = For(inner_var, 0, vector_size_,
ForKind::kVectorized, vectorize_body);
For serial_for =
For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition_bound, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
} else {
return fnode;
}
} else {
if (inner_for_ != node) {
return ret;
}
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
if (!fnode->extent.as<IntImmNode>()) {
return ret;
}
int extent = Downcast<IntImm>(fnode->extent)->value;
if (!dynamic_) {
return fnode;
}
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
VectorizedConditionExtracter extracter;
std::vector<PrimExpr> conditions = extracter.GetConditions(body);
// Set vectorize variable to the max value of the extent (i.e.
// vector_size_ - 1)
PrimExpr condition = conditions[0];
for (int i = 1; i < conditions.size(); ++i) {
condition = condition && conditions[i];
}
// add condition ifthenelse here
Map<Var, PrimExpr> vmap_condition;
vmap_condition.Set(inner_var, vector_size_ - 1);
PrimExpr condition_bound = Substitute(condition, vmap_condition);
// modify body in the vectorized loop
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body);
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
body = IfThenElse(condition_bound, vectorize_for, serial_for);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
const ForNode *inner_for_;
......@@ -341,6 +340,7 @@ class LoopVectorizerDynamic : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt) {
arith::Analyzer analyzer;
LOG(INFO) << "LoopVectorizerDynamic Substitute";
LoopVectorizerDynamic substituter(&analyzer);
stmt = substituter.VisitStmt(stmt);
return stmt;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment