Commit 5ee35bf0 authored by Paul's avatar Paul
Browse files

Format

parent d1dc3e35
......@@ -27,11 +27,11 @@ struct fused_concat
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if ((inputs.size() + 1) == mods.size())
if((inputs.size() + 1) == mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
auto input_iter = inputs.begin();
std::vector<shape> concat_inputs;
for(module_ref mod:range(mods.begin(), mods.end()-1))
for(module_ref mod : range(mods.begin(), mods.end() - 1))
{
concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size();
......@@ -39,14 +39,22 @@ struct fused_concat
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin()+1, concat_inputs.end(), [&](auto s) {
if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(), lens.begin()+axis, first_shape_lens.begin(), first_shape_lens.begin()+axis) and
std::equal(lens.begin()+axis+1, lens.end(), first_shape_lens.begin()+axis+1, first_shape_lens.end());
return std::equal(lens.begin(),
lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
}))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " + std::to_string(axis));
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " +
std::to_string(axis));
std::size_t new_dim_axis = transform_accumulate(concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
......@@ -63,7 +71,9 @@ struct find_pointwise_concat_pointwise
{
auto matcher() const
{
auto concat = match::name("concat")(match::used_once(), match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
auto concat = match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
}
......@@ -72,17 +82,22 @@ struct find_pointwise_concat_pointwise
auto ins = r.result;
auto concat_ins = r.instructions["concat"];
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) - ins->inputs().begin();
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
ins->inputs().begin();
std::vector<instruction_ref> inputs;
for(auto input:concat_ins->inputs())
for(auto input : concat_ins->inputs())
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) {
return input != concat_ins;
});
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), [&](instruction_ref input) {
if (input->name() == "pointwise")
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
......@@ -107,11 +122,15 @@ struct find_pointwise_concat_pointwise
module_inputs.push_back(rm);
mpm.get_module().replace_instruction(ins, make_op("fused_concat", concat_ins->normalized_operator().to_value()), inputs, module_inputs);
mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};
}
} // namespace
void fuse_concat::apply(module_pass_manager& mpm) const
{
......
......@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module
friend bool operator!=(const module& x, const module& y) { return not(x == y); }
friend struct program;
private:
void set_name(const std::string& name);
void assign(const module& m);
......
......@@ -134,10 +134,7 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; }
void module::set_name(const std::string& name)
{
impl->name = name;
}
void module::set_name(const std::string& name) { impl->name = name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
......
......@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__";
struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
......@@ -158,11 +157,10 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
auto nelements_per_op = options.inputs.back().elements() / op_names.size();
options.set_launch_params(
v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params;
std::vector<std::string> concat_args;
for(const auto& name:op_names)
for(const auto& name : op_names)
{
auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x";
......@@ -175,8 +173,7 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
});
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
}
auto src = interpolate_string(
fused_concat_kernel,
auto src = interpolate_string(fused_concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
......@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
{
auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), std::inserter(mod_names_lookup, mod_names_lookup.end()), [&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(), "pointwise" + std::to_string(i));
transform(range(ins->module_inputs().size()),
std::inserter(mod_names_lookup, mod_names_lookup.end()),
[&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(),
"pointwise" + std::to_string(i));
});
v["preamble"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end(), std::string{}, std::plus<>{}, [&](module_ref mod) {
v["preamble"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end(),
std::string{},
std::plus<>{},
[&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::back_inserter(mod_names), [&](module_ref mod) {
return mod_names_lookup.at(mod->name());
});
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::inserter(mod_args, mod_args.end()), [&](module_ref mod) {
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
v["kernel"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end()-1, std::string{}, std::plus<>{}, [&](module_ref mod) {
return generate_name_from_ops(*mod) + "_";
}) + "concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
v["kernel"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) { return generate_name_from_ops(*mod) + "_"; }) +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) +
"_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
......
......@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p)
migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}});
}
template<class F>
template <class F>
struct concat_arg
{
std::string name;
......@@ -45,35 +45,40 @@ struct concat_arg
F f;
};
template<class F>
template <class F>
concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f)
{
return {std::move(name), std::move(inputs), std::move(f)};
}
template <class Arg, class... Args>
migraphx::instruction_ref add_concat(migraphx::program& p,
std::size_t axis,
Arg post_arg,
Args... args)
migraphx::instruction_ref
add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
{
std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs;
migraphx::each_args([&](auto arg) {
migraphx::each_args(
[&](auto arg) {
module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f));
ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end());
}, args...);
},
args...);
module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) {
std::vector<migraphx::instruction_ref> params;
params.push_back(pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
std::transform(post_arg.inputs.begin(), post_arg.inputs.end(), std::back_inserter(params), [&](auto input) {
params.push_back(
pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
std::transform(post_arg.inputs.begin(),
post_arg.inputs.end(),
std::back_inserter(params),
[&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
return post_arg.f(pm, params);
}));
auto* mm = p.get_main_module();
return mm->add_instruction(migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
return mm->add_instruction(
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
}
TEST_CASE(simple_pointwise_concat)
......@@ -96,7 +101,12 @@ TEST_CASE(simple_pointwise_concat)
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat = add_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment