Commit 5ee35bf0 authored by Paul's avatar Paul
Browse files

Format

parent d1dc3e35
...@@ -27,11 +27,11 @@ struct fused_concat ...@@ -27,11 +27,11 @@ struct fused_concat
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if ((inputs.size() + 1) == mods.size()) if((inputs.size() + 1) == mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules"); MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
auto input_iter = inputs.begin(); auto input_iter = inputs.begin();
std::vector<shape> concat_inputs; std::vector<shape> concat_inputs;
for(module_ref mod:range(mods.begin(), mods.end()-1)) for(module_ref mod : range(mods.begin(), mods.end() - 1))
{ {
concat_inputs.push_back(*input_iter); concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size(); input_iter += mod->get_parameter_names().size();
...@@ -39,14 +39,22 @@ struct fused_concat ...@@ -39,14 +39,22 @@ struct fused_concat
module_ref post_mod = mods.back(); module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type(); auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens(); const auto& first_shape_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin()+1, concat_inputs.end(), [&](auto s) { if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens(); const auto& lens = s.lens();
return std::equal(lens.begin(), lens.begin()+axis, first_shape_lens.begin(), first_shape_lens.begin()+axis) and return std::equal(lens.begin(),
std::equal(lens.begin()+axis+1, lens.end(), first_shape_lens.begin()+axis+1, first_shape_lens.end()); lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
})) }))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " + std::to_string(axis)); MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " +
std::to_string(axis));
std::size_t new_dim_axis = transform_accumulate(concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) { std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis]; return input.lens()[axis];
}); });
auto new_lens = concat_inputs.front().lens(); auto new_lens = concat_inputs.front().lens();
...@@ -63,7 +71,9 @@ struct find_pointwise_concat_pointwise ...@@ -63,7 +71,9 @@ struct find_pointwise_concat_pointwise
{ {
auto matcher() const auto matcher() const
{ {
auto concat = match::name("concat")(match::used_once(), match::any_of[match::inputs()](match::name("pointwise")(match::used_once()))); auto concat = match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat"))); return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
} }
...@@ -72,17 +82,22 @@ struct find_pointwise_concat_pointwise ...@@ -72,17 +82,22 @@ struct find_pointwise_concat_pointwise
auto ins = r.result; auto ins = r.result;
auto concat_ins = r.instructions["concat"]; auto concat_ins = r.instructions["concat"];
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) - ins->inputs().begin(); auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
ins->inputs().begin();
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
for(auto input:concat_ins->inputs()) for(auto input : concat_ins->inputs())
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end()); inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) { std::copy_if(ins->inputs().begin(),
return input != concat_ins; ins->inputs().end(),
}); std::back_inserter(inputs),
[&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs; std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), [&](instruction_ref input) { std::transform(concat_ins->inputs().begin(),
if (input->name() == "pointwise") concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
{ {
auto* pm = input->module_inputs().front(); auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm); return mpm.create_module("concat:" + pm->name(), *pm);
...@@ -107,11 +122,15 @@ struct find_pointwise_concat_pointwise ...@@ -107,11 +122,15 @@ struct find_pointwise_concat_pointwise
module_inputs.push_back(rm); module_inputs.push_back(rm);
mpm.get_module().replace_instruction(ins, make_op("fused_concat", concat_ins->normalized_operator().to_value()), inputs, module_inputs); mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
} }
}; };
} } // namespace
void fuse_concat::apply(module_pass_manager& mpm) const void fuse_concat::apply(module_pass_manager& mpm) const
{ {
......
...@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module ...@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module
friend bool operator!=(const module& x, const module& y) { return not(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
friend struct program; friend struct program;
private: private:
void set_name(const std::string& name); void set_name(const std::string& name);
void assign(const module& m); void assign(const module& m);
......
...@@ -134,10 +134,7 @@ module& module::operator=(module m) ...@@ -134,10 +134,7 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; } std::string module::name() const { return impl->name; }
void module::set_name(const std::string& name) void module::set_name(const std::string& name) { impl->name = name; }
{
impl->name = name;
}
bool module::bypass() const { return impl->bypass; } bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; } void module::set_bypass(bool b) { impl->bypass = b; }
......
...@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) ...@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
struct concat_compiler : compiler<concat_compiler> struct concat_compiler : compiler<concat_compiler>
{ {
std::vector<std::string> names() const { return {"concat"}; } std::vector<std::string> names() const { return {"concat"}; }
...@@ -158,11 +157,10 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -158,11 +157,10 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
if(axis != v.at("axis").to<std::size_t>()) if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs); vec = vectorize::elements(ctx, axis, options.inputs);
auto nelements_per_op = options.inputs.back().elements() / op_names.size(); auto nelements_per_op = options.inputs.back().elements() / op_names.size();
options.set_launch_params( options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params; std::vector<std::string> concat_params;
std::vector<std::string> concat_args; std::vector<std::string> concat_args;
for(const auto& name:op_names) for(const auto& name : op_names)
{ {
auto n = args.at(name).to<std::size_t>(); auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x"; auto prefix = name + "_concat_x";
...@@ -175,8 +173,7 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -175,8 +173,7 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
}); });
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")"); concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
} }
auto src = interpolate_string( auto src = interpolate_string(fused_concat_kernel,
fused_concat_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
...@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
{ {
auto v = op.to_value(); auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup; std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), std::inserter(mod_names_lookup, mod_names_lookup.end()), [&](auto i) { transform(range(ins->module_inputs().size()),
return std::make_pair(ins->module_inputs()[i]->name(), "pointwise" + std::to_string(i)); std::inserter(mod_names_lookup, mod_names_lookup.end()),
[&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(),
"pointwise" + std::to_string(i));
}); });
v["preamble"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end(), std::string{}, std::plus<>{}, [&](module_ref mod) { v["preamble"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end(),
std::string{},
std::plus<>{},
[&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n"; return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
}); });
std::vector<std::string> mod_names; std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::back_inserter(mod_names), [&](module_ref mod) { std::transform(ins->module_inputs().begin(),
return mod_names_lookup.at(mod->name()); ins->module_inputs().end(),
}); std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names; v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args; std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::inserter(mod_args, mod_args.end()), [&](module_ref mod) { std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name()); const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size()); return std::make_pair(name, mod->get_parameter_names().size());
}); });
v["args"] = mod_args; v["args"] = mod_args;
v["kernel"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end()-1, std::string{}, std::plus<>{}, [&](module_ref mod) { v["kernel"] = transform_accumulate(
return generate_name_from_ops(*mod) + "_"; ins->module_inputs().begin(),
}) + "concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel"; ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) { return generate_name_from_ops(*mod) + "_"; }) +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) +
"_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
...@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p) ...@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p)
migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}});
} }
template<class F> template <class F>
struct concat_arg struct concat_arg
{ {
std::string name; std::string name;
...@@ -45,35 +45,40 @@ struct concat_arg ...@@ -45,35 +45,40 @@ struct concat_arg
F f; F f;
}; };
template<class F> template <class F>
concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f) concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f)
{ {
return {std::move(name), std::move(inputs), std::move(f)}; return {std::move(name), std::move(inputs), std::move(f)};
} }
template <class Arg, class... Args> template <class Arg, class... Args>
migraphx::instruction_ref add_concat(migraphx::program& p, migraphx::instruction_ref
std::size_t axis, add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
Arg post_arg,
Args... args)
{ {
std::vector<migraphx::module_ref> module_inputs; std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs; std::vector<migraphx::instruction_ref> ins_inputs;
migraphx::each_args([&](auto arg) { migraphx::each_args(
[&](auto arg) {
module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f)); module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f));
ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end()); ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end());
}, args...); },
args...);
module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) { module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) {
std::vector<migraphx::instruction_ref> params; std::vector<migraphx::instruction_ref> params;
params.push_back(pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()})); params.push_back(
std::transform(post_arg.inputs.begin(), post_arg.inputs.end(), std::back_inserter(params), [&](auto input) { pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
std::transform(post_arg.inputs.begin(),
post_arg.inputs.end(),
std::back_inserter(params),
[&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()), return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()}); migraphx::shape{input->get_shape().type()});
}); });
return post_arg.f(pm, params); return post_arg.f(pm, params);
})); }));
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
return mm->add_instruction(migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs); return mm->add_instruction(
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
} }
TEST_CASE(simple_pointwise_concat) TEST_CASE(simple_pointwise_concat)
...@@ -96,7 +101,12 @@ TEST_CASE(simple_pointwise_concat) ...@@ -96,7 +101,12 @@ TEST_CASE(simple_pointwise_concat)
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto fused_concat = add_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat}); mm->add_return({fused_concat});
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment