Commit e34cb7c1 authored by Paul's avatar Paul
Browse files

Add pointwise fusion

parent 97072183
...@@ -31,16 +31,15 @@ struct ck_gemm ...@@ -31,16 +31,15 @@ struct ck_gemm
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{ {
check_shapes{inputs, *this}.not_broadcasted(); check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1) // if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule."); // MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size(); auto a = inputs[0];
auto a = inputs[n - 2]; auto b = inputs[1];
auto b = inputs[n - 1]; for(const auto& input:inputs)
check_gemm_shape(a); check_gemm_shape(input);
check_gemm_shape(b);
return op.compute_shape({a, b}); return op.compute_shape({a, b});
} }
}; };
...@@ -64,13 +63,40 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins) ...@@ -64,13 +63,40 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
struct find_ck_gemm struct find_ck_gemm
{ {
// Find a convolution followed by a pointwise operation. // Find a gemm followed by a pointwise operation.
auto matcher() const { return match::name("dot")(is_ck_gemm().bind("gemm")); } auto matcher() const {
auto gemm = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm")));
return match::name("pointwise")(match::any_of[match::inputs()](gemm.bind("x")));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
mpm.get_module().replace_instruction(ins, ck_gemm{ins->get_operator()}, ins->inputs()); auto gemm_ins = r.instructions["gemm"];
auto x_ins = r.instructions["x"]; // input after contiguous
auto* pm = ins->module_inputs().front();
auto names = pm->get_parameter_names();
std::sort(names.begin(), names.end());
auto inputs = ins->inputs();
auto gemm_it = std::find(inputs.begin(), inputs.end(), x_ins);
auto gemm_idx = gemm_it - inputs.begin();
assert(gemm_it != inputs.end());
if (gemm_idx != 0)
{
// std::swap(inputs[0], inputs[gemm_idx]);
auto first_param = pm->get_parameter(names[0]);
auto gemm_param = pm->get_parameter(names[gemm_idx]);
auto new_gemm_param = pm->add_parameter(names[0] + ".0", gemm_param->get_shape());
auto new_first_param = pm->add_parameter(names[gemm_idx] + ".0", first_param->get_shape());
pm->replace_instruction(gemm_param, new_gemm_param);
pm->replace_instruction(first_param, new_first_param);
pm->remove_instruction(first_param);
pm->remove_instruction(gemm_param);
}
inputs.erase(gemm_it);
inputs.insert(inputs.begin(), gemm_ins->inputs().begin(), gemm_ins->inputs().end());
mpm.get_module().replace_instruction(ins, ck_gemm{}, inputs, {pm});
} }
}; };
......
...@@ -29,14 +29,12 @@ ...@@ -29,14 +29,12 @@
#include <migraphx/gpu/compile_hip_code_object.hpp> #include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp> #include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
...@@ -48,6 +46,8 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -48,6 +46,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
...@@ -55,15 +55,18 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING); ...@@ -55,15 +55,18 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
static const char* const ck_gemm_kernel = R"__migraphx__( static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp> #include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp> #include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
namespace migraphx { namespace migraphx {
${preamble}
extern "C" { extern "C" {
__global__ void ck_gemm_kernel(void* a_p, void* b_p, void* c_p) __global__ void ${kernel}(${params})
{ {
make_tensors()(a_p, b_p, c_p)([&](auto a, auto b, auto c) { transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<CK_DeviceGemmMultipleD<${instance}>>(a, b, c); ck_gemm<CK_DeviceGemmMultipleD<${instance}>>(xs...);
}); });
} }
...@@ -136,23 +139,42 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -136,23 +139,42 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return shape::cpp_type(s.type()); return shape::cpp_type(s.type());
} }
template<class Iterator, class F>
static std::string ck_tuple(Iterator start, Iterator last, F f)
{
std::vector<std::string> s;
std::transform(start, last, std::back_inserter(s), f);
return "ck::Tuple<" + join_strings(s, ",") + ">";
}
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; } std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& /* ctx */, const std::vector<shape>& inputs, const value& v) const
{ {
auto a_shape = inputs[0]; auto a_shape = inputs[0];
auto b_shape = inputs[1]; auto b_shape = inputs[1];
auto c_shape = inputs[2]; auto c_shape = inputs.back();
auto m = c_shape.lens().front(); auto m = c_shape.lens().front();
auto n = c_shape.lens().back(); auto n = c_shape.lens().back();
auto i = v.get("tuning_val", get_tuning_for(inputs)); auto i = v.get("tuning_val", get_tuning_for(inputs));
const auto& instance = get_instance(i, [&](const auto& x) -> bool { auto instance = get_instance(i, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9]; get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
}); });
assert(inputs.size() < 4 or v.contains("post"));
if (v.contains("post"))
{
assert(instance[2] == "ck::Tuple<>");
instance[2] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_layout);
assert(instance[8] == "ck::Tuple<>");
instance[8] = ck_tuple(inputs.begin()+2, inputs.end()-1, &get_type);
assert(instance[12] == "ck_passthrough");
instance[12] = v.at("post").to<std::string>();
}
hip_compile_options options; hip_compile_options options;
auto block_size = get_block_size(instance); auto block_size = get_block_size(instance);
...@@ -160,18 +182,34 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -160,18 +182,34 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.set_launch_params(v, grid_size * block_size, block_size); options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs; options.inputs = inputs;
options.output = c_shape; options.output = c_shape;
options.kernel_name = "ck_gemm_kernel"; options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
auto src = interpolate_string(ck_gemm_kernel, {{"instance", join_strings(instance, ",")}}); auto src = interpolate_string(ck_gemm_kernel, {
{"instance", join_strings(instance, ",")},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"preamble", v.get("preamble", std::string{})},
{"kernel", options.kernel_name}
});
return compile_hip_code_object(src, options); return compile_hip_code_object(src, options);
} }
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value();
v["kernel"] = "ck_gemm_kernel";
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_ck_gemm_function") + "\nMIGRAPHX_LIFT_CLASS(post_ck_gemm, post_ck_gemm_function);";
v["post"] = "ck_function_adaptor<post_ck_gemm>";
v["kernel"] = "ck_gemm_" + generate_name_from_ops(*pm) + "_kernel";
}
auto shapes = to_shapes(ins->inputs()); auto shapes = to_shapes(ins->inputs());
return action_decorate(replace(compile_op(ctx, shapes, op.to_value())), [=] { return action_decorate(replace(compile_op(ctx, shapes, v)), [=] {
if(enabled(MIGRAPHX_LOG_CK_GEMM{})) if(enabled(MIGRAPHX_LOG_CK_GEMM{}))
std::cout << "ck_gemm: " << to_json_string(to_value(shapes)) << std::endl; std::cout << "ck_gemm: " << to_json_string(to_value(shapes)) << std::endl;
}); });
......
...@@ -24,6 +24,12 @@ struct to_ck_type_impl<migraphx::half> ...@@ -24,6 +24,12 @@ struct to_ck_type_impl<migraphx::half>
using type = ck::half_t; using type = ck::half_t;
}; };
template <class T>
struct to_ck_type_impl<const T>
{
using type = const typename to_ck_type_impl<T>::type;
};
template <class Shape> template <class Shape>
constexpr bool is_row_major() constexpr bool is_row_major()
{ {
...@@ -44,6 +50,18 @@ constexpr bool is_row_major() ...@@ -44,6 +50,18 @@ constexpr bool is_row_major()
template <class T> template <class T>
using to_ck_type = typename detail::to_ck_type_impl<T>::type; using to_ck_type = typename detail::to_ck_type_impl<T>::type;
template<class T>
constexpr auto to_ck_pointer(T* x)
{
return static_cast<to_ck_type<T>*>(x);
}
template<class T>
constexpr auto to_ck_const_pointer(const T* x)
{
return static_cast<const to_ck_type<T>*>(x);
}
template <class Shape> template <class Shape>
using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(), using to_ck_gemm_layout = conditional_t<detail::is_row_major<get_shape_c<Shape>>(),
ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor,
......
...@@ -33,8 +33,8 @@ ...@@ -33,8 +33,8 @@
namespace migraphx { namespace migraphx {
template <class G, class A, class B, class E, class... Ds> template <class G, class E, class A, class B, class... Ds>
__device__ void ck_gemm(A a, B b, E e, Ds... ds) __device__ void ck_gemm(E e, A a, B b, Ds... ds)
{ {
constexpr const G gemm{}; constexpr const G gemm{};
...@@ -64,10 +64,10 @@ __device__ void ck_gemm(A a, B b, E e, Ds... ds) ...@@ -64,10 +64,10 @@ __device__ void ck_gemm(A a, B b, E e, Ds... ds)
constexpr const bool HasMainKBlockLoop = constexpr const bool HasMainKBlockLoop =
GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) * GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{})); a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
GridwiseGemm::template Run<HasMainKBlockLoop>(a.data(), GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
b.data(), to_ck_const_pointer(b.data()),
ck::make_tuple(ds.data()...), ck::make_tuple(to_ck_const_pointer(ds.data())...),
e.data(), to_ck_pointer(e.data()),
p_shared_block, p_shared_block,
gemm.a_element_op, gemm.a_element_op,
gemm.b_element_op, gemm.b_element_op,
......
...@@ -35,6 +35,15 @@ ...@@ -35,6 +35,15 @@
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \ [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) (__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT_CLASS(name, ...) \
struct name \
{ \
template<class... PrivateLiftTs> \
constexpr auto operator()(PrivateLiftTs&&... private_lisft_xs) const MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) \
}
namespace migraphx { namespace migraphx {
struct swallow struct swallow
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_relu : verify_program<gemm_add_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {16, 8}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {8, 32}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {16, 32}});
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), dot, c);
mm->add_instruction(migraphx::make_op("relu"), add);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment