Commit 5c4e15f2 authored by Paul's avatar Paul
Browse files

Unify the concat versions

parent 602924d4
...@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params}) ...@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
struct concat_compiler : compiler<concat_compiler> struct concat_compiler : compiler<concat_compiler>
{ {
std::vector<std::string> names() const { return {"concat"}; } std::vector<std::string> names() const { return {"fused_concat", "concat"}; }
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
return inputs.back().elements() / (inputs.size() - 1);
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(
concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
// NOLINTNEXTLINE
static const char* const fused_concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_concat_compiler : compiler<fused_concat_compiler>
{
std::vector<std::string> names() const { return {"fused_concat"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256)); options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params; std::vector<std::string> concat_params;
std::vector<std::string> concat_args; std::vector<std::string> concat_args;
for(const auto& name : op_names) for(auto i : range(op_names.size()))
{ {
const auto& name = op_names[i];
auto n = args.at(name).to<std::size_t>(); auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x"; auto prefix = to_c_id(name + std::to_string(i) + "_concat_x");
transform(range(n), std::back_inserter(concat_params), [&](auto i) { transform(range(n), std::back_inserter(concat_params), [&](auto j) {
return "auto " + prefix + std::to_string(i); return "auto " + prefix + std::to_string(j);
}); });
std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"}; std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"};
transform(range(n), std::back_inserter(pack_args), [&](auto i) { transform(range(n), std::back_inserter(pack_args), [&](auto j) {
return prefix + std::to_string(i); return prefix + std::to_string(j);
}); });
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")"); concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
} }
auto src = interpolate_string(fused_concat_kernel, auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
...@@ -189,6 +112,8 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -189,6 +112,8 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
if (op.name() == "fused_concat")
{
std::unordered_map<std::string, std::string> mod_names_lookup; std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()), transform(range(ins->module_inputs().size()),
std::inserter(mod_names_lookup, mod_names_lookup.end()), std::inserter(mod_names_lookup, mod_names_lookup.end()),
...@@ -233,6 +158,23 @@ struct fused_concat_compiler : compiler<fused_concat_compiler> ...@@ -233,6 +158,23 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
}); });
v["kernel"] = prefix_name + "concat_" + v["kernel"] = prefix_name + "concat_" +
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel"; generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
}
else if (op.name() == "concat")
{
auto concat_inputs = ins->inputs().size() - 1;
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
concat_inputs = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
std::vector<std::string> mod_names(concat_inputs, "op::id{}");
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}};
v["args"] = mod_args;
}
return compile_op(ctx, to_shapes(ins->inputs()), v); return compile_op(ctx, to_shapes(ins->inputs()), v);
} }
}; };
......
...@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input) ...@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template <index_int Axis, class... Inputs>
__device__ auto concat(Inputs... inputs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input) {
concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
idx.global_stride(input.get_shape().elements(),
[&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
}
template <index_int Axis, class... InputPacks> template <index_int Axis, class... InputPacks>
__device__ auto concat2(InputPacks... input_packs) __device__ auto concat(InputPacks... input_packs)
{ {
return [=](auto f, auto... ts) { return [=](auto f, auto... ts) {
auto idx = make_index(); auto idx = make_index();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment