Commit 473e2905 authored by Paul's avatar Paul
Browse files

Add lazy inner

parent c05e10f7
...@@ -260,6 +260,14 @@ struct reduce_op ...@@ -260,6 +260,14 @@ struct reduce_op
} }
}; };
static bool use_lazy_inner(instruction_ref ins)
{
if (ins->outputs().size() != 1)
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
std::string generate_reduce(const module& rm, const std::string& name) std::string generate_reduce(const module& rm, const std::string& name)
{ {
module m = rm; module m = rm;
...@@ -293,13 +301,15 @@ std::string generate_reduce(const module& rm, const std::string& name) ...@@ -293,13 +301,15 @@ std::string generate_reduce(const module& rm, const std::string& name)
if(tensors.empty()) if(tensors.empty())
return call_function; return call_function;
const std::string inner_template = const std::string inner_template =
"r.inner([=](${params}) { return ${call}; })(${args})"; "r.${inner}([=](${params}) { return ${call}; })(${args})";
std::string inner_name = use_lazy_inner(ins) ? "lazy_inner" : "inner";
auto args = cpp_generator::to_args(tensors, names); auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names); auto params = cpp_generator::to_args(tensors, inner_names);
std::transform( std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; }); params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template, return interpolate_string(inner_template,
{{"params", join_strings(params, ", ")}, {{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")}, {"args", join_strings(args, ", ")},
{"call", call_function}}); {"call", call_function}});
} }
...@@ -322,6 +332,8 @@ static std::vector<std::string> get_op_names(const module& m) ...@@ -322,6 +332,8 @@ static std::vector<std::string> get_op_names(const module& m)
{ {
if(starts_with(ins.name(), "@")) if(starts_with(ins.name(), "@"))
continue; continue;
if(ins.name() == "multibroadcast")
continue;
if(ins.name() == "pointwise") if(ins.name() == "pointwise")
{ {
auto names = get_op_names(*ins.module_inputs().front()); auto names = get_op_names(*ins.module_inputs().front());
......
...@@ -174,6 +174,25 @@ struct inner_storage_tag ...@@ -174,6 +174,25 @@ struct inner_storage_tag
template <class T> template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>; using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class Size, class F>
struct lazy_inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr lazy_inner_storage<Size, F> make_lazy_inner_storage(Size, F f)
{
return {{}, f};
}
template <class R, class F> template <class R, class F>
struct storage_access : F struct storage_access : F
{ {
...@@ -278,6 +297,14 @@ struct reducer_base ...@@ -278,6 +297,14 @@ struct reducer_base
}); });
} }
template <class F>
__device__ auto lazy_inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
});
}
template <class Op, class T, class Read> template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const __device__ auto reduce(Op op, T init, Read read) const
{ {
...@@ -396,25 +423,6 @@ struct lane ...@@ -396,25 +423,6 @@ struct lane
index idx; index idx;
Slicer slice; Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {f};
}
template <class Op, class T, class Read, class N, class U, class... Us> template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const __device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{ {
...@@ -445,7 +453,7 @@ struct lane ...@@ -445,7 +453,7 @@ struct lane
template <class R, class F, class N, class... Ts> template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const __device__ auto inner_impl(F f, N n, Ts&&... xs) const
{ {
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); }); return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
} }
}; };
template <class Slicer> template <class Slicer>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment