"include/vscode:/vscode.git/clone" did not exist on "b0a4674ce7d9299ad747b741fec1815445960ad3"
Unverified Commit 2f48b11a authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Concat pointwise fusions (#1388)

parent 04e33ec5
...@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL); ...@@ -39,19 +39,26 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_GPU_COMPILE_PARALLEL);
struct precompile_op struct precompile_op
{ {
operation op = op::identity{}; operation op = op::identity{};
std::size_t additional_args = 1;
bool ignore_modules = false;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
return pack(f(self.op, "op")); return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules"));
} }
std::string name() const { return "gpu::precompile_op"; } std::string name() const { return "gpu::precompile_op"; }
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
inputs.pop_back(); // Pop off additional args
inputs.resize(inputs.size() - additional_args);
if(ignore_modules)
return op.compute_shape(inputs);
return op.compute_shape(inputs, mods); return op.compute_shape(inputs, mods);
} }
......
...@@ -772,11 +772,9 @@ struct find_layernorm_pointwise ...@@ -772,11 +772,9 @@ struct find_layernorm_pointwise
{ {
auto ins = r.result; auto ins = r.result;
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if(not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto* pm = ins->module_inputs().front();
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end()); inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
...@@ -785,6 +783,37 @@ struct find_layernorm_pointwise ...@@ -785,6 +783,37 @@ struct find_layernorm_pointwise
} }
}; };
struct find_concat_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(
match::arg(0)(precompile_name("concat").bind("concat")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto concat = r.instructions["concat"];
if(not concat->module_inputs().empty())
return;
// TODO: Handle type conversions
if(ins->get_shape().type() != concat->get_shape().type())
return;
auto* pm = ins->module_inputs().front();
auto inputs = concat->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
auto op = concat->get_operator();
op.from_value({{"additional_args", ins->inputs().size() - 1}, {"ignore_modules", true}});
m.replace_instruction(ins, op, inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}); match::find_matches(m, find_contiguous_pointwise{});
...@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const ...@@ -793,6 +822,7 @@ void fuse_ops::apply(module& m) const
run_passes(m, {dead_code_elimination{}}); run_passes(m, {dead_code_elimination{}});
match::find_matches(m, match::find_matches(m,
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_concat_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{}, find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
......
...@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT ...@@ -38,16 +38,19 @@ using namespace migraphx::gpu::gen; // NOLINT
static const char* const concat_kernel = R"__migraphx__( static const char* const concat_kernel = R"__migraphx__(
#include <migraphx/kernels/concat.hpp> #include <migraphx/kernels/concat.hpp>
#include <migraphx/kernels/vectorize.hpp> #include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/ops.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
${preamble}
extern "C" { extern "C" {
__global__ void ${kernel}(${params}) __global__ void ${kernel}(${params})
{ {
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) { transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, ${concat_params}, auto... xs) {
concat<${axis}>(y, xs...); concat<${axis}>(${concat_args})(${post}, y, xs...);
}); });
} }
...@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler> ...@@ -68,28 +71,42 @@ struct concat_compiler : compiler<concat_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
// TODO: Use reduce_dims auto num_of_concat_inputs = v.get("concat_inputs", inputs.size() - 1);
hip_compile_options options; hip_compile_options options;
options.inputs = inputs; options.inputs = inputs;
options.output = inputs.back(); options.output = inputs.back();
options.params = "-Wno-float-equal"; options.params = "-Wno-float-equal";
options.kernel_name = v.get("kernel", "concat_kernel");
auto axis = find_fast_axis(options.inputs); auto axis = find_fast_axis(options.inputs);
auto vec = vectorize::elements(ctx, axis, options.inputs); auto vec = vectorize::elements(ctx, axis, options.inputs);
options.kernel_name = v.get("kernel", "concat_kernel");
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256)); v, compute_global_for(ctx, get_concat_elements(options.inputs) / vec.size, 256));
auto src = interpolate_string(concat_kernel, auto src = interpolate_string(
{{"kernel", options.kernel_name}, concat_kernel,
{"params", enum_params(inputs.size(), "void * private_p")}, {{"kernel", options.kernel_name},
{"args", enum_params(inputs.size(), "private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"transformers", make_transformer_args(vec)}, {"args", enum_params(inputs.size(), "private_p")},
{"axis", v.at("axis").to<std::string>()}}); {"concat_params", enum_params(num_of_concat_inputs, "auto concat_x")},
{"concat_args", enum_params(num_of_concat_inputs, "concat_x")},
{"post", v.get("post", std::string{"op::id{}"})},
{"transformers", make_transformer_args(vec)},
{"preamble", v.get("preamble", std::string{})},
{"axis", v.at("axis").to<std::string>()}});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value())); auto v = op.to_value();
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["concat_inputs"] = ins->inputs().size() - pm->get_parameter_names().size();
v["preamble"] = generate_pointwise(*pm, "post_concat");
v["post"] = "MIGRAPHX_LIFT(post_concat)";
v["kernel"] = "concat_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
}; };
......
...@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start) ...@@ -41,7 +41,15 @@ constexpr auto concat_slice(Output out, Input, Start)
return Start{} * output_shape.strides[Axis]; return Start{} * output_shape.strides[Axis];
}); });
constexpr auto s = make_shape(lens, strides); constexpr auto s = make_shape(lens, strides);
return make_tensor_view(&out[offset], s); MIGRAPHX_ASSERT(offset < out.get_shape().element_space());
MIGRAPHX_ASSERT((s.element_space() + offset) <= out.get_shape().element_space());
return make_tensor_view(out.data() + offset, s);
}
template <index_int Axis, class Input, class Start, class... Ts>
constexpr auto concat_slices(Input input, Start start, Ts... xs)
{
return [=](auto f) { f(concat_slice<Axis>(xs, input, start)...); };
} }
template <index_int Axis, class Input> template <index_int Axis, class Input>
...@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input) ...@@ -51,15 +59,19 @@ constexpr auto concat_ends(Input)
return _c<lens[Axis]>; return _c<lens[Axis]>;
} }
template <index_int Axis, class Output, class... Inputs> template <index_int Axis, class... Inputs>
__device__ void concat(Output output, Inputs... inputs) __device__ auto concat(Inputs... inputs)
{ {
auto idx = make_index(); return [=](auto f, auto... ts) {
fold([&](auto start, auto input) { auto idx = make_index();
auto y = concat_slice<Axis>(output, input, start); fold([&](auto start, auto input) {
idx.global_stride(input.get_shape().elements(), [&](auto i) { y[i] = input[i]; }); concat_slices<Axis>(input, start, ts...)([&](auto y, auto... xs) {
return start + concat_ends<Axis>(input); idx.global_stride(input.get_shape().elements(),
})(_c<0>, inputs...); [&](auto i) { y[i] = f(input[i], xs[i]...); });
});
return start + concat_ends<Axis>(input);
})(_c<0>, inputs...);
};
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -101,7 +101,10 @@ struct mlir_handle ...@@ -101,7 +101,10 @@ struct mlir_handle
mlir_handle(T p) : handle(ptr{p}) {} mlir_handle(T p) : handle(ptr{p}) {}
T get() const { return handle.get().get(); } T get() const
{
return handle.get().get(); // NOLINT(readability-redundant-smartptr-get)
}
T release() { return handle.release().get(); } T release() { return handle.release().get(); }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_concat_broadcast_add : verify_program<test_concat_broadcast_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 2, 4}};
migraphx::shape s1{migraphx::shape::float_type, {1, 6, 4}};
migraphx::shape s2{migraphx::shape::float_type, {6, 1}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s0);
auto z = mm->add_parameter("z", s0);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), x, y, z);
auto b = mm->add_literal(migraphx::generate_literal(s2, 15));
auto bb =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", s1.lens()}}), b);
mm->add_instruction(migraphx::make_op("add"), concat, bb);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_slice_concat_add : verify_program<test_slice_concat_add>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s0{migraphx::shape::float_type, {1, 24, 2, 2}};
migraphx::shape s1{migraphx::shape::float_type, {1, 8, 2, 2}};
auto x = mm->add_parameter("x", s0);
auto y = mm->add_parameter("y", s1);
auto z = mm->add_parameter("z", s0);
auto slice = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {8}}}), x);
auto concat = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), slice, y, y);
mm->add_instruction(migraphx::make_op("add"), concat, z);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment