Commit d1dc3e35 authored by Paul's avatar Paul
Browse files

Add fuse_concat pass

parent 8ba66ebc
...@@ -53,6 +53,7 @@ add_library(migraphx ...@@ -53,6 +53,7 @@ add_library(migraphx
eliminate_pad.cpp eliminate_pad.cpp
env.cpp env.cpp
file_buffer.cpp file_buffer.cpp
fuse_concat.cpp
fuse_pointwise.cpp fuse_pointwise.cpp
fuse_reduce.cpp fuse_reduce.cpp
generate.cpp generate.cpp
......
#include <migraphx/fuse_concat.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/algorithm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct fused_concat
{
int64_t axis = 0;
std::string name() const { return "fused_concat"; }
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
check_shapes{inputs, *this}.same_ndims();
if ((inputs.size() + 1) == mods.size())
MIGRAPHX_THROW("FUSED_CONCAT: Missing fused modules");
auto input_iter = inputs.begin();
std::vector<shape> concat_inputs;
for(module_ref mod:range(mods.begin(), mods.end()-1))
{
concat_inputs.push_back(*input_iter);
input_iter += mod->get_parameter_names().size();
}
module_ref post_mod = mods.back();
auto type = std::prev(post_mod->end())->get_shape().type();
const auto& first_shape_lens = concat_inputs.front().lens();
if(not std::all_of(concat_inputs.begin()+1, concat_inputs.end(), [&](auto s) {
const auto& lens = s.lens();
return std::equal(lens.begin(), lens.begin()+axis, first_shape_lens.begin(), first_shape_lens.begin()+axis) and
std::equal(lens.begin()+axis+1, lens.end(), first_shape_lens.begin()+axis+1, first_shape_lens.end());
}))
MIGRAPHX_THROW("FUSED_CONCAT: all input dimensions should match along non-axis: " + std::to_string(axis));
std::size_t new_dim_axis = transform_accumulate(concat_inputs.begin(), concat_inputs.end(), 0, std::plus<>{}, [&](const auto& input) {
return input.lens()[axis];
});
auto new_lens = concat_inputs.front().lens();
new_lens[axis] = new_dim_axis;
return shape::from_permutation(type, new_lens, find_permutation(inputs));
}
};
MIGRAPHX_REGISTER_OP(fused_concat);
namespace {
static unsigned int counter = 0;
struct find_pointwise_concat_pointwise
{
auto matcher() const
{
auto concat = match::name("concat")(match::used_once(), match::any_of[match::inputs()](match::name("pointwise")(match::used_once())));
return match::name("pointwise")(match::any_of[match::inputs()](concat.bind("concat")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat_ins = r.instructions["concat"];
auto concat_arg = std::find(ins->inputs().begin(), ins->inputs().end(), concat_ins) - ins->inputs().begin();
std::vector<instruction_ref> inputs;
for(auto input:concat_ins->inputs())
inputs.insert(inputs.end(), input->inputs().begin(), input->inputs().end());
std::copy_if(ins->inputs().begin(), ins->inputs().end(), std::back_inserter(inputs), [&](auto input) {
return input != concat_ins;
});
std::vector<module_ref> module_inputs;
std::transform(concat_ins->inputs().begin(), concat_ins->inputs().end(), std::back_inserter(module_inputs), [&](instruction_ref input) {
if (input->name() == "pointwise")
{
auto* pm = input->module_inputs().front();
return mpm.create_module("concat:" + pm->name(), *pm);
}
auto* pm = mpm.create_module("concat" + std::to_string(counter++));
auto x = pm->add_parameter("x", shape{input->get_shape().type()});
auto id = pm->add_instruction(make_op("identity"), x);
pm->add_return({id});
return pm;
});
auto* post_pm = ins->module_inputs().front();
auto* rm = mpm.create_module(post_pm->name() + ":concat", *post_pm);
std::vector<std::string> names = rm->get_parameter_names();
std::sort(names.begin(), names.end());
auto concat_param_name = names[concat_arg];
auto concat_param = rm->get_parameter(concat_param_name);
auto param = rm->add_parameter("!" + concat_param_name, concat_param->get_shape());
rm->replace_instruction(concat_param, param);
rm->remove_instruction(concat_param);
module_inputs.push_back(rm);
mpm.get_module().replace_instruction(ins, make_op("fused_concat", concat_ins->normalized_operator().to_value()), inputs, module_inputs);
}
};
}
void fuse_concat::apply(module_pass_manager& mpm) const
{
match::find_matches(mpm, find_pointwise_concat_pointwise{});
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
#include <migraphx/config.hpp>
#include <string>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module_pass_manager;
struct MIGRAPHX_EXPORT fuse_concat
{
std::string name() const { return "fuse_concat"; }
void apply(module_pass_manager& mpm) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_FUSE_CONCAT_HPP
...@@ -239,7 +239,9 @@ struct MIGRAPHX_EXPORT module ...@@ -239,7 +239,9 @@ struct MIGRAPHX_EXPORT module
MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y); MIGRAPHX_EXPORT friend bool operator==(const module& x, const module& y);
friend bool operator!=(const module& x, const module& y) { return not(x == y); } friend bool operator!=(const module& x, const module& y) { return not(x == y); }
friend struct program;
private: private:
void set_name(const std::string& name);
void assign(const module& m); void assign(const module& m);
void calc_implicit_deps(const module& smod, void calc_implicit_deps(const module& smod,
const module& pmod, const module& pmod,
......
...@@ -39,6 +39,7 @@ struct module_pass_manager ...@@ -39,6 +39,7 @@ struct module_pass_manager
module_pass_manager(const module_pass_manager&) = delete; module_pass_manager(const module_pass_manager&) = delete;
virtual module& get_module() = 0; virtual module& get_module() = 0;
virtual module* create_module(const std::string& name) = 0; virtual module* create_module(const std::string& name) = 0;
virtual module* create_module(const std::string& name, const module& m) = 0;
virtual module* get_common_parent() = 0; virtual module* get_common_parent() = 0;
virtual module* get_root_module() = 0; virtual module* get_root_module() = 0;
virtual void run_pass(const pass& p) = 0; virtual void run_pass(const pass& p) = 0;
......
...@@ -136,6 +136,7 @@ struct MIGRAPHX_EXPORT program ...@@ -136,6 +136,7 @@ struct MIGRAPHX_EXPORT program
// module related api // module related api
module* create_module(const std::string& name); module* create_module(const std::string& name);
module* create_module(const std::string& name, module m);
module* get_module(const std::string& name); module* get_module(const std::string& name);
const module* get_module(const std::string& name) const; const module* get_module(const std::string& name) const;
......
...@@ -177,6 +177,18 @@ inline std::string interpolate_string(const std::string& input, ...@@ -177,6 +177,18 @@ inline std::string interpolate_string(const std::string& input,
std::move(end)); std::move(end));
} }
inline std::string to_c_id(const std::string& name, char rep = '_')
{
std::string id = transform_string(name, [&](auto c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return rep;
});
while(id.find("__") != std::string::npos)
replace_string_inplace(id, "__", "_");
return id;
}
template <class Iterator> template <class Iterator>
inline std::string to_string_range(Iterator start, Iterator last, const char* delim = ", ") inline std::string to_string_range(Iterator start, Iterator last, const char* delim = ", ")
{ {
......
...@@ -134,6 +134,11 @@ module& module::operator=(module m) ...@@ -134,6 +134,11 @@ module& module::operator=(module m)
std::string module::name() const { return impl->name; } std::string module::name() const { return impl->name; }
void module::set_name(const std::string& name)
{
impl->name = name;
}
bool module::bypass() const { return impl->bypass; } bool module::bypass() const { return impl->bypass; }
void module::set_bypass(bool b) { impl->bypass = b; } void module::set_bypass(bool b) { impl->bypass = b; }
...@@ -784,18 +789,6 @@ void module::print_graph(std::ostream& os, bool brief) const ...@@ -784,18 +789,6 @@ void module::print_graph(std::ostream& os, bool brief) const
os << "}" << std::endl; os << "}" << std::endl;
} }
static std::string to_c_id(const std::string& name, char rep = '_')
{
std::string id = transform_string(name, [&](auto c) {
if(with_char(::isalnum)(c) or c == '_')
return c;
return rep;
});
while(contains(id, "__"))
replace_string_inplace(id, "__", "_");
return id;
}
static std::string cpp_var_name(const std::string& name) static std::string cpp_var_name(const std::string& name)
{ {
std::string prefix = "x_"; std::string prefix = "x_";
......
...@@ -99,6 +99,12 @@ struct module_pm : module_pass_manager ...@@ -99,6 +99,12 @@ struct module_pm : module_pass_manager
return prog->create_module(name); return prog->create_module(name);
} }
virtual module* create_module(const std::string& name, const module& m) override
{
assert(prog);
return prog->create_module(name, m);
}
virtual module* get_common_parent() override { return common_parent; } virtual module* get_common_parent() override { return common_parent; }
virtual module* get_root_module() override virtual module* get_root_module() override
......
...@@ -1064,6 +1064,13 @@ module* program::create_module(const std::string& name) ...@@ -1064,6 +1064,13 @@ module* program::create_module(const std::string& name)
auto r = impl->modules.emplace(name, name); auto r = impl->modules.emplace(name, name);
return &(r.first->second); return &(r.first->second);
} }
module* program::create_module(const std::string& name, module m)
{
assert(not contains(impl->modules, name));
m.set_name(name);
auto r = impl->modules.emplace(name, std::move(m));
return &(r.first->second);
}
module* program::get_module(const std::string& name) { return &impl->modules.at(name); } module* program::get_module(const std::string& name) { return &impl->modules.at(name); }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp> #include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/algorithm.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -60,6 +61,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) ...@@ -60,6 +61,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
struct concat_compiler : compiler<concat_compiler> struct concat_compiler : compiler<concat_compiler>
{ {
std::vector<std::string> names() const { return {"concat"}; } std::vector<std::string> names() const { return {"concat"}; }
...@@ -112,6 +114,109 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -112,6 +114,109 @@ struct concat_compiler : compiler<concat_compiler>
} }
}; };
// NOLINTNEXTLINE
static const char* const fused_concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_concat_compiler : compiler<fused_concat_compiler>
{
std::vector<std::string> names() const { return {"fused_concat"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
auto op_names = v.at("ops").to_vector<std::string>();
auto args = v.at("args");
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
auto nelements_per_op = options.inputs.back().elements() / op_names.size();
options.set_launch_params(
v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params;
std::vector<std::string> concat_args;
for(const auto& name:op_names)
{
auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x";
transform(range(n), std::back_inserter(concat_params), [&](auto i) {
return "auto " + prefix + std::to_string(i);
});
std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"};
transform(range(n), std::back_inserter(pack_args), [&](auto i) {
return prefix + std::to_string(i);
});
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
}
auto src = interpolate_string(
fused_concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", join_strings(concat_params, ", ")},
{"concat_args", join_strings(concat_args, ", ")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), std::inserter(mod_names_lookup, mod_names_lookup.end()), [&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(), "pointwise" + std::to_string(i));
});
v["preamble"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end(), std::string{}, std::plus<>{}, [&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::back_inserter(mod_names), [&](module_ref mod) {
return mod_names_lookup.at(mod->name());
});
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(), ins->module_inputs().end(), std::inserter(mod_args, mod_args.end()), [&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
v["kernel"] = transform_accumulate(ins->module_inputs().begin(), ins->module_inputs().end()-1, std::string{}, std::plus<>{}, [&](module_ref mod) {
return generate_name_from_ops(*mod) + "_";
}) + "concat_" + generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -89,7 +89,8 @@ __device__ auto concat2(InputPacks... input_packs) ...@@ -89,7 +89,8 @@ __device__ auto concat2(InputPacks... input_packs)
}); });
}); });
})(_c<0>, input_packs...); })(_c<0>, input_packs...);
} };
}
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP #endif // MIGRAPHX_GUARD_KERNELS_CONCAT_HPP
...@@ -31,6 +31,7 @@ ...@@ -31,6 +31,7 @@
#include <migraphx/eliminate_data_type.hpp> #include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/fuse_concat.hpp>
#include <migraphx/fuse_pointwise.hpp> #include <migraphx/fuse_pointwise.hpp>
#include <migraphx/fuse_reduce.hpp> #include <migraphx/fuse_reduce.hpp>
#include <migraphx/inline_module.hpp> #include <migraphx/inline_module.hpp>
...@@ -140,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -140,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
fuse_concat{},
dead_code_elimination{},
#ifndef _WIN32 #ifndef _WIN32
enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}), enable_pass(enabled(MIGRAPHX_ENABLE_CK{}), fuse_ck{}),
#endif #endif
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/fuse_concat.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/functional.hpp>
#include <test.hpp>
#include <pointwise.hpp>
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::fuse_concat{}, migraphx::dead_code_elimination{}});
}
template<class F>
struct concat_arg
{
std::string name;
std::vector<migraphx::instruction_ref> inputs;
F f;
};
template<class F>
concat_arg<F> arg(std::string name, std::vector<migraphx::instruction_ref> inputs, F f)
{
return {std::move(name), std::move(inputs), std::move(f)};
}
template <class Arg, class... Args>
migraphx::instruction_ref add_concat(migraphx::program& p,
std::size_t axis,
Arg post_arg,
Args... args)
{
std::vector<migraphx::module_ref> module_inputs;
std::vector<migraphx::instruction_ref> ins_inputs;
migraphx::each_args([&](auto arg) {
module_inputs.push_back(create_pointwise_module(p, arg.name, arg.inputs, arg.f));
ins_inputs.insert(ins_inputs.end(), arg.inputs.begin(), arg.inputs.end());
}, args...);
module_inputs.push_back(create_pointwise_module(p, post_arg.name, {}, [&](auto* pm, auto&&) {
std::vector<migraphx::instruction_ref> params;
params.push_back(pm->add_parameter("!x0", migraphx::shape{ins_inputs.back()->get_shape().type()}));
std::transform(post_arg.inputs.begin(), post_arg.inputs.end(), std::back_inserter(params), [&](auto input) {
return pm->add_parameter("x" + std::to_string(params.size()),
migraphx::shape{input->get_shape().type()});
});
return post_arg.f(pm, params);
}));
auto* mm = p.get_main_module();
return mm->add_instruction(migraphx::make_op("fused_concat", {{"axis", axis}}), ins_inputs, module_inputs);
}
TEST_CASE(simple_pointwise_concat)
{
migraphx::shape s{migraphx::shape::float_type, {2, 3}};
migraphx::program p1;
{
auto* mm = p1.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto add = add_pointwise(p1, "main:pointwise0", {x, y}, single_pointwise("add"));
auto sub = add_pointwise(p1, "main:pointwise1", {x, y}, single_pointwise("sub"));
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), add, sub);
auto relu = add_pointwise(p1, "main:pointwise2", {concat}, single_pointwise("relu"));
mm->add_return({relu});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", s);
auto y = mm->add_parameter("y", s);
auto fused_concat = add_concat(p2, 1, arg("main:pointwise2:concat", {}, single_pointwise("relu")), arg("concat:main:pointwise0", {x, y}, single_pointwise("add")), arg("concat:main:pointwise1", {x, y}, single_pointwise("sub")));
mm->add_return({fused_concat});
}
EXPECT(p1 == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment