Commit 5c4e15f2 authored by Paul's avatar Paul
Browse files

Unify the concat versions

parent 602924d4
......@@ -63,85 +63,7 @@ MIGRAPHX_GLOBAL void ${kernel}(${params})
struct concat_compiler : compiler<concat_compiler>
{
std::vector<std::string> names() const { return {"concat"}; }
static std::size_t get_concat_elements(const std::vector<shape>& inputs)
{
return inputs.back().elements() / (inputs.size() - 1);
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options;
options.inputs = inputs;
options.output = inputs.back();
options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs);
vectorize vec{};
if(axis != v.at("axis").to<std::size_t>())
vec = vectorize::elements(ctx, axis, options.inputs);
options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(
concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
// NOLINTNEXTLINE
static const char* const fused_concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat2<${axis}>(${concat_args})(${post}, y, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct fused_concat_compiler : compiler<fused_concat_compiler>
{
std::vector<std::string> names() const { return {"fused_concat"}; }
std::vector<std::string> names() const { return {"fused_concat", "concat"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
......@@ -160,20 +82,21 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
options.set_launch_params(v, compute_global_for(ctx, nelements_per_op / vec.size, 256));
std::vector<std::string> concat_params;
std::vector<std::string> concat_args;
for(const auto& name : op_names)
for(auto i : range(op_names.size()))
{
const auto& name = op_names[i];
auto n = args.at(name).to<std::size_t>();
auto prefix = name + "_concat_x";
transform(range(n), std::back_inserter(concat_params), [&](auto i) {
return "auto " + prefix + std::to_string(i);
auto prefix = to_c_id(name + std::to_string(i) + "_concat_x");
transform(range(n), std::back_inserter(concat_params), [&](auto j) {
return "auto " + prefix + std::to_string(j);
});
std::vector<std::string> pack_args = {"MIGRAPHX_LIFT(" + name + ")"};
transform(range(n), std::back_inserter(pack_args), [&](auto i) {
return prefix + std::to_string(i);
transform(range(n), std::back_inserter(pack_args), [&](auto j) {
return prefix + std::to_string(j);
});
concat_args.push_back("pack(" + join_strings(pack_args, ", ") + ")");
}
auto src = interpolate_string(fused_concat_kernel,
auto src = interpolate_string(concat_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
......@@ -189,50 +112,69 @@ struct fused_concat_compiler : compiler<fused_concat_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()),
std::inserter(mod_names_lookup, mod_names_lookup.end()),
[&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(),
"pointwise" + std::to_string(i));
});
v["preamble"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end(),
std::string{},
std::plus<>{},
[&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names;
module_ref last_mod = ins->module_inputs().back();
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) -> std::string {
auto name = generate_name_from_ops(*mod);
if(name.empty())
return "";
return name + "_";
});
v["kernel"] = prefix_name + "concat_" +
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
if (op.name() == "fused_concat")
{
std::unordered_map<std::string, std::string> mod_names_lookup;
transform(range(ins->module_inputs().size()),
std::inserter(mod_names_lookup, mod_names_lookup.end()),
[&](auto i) {
return std::make_pair(ins->module_inputs()[i]->name(),
"pointwise" + std::to_string(i));
});
v["preamble"] = transform_accumulate(
ins->module_inputs().begin(),
ins->module_inputs().end(),
std::string{},
std::plus<>{},
[&](module_ref mod) {
return generate_pointwise(*mod, mod_names_lookup.at(mod->name())) + "\n";
});
std::vector<std::string> mod_names;
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::back_inserter(mod_names),
[&](module_ref mod) { return mod_names_lookup.at(mod->name()); });
v["ops"] = mod_names;
module_ref last_mod = ins->module_inputs().back();
v["post"] = "MIGRAPHX_LIFT(" + mod_names_lookup.at(last_mod->name()) + ")";
std::unordered_map<std::string, std::size_t> mod_args;
std::transform(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::inserter(mod_args, mod_args.end()),
[&](module_ref mod) {
const auto& name = mod_names_lookup.at(mod->name());
return std::make_pair(name, mod->get_parameter_names().size());
});
v["args"] = mod_args;
auto prefix_name = transform_accumulate(ins->module_inputs().begin(),
ins->module_inputs().end() - 1,
std::string{},
std::plus<>{},
[&](module_ref mod) -> std::string {
auto name = generate_name_from_ops(*mod);
if(name.empty())
return "";
return name + "_";
});
v["kernel"] = prefix_name + "concat_" +
generate_name_from_ops(*(ins->module_inputs().back())) + "_kernel";
}
else if (op.name() == "concat")
{
auto concat_inputs = ins->inputs().size() - 1;
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
concat_inputs = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
std::vector<std::string> mod_names(concat_inputs, "op::id{}");
v["ops"] = mod_names;
std::unordered_map<std::string, std::size_t> mod_args = {{"op::id{}", 1}};
v["args"] = mod_args;
}
return compile_op(ctx, to_shapes(ins->inputs()), v);
}
};
......
......@@ -59,23 +59,8 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>;
}
template <index_int Axis, class... Inputs>
__device__ auto concat(Inputs... inputs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
fold([&](auto start, auto input) {
concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
idx.global_stride(input.get_shape().elements(),
[&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
}
template <index_int Axis, class... InputPacks>
__device__ auto concat2(InputPacks... input_packs)
__device__ auto concat(InputPacks... input_packs)
{
return [=](auto f, auto... ts) {
auto idx = make_index();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment