Unverified Commit f201285c authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Add reduction fusion (#1614)

Automatically fuse multiple reductions and pointwise operations.
parent a123cb2e
......@@ -50,6 +50,7 @@ add_library(migraphx
env.cpp
file_buffer.cpp
fuse_pointwise.cpp
fuse_reduce.cpp
generate.cpp
inline_module.cpp
insert_pad.cpp
......
......@@ -106,6 +106,13 @@ cpp_generator::function& cpp_generator::function::set_generic_types(const module
return *this;
}
cpp_generator::function& cpp_generator::function::add_generic_param(const std::string& pname)
{
params.push_back({pname, "T" + pname});
tparams.push_back("class T" + pname);
return *this;
}
struct cpp_generator_impl
{
std::stringstream fs{};
......@@ -182,7 +189,8 @@ std::string cpp_generator::generate_point_op(const operation& op,
std::string cpp_generator::str() const { return impl->fs.str(); }
cpp_generator::function cpp_generator::generate_module(const module& m)
cpp_generator::function cpp_generator::generate_module(const module& m,
const generate_module_callback& g)
{
function f;
auto name = transform_string(m.name(), [](char c) {
......@@ -195,13 +203,7 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
if(ins->name() == "@literal")
return shape::cpp_type(ins->get_shape().type()) + "(" +
ins->get_literal().to_string() + ")";
std::vector<std::string> args;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(args),
[&](auto i) { return names.at(i); });
auto s = this->generate_point_op(ins->get_operator(), args);
auto s = g(ins, names);
if(impl->fresult)
return impl->fresult(ins->get_shape()) + '(' + s + ')';
else
......@@ -210,6 +212,24 @@ cpp_generator::function cpp_generator::generate_module(const module& m)
return f;
}
std::vector<std::string>
cpp_generator::to_args(const std::vector<instruction_ref>& inputs,
const std::unordered_map<instruction_ref, std::string>& names)
{
std::vector<std::string> args;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(args), [&](auto i) {
return names.at(i);
});
return args;
}
cpp_generator::function cpp_generator::generate_module(const module& m)
{
return this->generate_module(m, [&](auto ins, const auto& names) {
return this->generate_point_op(ins->get_operator(), to_args(ins->inputs(), names));
});
}
std::string cpp_generator::create_function(const cpp_generator::function& f)
{
impl->function_count++;
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/register_op.hpp>
#include <iterator>
#include <map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct fused_reduce
{
std::vector<std::int64_t> axes{};
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axes, "axes"));
}
shape compute_shape(const std::vector<shape>& inputs, std::vector<module_ref> mods) const
{
if(mods.size() != 1)
MIGRAPHX_THROW("should have one submodule.");
auto* sm = mods.front();
if(sm->get_output_shapes().size() != 1)
MIGRAPHX_THROW("Only one output supported");
auto names = sm->get_parameter_names();
check_shapes{inputs, *this}.has(names.size()).same_ndims();
std::sort(names.begin(), names.end());
auto shapes = sm->get_parameter_shapes();
// Check dimension matches for each input
if(not equal(names, inputs, [&](const auto& name, const auto& input) {
return shapes.at(name).lens() == input.lens();
}))
MIGRAPHX_THROW("Dimenstion does not match the submodule.");
const auto& s = inputs.at(0);
auto lens = s.lens();
if(lens != sm->get_output_shapes().front().lens())
{
for(const auto& axis : axes)
{
lens[axis] = 1;
}
}
return shape::from_permutation(
sm->get_output_shapes().front().type(), lens, find_permutation(inputs));
}
std::string name() const { return "fused_reduce"; }
};
MIGRAPHX_REGISTER_OP(fused_reduce);
static std::unordered_map<instruction_ref, instruction_ref>
get_ins_param_map(const std::vector<instruction_ref>& inputs, const_module_ref sm)
{
std::unordered_map<instruction_ref, instruction_ref> result;
auto names = sm->get_parameter_names();
std::sort(names.begin(), names.end());
assert(names.size() == inputs.size());
std::transform(names.begin(),
names.end(),
inputs.begin(),
std::inserter(result, result.end()),
[&](const auto& name, auto input) {
return std::make_pair(input, sm->get_parameter(name));
});
return result;
}
static void insert_params(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
auto n = sm->get_parameter_shapes().size();
for(auto input : ins->inputs())
{
if(contains(map_ins, input))
continue;
auto s = shape{input->get_shape().type(), input->get_shape().lens()};
map_ins[input] = sm->add_parameter("x" + std::to_string(n++), s);
}
}
static auto insert_ins_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
return sm->add_instructions({ins}, map_ins);
}
static auto insert_ins_in_submodule(module_ref sm, instruction_ref ins)
{
std::unordered_map<instruction_ref, instruction_ref> map_ins;
return insert_ins_in_submodule(sm, ins, map_ins);
}
static auto
insert_module_in_submodule(module_ref sm,
instruction_ref ins,
std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
insert_params(sm, ins, map_ins);
auto* m = ins->module_inputs().front();
auto param_map = get_ins_param_map(ins->inputs(), m);
for(auto&& [input, param] : param_map)
{
map_ins[param] = map_ins.at(input);
}
return sm->add_instructions(m, map_ins);
}
static std::vector<instruction_ref>
find_inputs(module_ref sm,
const module& parent,
const std::unordered_map<instruction_ref, instruction_ref>& map_ins)
{
std::vector<instruction_ref> result;
std::map<std::string, instruction_ref> names;
for(auto&& [input, param] : map_ins)
{
if(not sm->has_instruction(param))
continue;
if(param->name() != "@param")
continue;
if(not parent.has_instruction(input))
continue;
auto v = param->get_operator().to_value();
auto name = v.at("parameter").to<std::string>();
names[name] = input;
}
std::transform(names.begin(), names.end(), std::back_inserter(result), [](const auto& p) {
return p.second;
});
assert(result.size() == sm->get_parameter_shapes().size());
return result;
}
static void create_reduce_modules(module_pass_manager& mpm)
{
std::size_t n = 0;
for(auto ins : iterator_for(mpm.get_module()))
{
if(not ins->get_operator().attributes().get("reduce", false))
continue;
if(ins->inputs().size() != 1)
continue;
auto* rm =
mpm.create_module(mpm.get_module().name() + ":" + ins->name() + std::to_string(n++));
rm->set_bypass();
rm->add_return(insert_ins_in_submodule(rm, ins));
auto v = ins->get_operator().to_value();
mpm.get_module().replace_instruction(
ins, make_op("fused_reduce", {{"axes", v["axes"]}}), ins->inputs(), {rm});
}
}
template <class... Ms>
static auto match_broadcast(Ms... ms)
{
return match::skip(match::name("contiguous"))(
match::name("multibroadcast")(match::arg(0)(ms...), match::used_once()).bind("broadcast"));
}
template <class... Ms>
static auto any_input(Ms... ms)
{
return match::any_of[match::inputs()](match::any(ms...).bind("input"));
}
static auto match_broadcastable_input(const std::string& op, const std::string& name)
{
auto match_op = match::name(op)(match::used_once()).bind(name);
auto match_op_input = any_input(match_op, match::used_once());
auto broadcast_match_op_input = any_input(match_broadcast(match_op), match::used_once());
return match::any_of(match_op_input, broadcast_match_op_input);
}
namespace {
struct find_pointwise_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match_broadcastable_input("pointwise", "pointwise"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce = r.result;
auto input = r.instructions["pointwise"];
const auto* pm = input->module_inputs().front();
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(pm->name() + ":" + old_rm->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Insert pointwise
auto rins = insert_ins_in_submodule(rm, input, map_ins).front();
map_ins[input] = rins;
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast] = insert_ins_in_submodule(rm, broadcast, map_ins).front();
}
// Insert fused_reduce
rm->add_return(insert_module_in_submodule(rm, reduce, map_ins));
auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce, reduce->get_operator(), new_inputs, {rm});
}
};
struct find_reduce_pointwise
{
auto matcher() const
{
return match::name("pointwise")(match_broadcastable_input("fused_reduce", "reduce"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto pw = r.result;
auto reduce = r.instructions["reduce"];
auto input = r.instructions["input"];
const auto* pm = pw->module_inputs().front();
const auto* old_rm = reduce->module_inputs().front();
auto* rm = mpm.create_module(old_rm->name() + ":" + pm->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy module instructions
insert_module_in_submodule(rm, reduce, map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}
auto out = insert_ins_in_submodule(rm, pw, map_ins);
rm->replace_return(out);
auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(pw, reduce->get_operator(), new_inputs, {rm});
}
};
struct find_reduce_reduce
{
auto matcher() const
{
return match::name("fused_reduce")(match_broadcastable_input("fused_reduce", "reduce"));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto reduce1 = r.result;
auto reduce2 = r.instructions["reduce"];
auto input = r.instructions["input"];
if(reduce1->get_operator() != reduce2->get_operator())
return;
const auto* rm1 = reduce1->module_inputs().front();
const auto* rm2 = reduce2->module_inputs().front();
auto* rm = mpm.create_module(rm1->name() + ":" + rm2->name());
rm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> map_ins;
// Copy reduce1 instructions
insert_module_in_submodule(rm, reduce2, map_ins);
if(contains(r.instructions, "broadcast"))
{
auto broadcast = r.instructions["broadcast"];
map_ins[broadcast->inputs().front()] = rm->get_returns().front();
auto bout = insert_ins_in_submodule(rm, broadcast, map_ins);
map_ins[input] = bout.front();
}
else
{
map_ins[input] = rm->get_returns().front();
}
auto out = insert_module_in_submodule(rm, reduce1, map_ins);
rm->replace_return(out);
auto new_inputs = find_inputs(rm, mpm.get_module(), map_ins);
mpm.get_module().replace_instruction(reduce1, reduce1->get_operator(), new_inputs, {rm});
}
};
} // namespace
void fuse_reduce::apply(module_pass_manager& mpm) const
{
create_reduce_modules(mpm);
mpm.run_pass(dead_code_elimination{});
for(int i = 0; i < 4; i++)
{
match::find_matches(
mpm, find_reduce_pointwise{}, find_pointwise_reduce{}, find_reduce_reduce{});
mpm.run_pass(dead_code_elimination{});
}
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -77,6 +77,7 @@ struct cpp_generator
function& set_types(const module& m);
function& set_types(const module& m, const std::function<std::string(shape)>& parse);
function& set_generic_types(const module& m);
function& add_generic_param(const std::string& pname);
};
cpp_generator();
......@@ -105,6 +106,10 @@ struct cpp_generator
std::string create_function(const function& f);
static std::vector<std::string>
to_args(const std::vector<instruction_ref>& inputs,
const std::unordered_map<instruction_ref, std::string>& names);
private:
std::unique_ptr<cpp_generator_impl> impl;
};
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_REDUCE_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
struct fuse_reduce
{
std::string name() const { return "fuse_reduce"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_POINTWISE_HPP
......@@ -178,6 +178,8 @@ struct module
bool has_instruction(instruction_ref ins) const;
std::vector<instruction_ref> get_returns() const;
std::size_t size() const;
instruction_ref begin() const;
instruction_ref end() const;
......
......@@ -91,7 +91,7 @@ struct reduce_op : op_name<Derived>
{
value normalize;
normalize["axes"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
return {{"normalize_axes", normalize}, {"reduce", true}};
}
std::vector<int64_t> tune_axes(std::size_t n_dim) const
......
......@@ -595,6 +595,14 @@ std::vector<shape> module::get_output_shapes() const
}
}
std::vector<instruction_ref> module::get_returns() const
{
auto last = std::prev(this->end());
if(last->name() == "@return")
return last->inputs();
return {last};
}
instruction_ref module::validate() const
{
return std::find_if(
......
......@@ -168,7 +168,7 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", ");
}
std::string generate_pointwise(const module& pm, const std::string& name)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
......@@ -184,8 +184,131 @@ std::string generate_pointwise(const module& pm, const std::string& name)
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
gg.create_function(g.generate_module(m)
.set_attributes({"__device__", "__attribute__((const))"})
.set_generic_types(m)
.set_name(name));
}
std::string generate_pointwise(const module& pm, const std::string& name)
{
cpp_generator g;
generate_pointwise(g, pm, name);
return g.str();
}
std::string reduce_op::str() const
{
return write + "(r.reduce(" + reduction + ", " + init + ", " + read + ")(" + input + "))";
}
void reduce_op::set(instruction_ref ins, const operation& op)
{
if(op.name() == "reduce_sum")
{
reduction = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto s = ins->inputs().front()->get_shape();
auto reduce_elements = s.elements() / ins->get_shape().elements();
auto reduce_type = s.type();
reduction = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
read = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type}, reduce_type))
read = mean;
else
write = mean;
}
else if(op.name() == "reduce_max")
{
reduction = "op::max{}";
init = "lowest{}";
}
else if(op.name() == "reduce_min")
{
reduction = "op::min{}";
init = "highest{}";
}
else if(op.name() == "reduce_prod")
{
reduction = "op::product{}";
init = "1";
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
}
}
std::string reduce_op::generate(instruction_ref ins, const std::string& x)
{
reduce_op r{x};
r.set(ins, ins->get_operator());
return r.str();
}
static bool use_lazy_inner(instruction_ref ins)
{
if(ins->outputs().size() != 1)
return false;
auto output = ins->outputs().front();
return contains(output->name(), "reduce") or output->name() == "@return";
}
std::string generate_reduce(const module& m, const std::string& name)
{
cpp_generator g;
auto ilens = m.get_parameter_shapes().begin()->second.lens();
std::size_t i = 0;
auto f = g.generate_module(m, [&](instruction_ref ins, const auto& names) {
if(contains(ins->name(), "reduce"))
{
return reduce_op::generate(ins, names.at(ins->inputs().front()));
}
else if(ins->name() == "pointwise")
{
auto pointwise_name = "pointwise" + std::to_string(i);
i++;
generate_pointwise(g, *ins->module_inputs().front(), pointwise_name);
std::vector<instruction_ref> tensors;
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(tensors),
[&](auto input) {
return input->get_shape().lens() == ilens and
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function =
pointwise_name + "(" +
join_strings(cpp_generator::to_args(ins->inputs(), inner_names), ", ") + ")";
if(tensors.empty())
return call_function;
const std::string inner_template =
"r.${inner}([=](${params}) { return ${call}; })(${args})";
std::string inner_name = use_lazy_inner(ins) ? "lazy_inner" : "inner";
auto args = cpp_generator::to_args(tensors, names);
auto params = cpp_generator::to_args(tensors, inner_names);
std::transform(
params.begin(), params.end(), params.begin(), [](auto s) { return "auto " + s; });
return interpolate_string(inner_template,
{{"inner", inner_name},
{"params", join_strings(params, ", ")},
{"args", join_strings(args, ", ")},
{"call", call_function}});
}
else if(ins->name() == "multibroadcast")
{
return names.at(ins->inputs().front());
}
MIGRAPHX_THROW("Unknown operator: " + ins->name());
});
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
g.create_function(f);
return g.str();
}
......@@ -196,8 +319,18 @@ static std::vector<std::string> get_op_names(const module& m)
{
if(starts_with(ins.name(), "@"))
continue;
if(ins.name() == "multibroadcast")
continue;
if(ins.name() == "pointwise")
{
auto names = get_op_names(*ins.module_inputs().front());
result.insert(result.end(), names.begin(), names.end());
}
else
{
result.push_back(ins.name());
}
}
return result;
}
......
......@@ -26,6 +26,7 @@
#include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <migraphx/instruction_ref.hpp>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -34,6 +35,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct shape;
struct operation;
namespace gpu {
......@@ -72,8 +74,23 @@ std::string make_transformer_args(Ts... xs)
std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_reduce(const module& m, const std::string& name);
std::string generate_name_from_ops(const module& m);
struct reduce_op
{
std::string input = "";
std::string reduction = "";
std::string init = "0";
std::string read = "op::id{}";
std::string write = "op::id{}";
void set(instruction_ref ins, const operation& op);
std::string str() const;
static std::string generate(instruction_ref ins, const std::string& x);
};
} // namespace gen
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -71,6 +71,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
std::size_t compute_block_size(std::size_t n, std::size_t max_block_size = 1024);
std::string generate_make_shape(const shape& s);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
......@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p)
)__migraphx__";
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
static std::size_t get_reduce_elements(const std::vector<instruction_ref>& inputs)
{
return get_reduce_elements(to_shapes(inputs));
}
static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>& input_lens,
const std::vector<std::size_t>& output_lens)
{
......@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return reduce_lens;
}
static std::string get_reduce_algo(const std::vector<shape>& inputs)
template <class T>
static shape get_reduced_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
std::fill(lens.begin(), lens.end(), 1);
for(const auto& axis : axes)
lens[axis] = s.lens()[axis];
return shape{s.type(), lens};
}
template <class T>
static shape get_output_shape(const shape& s, const std::vector<T>& axes)
{
auto lens = s.lens();
for(const auto& axis : axes)
lens[axis] = 1;
return shape{s.type(), lens};
}
template <class ReduceLens>
static std::string get_reduce_algo(const std::vector<shape>& inputs, ReduceLens rlens)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
const auto init = std::numeric_limits<std::size_t>::max();
// The minimum stride
auto min_stride = std::inner_product(
......@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
return "block";
}
struct reduce_compiler : compiler<reduce_compiler>
static std::string get_reduce_algo(const std::vector<shape>& inputs)
{
auto rlens = get_reduce_lens(inputs.front().lens(), inputs.back().lens());
return get_reduce_algo(inputs, rlens);
}
struct simple_reduce_compiler : compiler<simple_reduce_compiler>
{
std::vector<std::string> names() const
{
return {"reduce", "reduce_sum", "reduce_mean", "reduce_max", "reduce_min", "reduce_prod"};
return {"simple_reduce",
"reduce_sum",
"reduce_mean",
"reduce_max",
"reduce_min",
"reduce_prod"};
}
static std::size_t get_reduce_elements(const std::vector<shape>& inputs)
{
return inputs.front().elements() / inputs.back().elements();
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
......@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
value v = value::object{};
if(op.name() == "reduce_sum")
{
v["reduction"] = "op::sum{}";
}
else if(op.name() == "reduce_mean")
{
auto reduce_elements = get_reduce_elements(ins->inputs());
auto reduce_type = ins->inputs().front()->get_shape().type();
v["reduction"] = "op::sum{}";
std::string mean = "op::mean<" + std::to_string(reduce_elements) + ">{}";
// Use float accumulator when reduction size is too large for half
if(reduce_type == shape::half_type and reduce_elements > 16384)
v["read"] = "compose(" + mean + ", op::convert_to<float>{})";
else if(contains({shape::float_type, shape::half_type, shape::double_type},
reduce_type))
v["read"] = mean;
else
v["write"] = mean;
reduce_op r{};
r.set(ins, op);
v["reduction"] = r.reduction;
v["read"] = r.read;
v["write"] = r.write;
v["init"] = r.init;
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
else if(op.name() == "reduce_max")
};
static const char* const fused_reduce_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_reduce_compiler : compiler<fused_reduce_compiler>
{
std::vector<std::string> names() const { return {"fused_reduce"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
v["reduction"] = "op::max{}";
v["init"] = "lowest{}";
}
else if(op.name() == "reduce_min")
auto axes = v.at("axes").to_vector<std::size_t>();
auto virtual_inputs = inputs;
virtual_inputs.push_back(get_reduced_shape(inputs.front(), axes));
virtual_inputs.push_back(get_output_shape(inputs.front(), axes));
virtual_inputs = reduce_dims(virtual_inputs);
auto reduce_output_shape = virtual_inputs.back();
virtual_inputs.pop_back();
auto reduction_shape = virtual_inputs.back();
virtual_inputs.pop_back();
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.virtual_inputs = virtual_inputs;
auto faxis = find_fast_axis({options.virtual_inputs.front()});
vectorize vec{};
auto nelements = reduce_output_shape.elements();
auto algo = v.get("algo", get_reduce_algo(options.virtual_inputs, reduction_shape.lens()));
if(algo == "block")
{
v["reduction"] = "op::min{}";
v["init"] = "highest{}";
// Vectorize if the axis is a reduction axis
if(reduce_output_shape.lens()[faxis] == 1)
vec = vectorize::elements(ctx, faxis, options.virtual_inputs);
auto relements = reduction_shape.elements() / vec.size;
auto block_size = compute_block_size(relements, 256);
if(relements >= block_size * 256)
algo = "block_large";
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
}
else if(op.name() == "reduce_prod")
else if(algo == "lane")
{
v["reduction"] = "op::product{}";
v["init"] = "1";
options.set_launch_params(v, compute_global_for(ctx, nelements, 256));
}
else
{
MIGRAPHX_THROW("Unsupported reduce");
MIGRAPHX_THROW("Unknown reduce algo: " + algo);
}
options.kernel_name = v.get("kernel", "reduce_kernel");
auto src = interpolate_string(
fused_reduce_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"algo", algo},
{"reduced", "decltype(" + generate_make_shape(reduce_output_shape) + ")"},
{"lambda", v.at("lambda").to<std::string>()},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})}});
options.params += "-Wno-float-equal";
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
assert(not ins->module_inputs().empty());
auto v = op.to_value();
auto* rm = ins->module_inputs().front();
v["preamble"] = generate_reduce(*rm, "fused_reduce_op");
v["lambda"] = "MIGRAPHX_LIFT(fused_reduce_op)";
v["kernel"] = generate_name_from_ops(*rm) + "_kernel";
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
......
......@@ -195,6 +195,14 @@ constexpr auto compose(Fs... fs)
})(fs...);
}
template <class F>
constexpr auto partial(F f)
{
return [=](auto... xs) {
return [=](auto&&... ys) { return f(xs..., static_cast<decltype(ys)>(ys)...); };
};
}
template <class... Ts>
constexpr auto pack(Ts... xs)
{
......
......@@ -233,6 +233,12 @@ struct index
}
};
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline __device__ __attribute__((const)) index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
......
......@@ -174,6 +174,25 @@ struct inner_storage_tag
template <class T>
using is_inner_storage = is_base_of<inner_storage_tag, remove_cv_t<remove_reference_t<T>>>;
template <class Size, class F>
struct lazy_inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
constexpr lazy_inner_storage<Size, F> make_lazy_inner_storage(Size, F f)
{
return {{}, f};
}
template <class R, class F>
struct storage_access : F
{
......@@ -278,6 +297,14 @@ struct reducer_base
});
}
template <class F>
__device__ auto lazy_inner(F f) const
{
return this->inner_sliced([=](auto n, auto&&... xs) {
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
});
}
template <class Op, class T, class Read>
__device__ auto reduce(Op op, T init, Read read) const
{
......@@ -396,25 +423,6 @@ struct block_large
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class... Ts>
__device__ auto reduce_impl(Op op, T init, Read read, N n, Ts&&... xs) const
{
......@@ -439,7 +447,7 @@ struct block_large
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
......@@ -469,25 +477,6 @@ struct lane
index idx;
Slicer slice;
template <class Size, class F>
struct inner_storage : inner_storage_tag
{
using type = remove_reference_t<decltype(declval<F>()(0, _c<0>))>;
F f;
constexpr Size rsize() const { return {}; }
template <class U, class V>
constexpr auto operator()(U j, V d) const
{
return f(j, d);
}
};
template <class Size, class F>
static constexpr inner_storage<Size, F> make_inner_storage(Size, F f)
{
return {{}, {f}};
}
template <class Op, class T, class Read, class N, class U, class... Us>
__device__ auto reduce_impl(Op op, T init, Read read, N n, U&& x, Us&&... xs) const
{
......@@ -518,7 +507,7 @@ struct lane
template <class R, class F, class N, class... Ts>
__device__ auto inner_impl(F f, N n, Ts&&... xs) const
{
return make_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
return make_lazy_inner_storage(n, [=](auto j, auto d) { return f(xs(j, d)...); });
}
};
template <class Slicer>
......@@ -577,5 +566,21 @@ simple_reduce(Op op, T init, Input input, Output output, ReadInput read, WriteOu
});
}
template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f)
{
Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r);
if constexpr(reduce::is_inner_storage<decltype(result)>{})
{
r.inner([&](auto& y, auto x) { y = x; })(output, result);
}
else
{
r.outer([&] { output[out_idx] = implicit_conversion(result); });
}
});
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_REDUCE_HPP
......@@ -32,6 +32,7 @@
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp>
#include <migraphx/insert_pad.hpp>
#include <migraphx/layout_nhwc.hpp>
......@@ -72,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
{
......@@ -129,6 +131,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
fuse_mlir{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_reduce.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <basic_ops.hpp>
#include <migraphx/make_op.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_reduce{}, migraphx::dead_code_elimination{}});
}
bool all_instructions_are_local(const migraphx::module& m)
{
return std::all_of(m.begin(), m.end(), [&](const auto& ins) {
return std::all_of(ins.inputs().begin(), ins.inputs().end(), [&](auto input) {
return m.has_instruction(input);
});
});
}
template <class F>
migraphx::instruction_ref add_reduce(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
const std::vector<int64_t>& axes,
F f)
{
auto* rm = p.create_module(name);
auto* mm = p.get_main_module();
rm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
return rm->add_parameter(
"x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type(), input->get_shape().lens()});
});
auto r = f(rm, params, axes);
rm->add_return({r});
EXPECT(all_instructions_are_local(*rm));
return mm->add_instruction(migraphx::make_op("fused_reduce", {{"axes", axes}}), inputs, {rm});
}
inline auto single_reduce(const std::string& name)
{
return [=](auto* rm, const auto& inputs, const auto& axes) {
return rm->add_instruction(migraphx::make_op(name, {{"axes", axes}}), inputs);
};
}
TEST_CASE(single)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), y);
mm->add_return({rsum1, rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {y}, {1}, single_reduce("reduce_sum"));
mm->add_return({rsum1, rsum2});
}
EXPECT(p1 == p2);
}
TEST_CASE(pointwise_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add);
mm->add_return({rsum});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = add_reduce(
p2,
"main:pointwise0:main:reduce_sum0",
{x, y},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto add =
add_pointwise(p2, rm, "main:pointwise0", inputs, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_pointwise)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto add = add_pointwise(p1, "main:pointwise0", {rsumb, y}, single_pointwise("add"));
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_reduce(
p2,
"main:reduce_sum0:main:pointwise0",
{x, y},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
return add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[1]}, single_pointwise("add"));
});
mm->add_return({add});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto rsumdiff = add_pointwise(p1, "main:pointwise0", {rsumb, x}, single_pointwise("sub"));
auto rsum2 =
mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), rsumdiff);
auto sqrt = add_pointwise(p1, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
mm->add_return({sqrt});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto sqrt = add_reduce(
p2,
"main:reduce_sum1:main:reduce_sum0:main:pointwise0:main:pointwise1",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum);
auto rsumdiff = add_pointwise(
p2, rm, "main:pointwise0", {rsumb, inputs[0]}, single_pointwise("sub"));
auto rsum2 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
rsumdiff);
return add_pointwise(p2, rm, "main:pointwise1", {rsum2}, single_pointwise("sqrt"));
});
mm->add_return({sqrt});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce_mismatch_axis)
{
migraphx::shape s{migraphx::shape::float_type, {4, 2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), rsum1);
mm->add_return({rsum2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = add_reduce(p2, "main:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsum2 = add_reduce(p2, "main:reduce_sum1", {rsum1}, {2}, single_reduce("reduce_sum"));
mm->add_return({rsum2});
}
EXPECT(p1 == p2);
}
TEST_CASE(pointwise_reduce_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), x);
auto sqrt = add_pointwise(p1, "main:pointwise0", {rsum1}, single_pointwise("sqrt"));
auto sqrtb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt);
auto add1 = add_pointwise(p1, "main:pointwise1", {sqrtb, x}, single_pointwise("add"));
auto rsum2 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {1}}}), add1);
auto add2 = add_pointwise(p1, "main:pointwise2", {rsum2, rsum1}, single_pointwise("add"));
mm->add_return({add2});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto add2 = add_reduce(
p2,
"main:pointwise0:main:pointwise1:main:reduce_sum1:main:pointwise2:main:reduce_sum0",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto sqrt =
add_pointwise(p2, rm, "main:pointwise0", {rsum1}, single_pointwise("sqrt"));
auto sqrtb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), sqrt);
auto add1 = add_pointwise(
p2, rm, "main:pointwise1", {sqrtb, inputs[0]}, single_pointwise("add"));
auto rsum2 =
rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add1);
return add_pointwise(
p2, rm, "main:pointwise2", {rsum2, rsum1}, single_pointwise("add"));
});
mm->add_return({add2});
}
EXPECT(p1 == p2);
}
TEST_CASE(reduce_reduce_broadcast)
{
migraphx::shape s{migraphx::shape::float_type, {4, 2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum1 = add_reduce(p1, "test:reduce_sum0", {x}, {1}, single_reduce("reduce_sum"));
auto rsumb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1);
auto add = add_reduce(
p1,
"test:reduce_sum1",
{rsumb, x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto add2 =
add_pointwise(p1, rm, "test:pointwise0", inputs, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add2);
});
mm->add_return({add});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto rsum = add_reduce(
p2,
"test:reduce_sum1:test:reduce_sum0",
{x},
{1},
[&](auto* rm, const auto& inputs, const auto& axes) {
auto rsum1 = rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}),
inputs[0]);
auto rsumb = rm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", s.lens()}}), rsum1);
auto add = add_pointwise(
p2, rm, "test:pointwise0", {rsumb, inputs[0]}, single_pointwise("add"));
return rm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", axes}}), add);
});
mm->add_return({rsum});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -30,12 +30,12 @@
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = p.create_module(name);
auto* mm = p.get_main_module();
pm->set_bypass();
std::vector<migraphx::instruction_ref> params;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(params), [&](auto input) {
......@@ -47,6 +47,15 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
return add_pointwise(p, p.get_main_module(), name, inputs, f);
}
inline auto single_pointwise(const std::string& name)
{
return [=](auto* pm, const auto& inputs) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment