Commit 5c4e15f2 authored by Paul's avatar Paul
Browse files

Unify the concat versions

parent 602924d4
...@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) ...@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
struct concat_compiler : compiler<concat_compiler> struct concat_compiler : compiler<concat_compiler>
{ {
std::vector<std::string> names() const { return {"concat"}; } std::vector<std::string> names() const { return {"fused_concat", "concat"}; }
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
return inputs.back().elements() / (inputs.size() - 1);
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(
concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
// NOLINTNEXTLINE
static const char* const fused_concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_concat_compiler : compiler<fused_concat_compiler>
{
std::vector<std::string> names() const { return {"fused_concat"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256)); options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params; std::vector<std::string> concat_params;
std::vector<std::string> concat_args; std::vector<std::string> concat_args;
for(const auto& name : op_names) for(auto i : range(op_names.size()))
{ {
const auto& name = op_names[i];
auto n = args.at(name).to<std::size_t>(); auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x"; auto prefix = to_c_id(name + std::to_string(i) + "_concat_x");
transform(range(n), std::back_inserter(concat_params), [&](auto i) { transform(range(n), std::back_inserter(concat_params), [&](auto j) {
return "auto " + prefix + std::to_string(i); return "auto " + prefix + std::to_string(j);
}); });
std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"}; std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"};
transform(range(n), std::back_inserter(pack_args), [&](auto i) { transform(range(n), std::back_inserter(pack_args), [&](auto j) {
return prefix + std::to_string(i); return prefix + std::to_string(j);
}); });
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")"); concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
} }
auto src = interpolate_string(fused_concat_kernel, auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
...@@ -189,50 +112,69 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -189,50 +112,69 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup; if (op.name() == "fused_concat")
transform(range(ins->module_inputs().size()), {
std::inserter(mod_names_lookup, mod_names_lookup.end()), std::unordered_map<std::string, std::string> mod_names_lookup;
[&](auto i) { transform(range(ins->module_inputs().size()),
return std::make_pair(ins->module_inputs()[i]->name(), std::inserter(mod_names_lookup, mod_names_lookup.end()),
"pointwise" + std::to_string(i)); [&](auto i) {
}); return std::make_pair(ins->module_inputs()[i]->name(),
v["preamble"] = transform_accumulate( "pointwise" + std::to_string(i));
ins->module_inputs().begin(), });
ins->module_inputs().end(), v["preamble"] = transform_accumulate(
std::string{}, ins->module_inputs().begin(),
std::plus<>{}, ins->module_inputs().end(),
[&](module_ref mod) { std::string{},
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n"; std::plus<>{},
}); [&](module_ref mod) {
std::vector<std::string> mod_names; return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
std::transform(ins->module_inputs().begin(), });
ins->module_inputs().end() - 1, std::vector<std::string> mod_names;
std::back_inserter(mod_names), std::transform(ins->module_inputs().begin(),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); }); ins->module_inputs().end() - 1,
v["ops"] = mod_names; std::back_inserter(mod_names),
module_ref last_mod = ins->module_inputs().back(); [&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")"; v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args; module_ref last_mod = ins->module_inputs().back();
std::transform(ins->module_inputs().begin(), v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
ins->module_inputs().end() - 1, std::unordered_map<std::string, std::size_t> mod_args;
std::inserter(mod_args, mod_args.end()), std::transform(ins->module_inputs().begin(),
[&](module_ref mod) { ins->module_inputs().end() - 1,
const auto& name = mod_names_lookup.at(mod->name()); std::inserter(mod_args, mod_args.end()),
return std::make_pair(name, mod->get_parameter_names().size()); [&](module_ref mod) {
}); const auto& name = mod_names_lookup.at(mod->name());
v["args"] = mod_args; return std::make_pair(name, mod->get_parameter_names().size());
auto prefix_name = transform_accumulate(ins->module_inputs().begin(), });
ins->module_inputs().end() - 1, v["args"] = mod_args;
std::string{}, auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
std::plus<>{}, ins->module_inputs().end() - 1,
[&](module_ref mod) -> std::string { std::string{},
auto name = generate_name_from_ops(*mod); std::plus<>{},
if(name.empty()) [&](module_ref mod) -> std::string {
return ""; auto name = generate_name_from_ops(*mod);
return name + "_"; if(name.empty())
}); return "";
v["kernel"] = prefix_name + "concat_" + return name + "_";
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel"; });
v["kernel"] = prefix_name + "concat_" +
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
}
else if (op.name() == "concat")
{
auto concat_inputs = ins->inputs().size() - 1;
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
concat_inputs = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
std::vector<std::string> mod_names(concat_inputs, "op::id{}");
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}};
v["args"] = mod_args;
}
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
...@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input) ...@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template <index_int Axis, class... Inputs>
__device__ auto concat(Inputs... inputs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input) {
concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
idx.global_stride(input.get_shape().elements(),
[&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
}
template <index_int Axis, class... InputPacks> template <index_int Axis, class... InputPacks>
__device__ auto concat2(InputPacks... input_packs) __device__ auto concat(InputPacks... input_packs)
{ {
return [=](auto f, auto... ts) { return [=](auto f, auto... ts) {
auto idx = make_index(); auto idx = make_index();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment