"vscode:/vscode.git/clone" did not exist on "4df8b7a2cec4f44da220b9a2ae5b40dbcd3ef288"
Commit 5ee35bf0 authored by Paul's avatar Paul
Browse files

Format

parent d1dc3e35
......@@ -27,29 +27,37 @@ struct fused_concat
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if ((inputs.size() + 1) == mods.size())
if((inputs.size() + 1) == mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
auto input_iter = inputs.begin();
std::vector<shape> concat_inputs;
for(module_ref mod:range(mods.begin(), mods.end()-1))
for(module_ref mod : range(mods.begin(), mods.end() - 1))
{
concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size();
}
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin()+1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(), lens.begin()+axis, first_shape_lens.begin(), first_shape_lens.begin()+axis) and
std::equal(lens.begin()+axis+1, lens.end(), first_shape_lens.begin()+axis+1, first_shape_lens.end());
}))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " + std::to_string(axis));
std::size_t new_dim_axis = transform_accumulate(concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin() + 1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(),
lens.begin() + axis,
first_shape_lens.begin(),
first_shape_lens.begin() + axis) and
std::equal(lens.begin() + axis + 1,
lens.end(),
first_shape_lens.begin() + axis + 1,
first_shape_lens.end());
}))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " +
std::to_string(axis));
std::size_t new_dim_axis = transform_accumulate(
concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
......@@ -63,55 +71,66 @@ struct find_pointwise_concat_pointwise
{
auto matcher() const
{
auto concat = match::name("concat")(match::used_once(), match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
auto concat = match::name("concat")(
match::used_once(),
match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins = r.result;
auto concat_ins = r.instructions["concat"];
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) - ins->inputs().begin();
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) -
ins->inputs().begin();
std::vector<instruction_ref> inputs;
for(auto input:concat_ins->inputs())
for(auto input : concat_ins->inputs())
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) {
return input != concat_ins;
});
std::copy_if(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(inputs),
[&](auto input) { return input != concat_ins; });
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), [&](instruction_ref input) {
if (input->name() == "pointwise")
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat" + std::to_string(counter++));
auto x = pm->add_parameter("x", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
return pm;
});
auto* post_pm = ins->module_inputs().front();
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
std::transform(concat_ins->inputs().begin(),
concat_ins->inputs().end(),
std::back_inserter(module_inputs),
[&](instruction_ref input) {
if(input->name() == "pointwise")
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat" + std::to_string(counter++));
auto x = pm->add_parameter("x", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
return pm;
});
auto* post_pm = ins->module_inputs().front();
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
std::vector<std::string> names = rm->get_parameter_names();
std::sort(names.begin(), names.end());
auto concat_param_name = names[concat_arg];
auto concat_param = rm->get_parameter(concat_param_name);
auto concat_param = rm->get_parameter(concat_param_name);
auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
rm->replace_instruction(concat_param, param);
rm->remove_instruction(concat_param);
module_inputs.push_back(rm);
mpm.get_module().replace_instruction(ins, make_op("fused_concat", concat_ins->normalized_operator().to_value()), inputs, module_inputs);
mpm.get_module().replace_instruction(
ins,
make_op("fused_concat", concat_ins->normalized_operator().to_value()),
inputs,
module_inputs);
}
};
}
} // namespace
void fuse_concat::apply(module_pass_manager& mpm) const
{
......
......@@ -240,6 +240,7 @@ struct MIGRAPHX_EXPORT module
friend bool operator!=(const module& x, const module& y) { return not(x == y); }
friend struct program;
private:
void set_name(const std::string& name);
void assign(const module& m);
......
......@@ -134,10 +134,7 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; }
void module::set_name(const std::string& name)
{
impl->name = name;
}
void module::set_name(const std::string& name) { impl->name = name; }
bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; }
......
......@@ -61,7 +61,6 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__";
struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
......@@ -152,19 +151,18 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto op_names = v.at("ops").to_vector<std::string>();
auto args = v.at("args");
auto op_names = v.at("ops").to_vector<std::string>();
auto args = v.at("args");
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
auto nelements_per_op = options.inputs.back().elements() / op_names.size();
options.set_launch_params(
v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params;
std::vector<std::string> concat_args;
for(const auto& name:op_names)
for(const auto& name : op_names)
{
auto n = args.at(name).to<std::size_t>();
auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x";
transform(range(n), std::back_inserter(concat_params), [&](auto i) {
return "auto " + prefix + std::to_string(i);
......@@ -175,17 +173,16 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
});
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
}
auto src = interpolate_string(
fused_concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", join_strings(concat_params, ", ")},
{"concat_args", join_strings(concat_args, ", ")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
auto src = interpolate_string(fused_concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", join_strings(concat_params, ", ")},
{"concat_args", join_strings(concat_args, ", ")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
......@@ -193,26 +190,43 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
{
auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), std::inserter(mod_names_lookup, mod_names_lookup.end()), [&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(), "pointwise" + std::to_string(i));
});
v["preamble"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end(), std::string{}, std::plus<>{}, [&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
transform(range(ins->module_inputs().size()),
std::inserter(mod_names_lookup, mod_names_lookup.end()),
[&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(),
"pointwise" + std::to_string(i));
});
v["preamble"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end(),
std::string{},
std::plus<>{},
[&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::back_inserter(mod_names), [&](module_ref mod) {
return mod_names_lookup.at(mod->name());
});
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::inserter(mod_args, mod_args.end()), [&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
v["kernel"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end()-1, std::string{}, std::plus<>{}, [&](module_ref mod) {
return generate_name_from_ops(*mod) + "_";
}) + "concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end(),
std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
v["kernel"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) { return generate_name_from_ops(*mod) + "_"; }) +
"concat_" + generate_name_from_ops(*(ins->module_inputs().back())) +
"_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
......
......@@ -37,7 +37,7 @@ void run_pass(migraphx::program& p)
migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}});
}
template<class F>
template <class F>
struct concat_arg
{
std::string name;
......@@ -45,35 +45,40 @@ struct concat_arg
F f;
};
template<class F>
template <class F>
concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f)
{
return {std::move(name), std::move(inputs), std::move(f)};
}
template <class Arg, class... Args>
migraphx::instruction_ref add_concat(migraphx::program& p,
std::size_t axis,
Arg post_arg,
Args... args)
migraphx::instruction_ref
add_concat(migraphx::program& p, std::size_t axis, Arg post_arg, Args... args)
{
std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs;
migraphx::each_args([&](auto arg) {
module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f));
ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end());
}, args...);
migraphx::each_args(
[&](auto arg) {
module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f));
ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end());
},
args...);
module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) {
std::vector<migraphx::instruction_ref> params;
params.push_back(pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
std::transform(post_arg.inputs.begin(), post_arg.inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
params.push_back(
pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
std::transform(post_arg.inputs.begin(),
post_arg.inputs.end(),
std::back_inserter(params),
[&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
return post_arg.f(pm, params);
}));
auto* mm = p.get_main_module();
return mm->add_instruction(migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
return mm->add_instruction(
migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
}
TEST_CASE(simple_pointwise_concat)
......@@ -81,22 +86,27 @@ TEST_CASE(simple_pointwise_concat)
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
mm->add_return({relu});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat = add_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat =
add_concat(p2,
1,
arg("main:pointwise2:concat", {}, single_pointwise("relu")),
arg("concat:main:pointwise0", {x, y}, single_pointwise("add")),
arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment