"...composable_kernel.git" did not exist on "d807d05e3a96acbc3dd7134cdab213e9f8168338"
Unverified Commit 0eb33f28 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dependency] Update apache-tvm-ffi version to >=0.1.2 (#1400)

* [Dependency] Update apache-tvm-ffi version to >=0.1.2 in project files

* [Dependency] Update subproject commit for TVM to latest version afc07935

* [Enhancement] Add support for optional step parameter in loop constructs

- Updated loop creation functions to accept an optional step parameter, enhancing flexibility in loop definitions.
- Modified ForFrame implementations to utilize the new step parameter across various loop types including serial, parallel, and pipelined loops.
- Adjusted related vectorization transformations to accommodate the step parameter, ensuring consistent behavior in loop vectorization processes.

* lint fix
parent 79d381d1
Subproject commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0
Subproject commit afc079350def46a78931c6edeb7bad3fb248b4e1
......@@ -31,7 +31,7 @@ dependencies = [
# Extra constraint to tvm-ffi for abi issue,
# should be removed after our tvm's update.
# See discussion in tilelang#1373 and apache/tvm-ffi#307
"apache-tvm-ffi<=0.1.1",
"apache-tvm-ffi>=0.1.2",
"cloudpickle",
"ml-dtypes",
"numpy>=1.23.5",
......
# Requirements to run local build with `--no-build-isolation` or other developments
apache-tvm-ffi~=0.1.0
apache-tvm-ffi>=0.1.2
build
cmake>=3.26
cython>=3.0.0
......
......@@ -44,16 +44,22 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt {
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt,
/*annotations=*/tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any>{},
/*step=*/step);
};
return ForFrame(n);
}
ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
const Map<String, tvm::ffi::Any> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
......@@ -63,16 +69,19 @@ ForFrame ParallelFor(const Array<PrimExpr> &extents,
n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent));
}
n->f_make_for_loop = [annotations](const Array<Var> &vars,
const Array<Range> &doms,
Stmt body) -> Stmt {
n->f_make_for_loop =
[annotations](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps, Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i];
Var var = vars[i];
Optional<PrimExpr> step =
i < steps.size() ? steps[i] : Optional<PrimExpr>(std::nullopt);
body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
/*thread_binding=*/std::nullopt, /*annotations=*/annotations);
/*thread_binding=*/std::nullopt, /*annotations=*/annotations,
/*step=*/step);
}
return body;
};
......@@ -90,11 +99,12 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
ICHECK(n == 1);
Map<String, ObjectRef> anno;
Map<String, tvm::ffi::Any> anno;
if (num_stages > 0)
anno.Set("num_stages", PrimExpr(num_stages));
if (!order.empty())
......@@ -105,8 +115,11 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
anno.Set("tl_pipeline_sync", sync);
if (!groups.empty())
anno.Set("tl_pipeline_group", groups);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno);
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
return body;
};
return ForFrame(n);
......@@ -145,9 +158,10 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
grouped_domain.push_back(group_size);
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt {
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
Map<String, ObjectRef> anno;
Map<String, tvm::ffi::Any> anno;
Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
PrimExpr rem = loop_var * wave_size + index;
......@@ -168,8 +182,11 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
if (analyzer.CanProveGreaterEqual(waves, 2)) {
new_body = SeqStmt({out_if, body});
}
Stmt outer =
For(loop_var, 0, waves, ForKind::kSerial, new_body, std::nullopt, anno);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
}
......
......@@ -203,7 +203,8 @@ private:
vmap.Set(old_var, new_var * vector_size_);
Stmt body = Substitute(fnode->body, vmap);
return For(new_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
fnode->thread_binding, fnode->annotations, fnode->step,
fnode->span);
}
}
return ret;
......
......@@ -232,7 +232,8 @@ private:
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
fnode->thread_binding, fnode->annotations, fnode->step,
fnode->span);
return body;
}
} else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment