Commit 5ee35bf0 authored by Paul's avatar Paul
Browse files

Format

parent d1dc3e35
...@@ -27,29 +27,37 @@ struct fused_concat ...@@ -27,29 +27,37 @@ struct fused_concat
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
check_shapes{inputs, *this}.same_ndims(); check_shapes{inputs, *this}.same_ndims();
if ((inputs.size() + 1) == mods.size()) if((inputs.size() + 1) == mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules"); MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
auto input_iter = inputs.begin(); auto input_iter = inputs.begin();
std::vector<shape> concat_inputs; std::vector<shape> concat_inputs;
for(module_ref mod:range(mods.begin(), mods.end()-1)) for(module_ref mod : range(mods.begin(), mods.end() - 1))
{ {
concat_inputs.push_back(*input_iter); concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size(); input_iter += mod->get_parameter_names().size();
} }
module_ref post_mod = mods.back(); module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type(); auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens(); const auto& first_shape_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin()+1, concat_inputs.end(), [&](auto s) { if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens(); const auto& lens = s.lens();
return std::equal(lens.begin(), lens.begin()+axis, first_shape_lens.begin(), first_shape_lens.begin()+axis) and return std::equal(lens.begin(),
std::equal(lens.begin()+axis+1, lens.end(), first_shape_lens.begin()+axis+1, first_shape_lens.end()); lens.begin() + axis,
})) first_shape_lens.begin(),
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " + std::to_string(axis)); first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
std::size_t new_dim_axis = transform_accumulate(concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) { lens.end(),
return input.lens()[axis]; first_shape_lens.begin() + axis + 1,
}); first_shape_lens.end());
auto new_lens = concat_inputs.front().lens(); }))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " +
std::to_string(axis));
std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
new_lens[axis] = new_dim_axis; new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs)); return shape::from_permutation(type, new_lens, find_permutation(inputs));
} }
...@@ -63,55 +71,66 @@ struct find_pointwise_concat_pointwise ...@@ -63,55 +71,66 @@ struct find_pointwise_concat_pointwise
{ {
auto matcher() const auto matcher() const
{ {
auto concat = match::name("concat")(match::used_once(), match::any_of[match::inputs()](match::name("pointwise")(match::used_once()))); auto concat = match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat"))); return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
} }
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto concat_ins = r.instructions["concat"]; auto concat_ins = r.instructions["concat"];
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) - ins->inputs().begin(); auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
ins->inputs().begin();
std::vector<instruction_ref> inputs; std::vector<instruction_ref> inputs;
for(auto input:concat_ins->inputs()) for(auto input : concat_ins->inputs())
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end()); inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) { std::copy_if(ins->inputs().begin(),
return input != concat_ins; ins->inputs().end(),
}); std::back_inserter(inputs),
[&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs; std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), [&](instruction_ref input) { std::transform(concat_ins->inputs().begin(),
if (input->name() == "pointwise") concat_ins->inputs().end(),
{ std::back_inserter(module_inputs),
auto* pm = input->module_inputs().front(); [&](instruction_ref input) {
return mpm.create_module("concat:" + pm->name(), *pm); if(input->name() == "pointwise")
} {
auto* pm = mpm.create_module("concat" + std::to_string(counter++)); auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
auto x = pm->add_parameter("x", shape{input->get_shape().type()}); }
auto id = pm->add_instruction(make_op("identity"), x); auto* pm = mpm.create_module("concat" + std::to_string(counter++));
pm->add_return({id});
return pm; auto x = pm->add_parameter("x", shape{input->get_shape().type()});
}); auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
auto* post_pm = ins->module_inputs().front(); return pm;
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm); });
auto* post_pm = ins->module_inputs().front();
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
std::vector<std::string> names = rm->get_parameter_names(); std::vector<std::string> names = rm->get_parameter_names();
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
auto concat_param_name = names[concat_arg]; auto concat_param_name = names[concat_arg];
auto concat_param = rm->get_parameter(concat_param_name); auto concat_param = rm->get_parameter(concat_param_name);
auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape()); auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
rm->replace_instruction(concat_param, param); rm->replace_instruction(concat_param, param);
rm->remove_instruction(concat_param); rm->remove_instruction(concat_param);
module_inputs.push_back(rm); module_inputs.push_back(rm);
mpm.get_module().replace_instruction(ins, make_op("fused_concat", concat_ins->normalized_operator().to_value()), inputs, module_inputs); mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
} }
}; };
} } // namespace
void fuse_concat::apply(module_pass_manager& mpm) const void fuse_concat::apply(module_pass_manager& mpm) const
{ {
......
...@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module ...@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module
friend bool operator!=(const module& x, const module& y) { return not(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
friend struct program; friend struct program;
private: private:
void set_name(const std::string& name); void set_name(const std::string& name);
void assign(const module& m); void assign(const module& m);
......
...@@ -134,10 +134,7 @@ module& module::operator=(module m) ...@@ -134,10 +134,7 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; } std::string module::name() const { return impl->name; }
void module::set_name(const std::string& name) void module::set_name(const std::string& name) { impl->name = name; }
{
impl->name = name;
}
bool module::bypass() const { return impl->bypass; } bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; } void module::set_bypass(bool b) { impl->bypass = b; }
......
...@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) ...@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
struct concat_compiler : compiler<concat_compiler> struct concat_compiler : compiler<concat_compiler>
{ {
std::vector<std::string> names() const { return {"concat"}; } std::vector<std::string> names() const { return {"concat"}; }
...@@ -152,19 +151,18 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -152,19 +151,18 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel"); options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto op_names = v.at("ops").to_vector<std::string>(); auto op_names = v.at("ops").to_vector<std::string>();
auto args = v.at("args"); auto args = v.at("args");
vectorize vec{}; vectorize vec{};
if(axis != v.at("axis").to<std::size_t>()) if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs); vec = vectorize::elements(ctx, axis, options.inputs);
auto nelements_per_op = options.inputs.back().elements() / op_names.size(); auto nelements_per_op = options.inputs.back().elements() / op_names.size();
options.set_launch_params( options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params; std::vector<std::string> concat_params;
std::vector<std::string> concat_args; std::vector<std::string> concat_args;
for(const auto& name:op_names) for(const auto& name : op_names)
{ {
auto n = args.at(name).to<std::size_t>(); auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x"; auto prefix = name + "_concat_x";
transform(range(n), std::back_inserter(concat_params), [&](auto i) { transform(range(n), std::back_inserter(concat_params), [&](auto i) {
return "auto " + prefix + std::to_string(i); return "auto " + prefix + std::to_string(i);
...@@ -175,17 +173,16 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -175,17 +173,16 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
}); });
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")"); concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
} }
auto src = interpolate_string( auto src = interpolate_string(fused_concat_kernel,
fused_concat_kernel, {{"kernel", options.kernel_name},
{{"kernel", options.kernel_name}, {"params", enum_params(inputs.size(), "void * private_p")},
{"params", enum_params(inputs.size(), "void * private_p")}, {"args", enum_params(inputs.size(), "private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"concat_params", join_strings(concat_params, ", ")},
{"concat_params", join_strings(concat_params, ", ")}, {"concat_args", join_strings(concat_args, ", ")},
{"concat_args", join_strings(concat_args, ", ")}, {"post", v.get("post", std::string{"op::id{}"})},
{"post", v.get("post", std::string{"op::id{}"})}, {"transformers", make_transformer_args(vec)},
{"transformers", make_transformer_args(vec)}, {"preamble", v.get("preamble", std::string{})},
{"preamble", v.get("preamble", std::string{})}, {"axis", v.at("axis").to<std::string>()}});
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
...@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
{ {
auto v = op.to_value(); auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup; std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), std::inserter(mod_names_lookup, mod_names_lookup.end()), [&](auto i) { transform(range(ins->module_inputs().size()),
return std::make_pair(ins->module_inputs()[i]->name(), "pointwise" + std::to_string(i)); std::inserter(mod_names_lookup, mod_names_lookup.end()),
}); [&](auto i) {
v["preamble"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end(), std::string{}, std::plus<>{}, [&](module_ref mod) { return std::make_pair(ins->module_inputs()[i]->name(),
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n"; "pointwise" + std::to_string(i));
}); });
v["preamble"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end(),
std::string{},
std::plus<>{},
[&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
std::vector<std::string> mod_names; std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::back_inserter(mod_names), [&](module_ref mod) { std::transform(ins->module_inputs().begin(),
return mod_names_lookup.at(mod->name()); ins->module_inputs().end(),
}); std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names; v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args; std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::inserter(mod_args, mod_args.end()), [&](module_ref mod) { std::transform(ins->module_inputs().begin(),
const auto& name = mod_names_lookup.at(mod->name()); ins->module_inputs().end(),
return std::make_pair(name, mod->get_parameter_names().size()); std::inserter(mod_args, mod_args.end()),
}); [&](module_ref mod) {
v["args"] = mod_args; const auto& name = mod_names_lookup.at(mod->name());
v["kernel"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end()-1, std::string{}, std::plus<>{}, [&](module_ref mod) { return std::make_pair(name, mod->get_parameter_names().size());
return generate_name_from_ops(*mod) + "_"; });
}) + "concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel"; v["args"] = mod_args;
v["kernel"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) { return generate_name_from_ops(*mod) + "_"; }) +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) +
"_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
...@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p) ...@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p)
migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}});
} }
template<class F> template <class F>
struct concat_arg struct concat_arg
{ {
std::string name; std::string name;
...@@ -45,35 +45,40 @@ struct concat_arg ...@@ -45,35 +45,40 @@ struct concat_arg
F f; F f;
}; };
template<class F> template <class F>
concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f) concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f)
{ {
return {std::move(name), std::move(inputs), std::move(f)}; return {std::move(name), std::move(inputs), std::move(f)};
} }
template <class Arg, class... Args> template <class Arg, class... Args>
migraphx::instruction_ref add_concat(migraphx::program& p, migraphx::instruction_ref
std::size_t axis, add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
Arg post_arg,
Args... args)
{ {
std::vector<migraphx::module_ref> module_inputs; std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs; std::vector<migraphx::instruction_ref> ins_inputs;
migraphx::each_args([&](auto arg) { migraphx::each_args(
module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f)); [&](auto arg) {
ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end()); module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f));
}, args...); ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end());
},
args...);
module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) { module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) {
std::vector<migraphx::instruction_ref> params; std::vector<migraphx::instruction_ref> params;
params.push_back(pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()})); params.push_back(
std::transform(post_arg.inputs.begin(), post_arg.inputs.end(), std::back_inserter(params), [&](auto input) { pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
return pm->add_parameter("x" + std::to_string(params.size()), std::transform(post_arg.inputs.begin(),
migraphx::shape{input->get_shape().type()}); post_arg.inputs.end(),
}); std::back_inserter(params),
[&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
return post_arg.f(pm, params); return post_arg.f(pm, params);
})); }));
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
return mm->add_instruction(migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs); return mm->add_instruction(
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
} }
TEST_CASE(simple_pointwise_concat) TEST_CASE(simple_pointwise_concat)
...@@ -81,22 +86,27 @@ TEST_CASE(simple_pointwise_concat) ...@@ -81,22 +86,27 @@ TEST_CASE(simple_pointwise_concat)
migraphx::shape s{migraphx::shape::float_type, {2, 3}}; migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1; migraphx::program p1;
{ {
auto* mm = p1.get_main_module(); auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add")); auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub")); auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub); auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu")); auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
mm->add_return({relu}); mm->add_return({relu});
} }
run_pass(p1); run_pass(p1);
migraphx::program p2; migraphx::program p2;
{ {
auto* mm = p2.get_main_module(); auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s); auto y = mm->add_parameter("y", s);
auto fused_concat = add_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), arg("concat:main:pointwise1", {x, y}, single_pointwise("sub"))); auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat}); mm->add_return({fused_concat});
} }
EXPECT(p1 == p2); EXPECT(p1 == p2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment