Unverified Commit a85b183b authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Final performance improvements for release (#1369)

* Improvements to handling and add constant passed to dot operator (#1280)
* Improve horizontal fusion of contiguous (#1292)
* Add pass to rewrite gelu as fast gelu (#1299)
* Add jit layernorm fusion (#1301)
parent 9a1ada1a
...@@ -82,6 +82,7 @@ add_library(migraphx ...@@ -82,6 +82,7 @@ add_library(migraphx
simplify_qdq.cpp simplify_qdq.cpp
sqlite.cpp sqlite.cpp
rewrite_batchnorm.cpp rewrite_batchnorm.cpp
rewrite_gelu.cpp
rewrite_pooling.cpp rewrite_pooling.cpp
rewrite_quantization.cpp rewrite_quantization.cpp
rewrite_rnn.cpp rewrite_rnn.cpp
......
...@@ -38,11 +38,11 @@ struct gelu_erf_matcher ...@@ -38,11 +38,11 @@ struct gelu_erf_matcher
F f; F f;
auto erf_fn() const auto erf_fn() const
{ {
return f("erf")( auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
used_once(), has_value(M_SQRT1_2, 1e-3)));
arg(0)(used_once(), auto div_sqrt_2 =
f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"), f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
has_value(M_SQRT1_2, 1e-3))))); return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
} }
auto add_erf() const auto add_erf() const
......
...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in ...@@ -564,6 +564,11 @@ MIGRAPHX_BASIC_MATCHER(is_unused, const matcher_context& ctx, instruction_ref in
return nullopt; return nullopt;
} }
MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
{
return contains({"broadcast", "multibroadcast"}, ins->name());
}
template <class... Ms> template <class... Ms>
auto skip(Ms... ms) auto skip(Ms... ms)
{ {
...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name) ...@@ -813,8 +818,7 @@ inline auto has_attribute(const std::string& name)
template <class... Ms> template <class... Ms>
auto pointwise(Ms... ms) auto pointwise(Ms... ms)
{ {
return match::has_attribute("pointwise")(match::any_of(match::nargs(1), match::nargs(2)), return match::has_attribute("pointwise")(ms...);
ms...);
} }
} // namespace match } // namespace match
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GELU_HPP
#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct module;
/**
* Rewrite gelu standard formula as the sigmoid approximation formula
*/
struct rewrite_gelu
{
std::string name() const { return "rewrite_gelu"; }
void apply(module& m) const;
};
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/match/gelu_erf.hpp>
#include <migraphx/common.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
struct find_gelu_erf
{
auto matcher() const { return match::gelu_erf(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x = r.instructions["x"];
if(x->get_shape().type() != migraphx::shape::half_type)
return;
auto lit = m.add_literal(literal{shape{x->get_shape().type()}, {1.702f}});
auto mul = insert_common_op(m, ins, make_op("mul"), {x, lit});
auto sig = m.insert_instruction(ins, make_op("neg"), mul);
sig = m.insert_instruction(ins, make_op("exp"), sig);
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
sig = insert_common_op(m, ins, make_op("add"), {sig, one});
sig = m.insert_instruction(ins, make_op("div"), x, sig);
m.replace_instruction(ins, sig);
}
};
void rewrite_gelu::apply(module& m) const { match::find_matches(m, find_gelu_erf{}); }
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -208,6 +208,42 @@ struct find_mul_add ...@@ -208,6 +208,42 @@ struct find_mul_add
} }
}; };
struct find_dot_add
{
auto matcher() const
{
return match::name("dot")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(match::any().bind("x"),
match::any_of(match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);
const bool flipped = a_ins == ins->inputs().back();
auto insert_dot = [&](auto x, auto y) {
if(flipped)
return m.insert_instruction(ins, make_op("dot"), y, x);
else
return m.insert_instruction(ins, make_op("dot"), x, y);
};
auto ax_ins = insert_dot(a_ins, x_ins);
auto ab_ins = insert_dot(a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_add_lit_broadcast struct find_add_lit_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast ...@@ -267,28 +303,26 @@ struct find_double_add_lit_broadcast
struct find_inner_broadcast struct find_inner_broadcast
{ {
auto matcher() const auto matcher() const { return pointwise(match::all_of[match::inputs()](match::broadcast())); }
{
return pointwise(
match::nargs(2),
match::args(match::name("broadcast").bind("x"), match::name("broadcast").bind("y")));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto broadcasts = ins->inputs();
auto y_ins = r.instructions["y"]; if(broadcasts.empty())
return;
auto xbroadcast = any_cast<op::broadcast>(x_ins->get_operator()); std::vector<instruction_ref> inputs;
auto ybroadcast = any_cast<op::broadcast>(y_ins->get_operator()); std::transform(broadcasts.begin(),
broadcasts.end(),
if(xbroadcast.axis != ybroadcast.axis) std::back_inserter(inputs),
[](auto i) { return i->inputs().front(); });
if(std::any_of(inputs.begin(), inputs.end(), [&](auto i) {
return i->get_shape() != inputs.front()->get_shape();
}))
return; return;
auto op = m.insert_instruction( auto op = m.insert_instruction(ins, ins->get_operator(), inputs);
ins, ins->get_operator(), x_ins->inputs().front(), y_ins->inputs().front()); m.replace_instruction(ins, broadcasts.front()->get_operator(), op);
m.replace_instruction(ins, xbroadcast, op);
} }
}; };
...@@ -416,8 +450,9 @@ struct find_splits ...@@ -416,8 +450,9 @@ struct find_splits
{ {
auto matcher() const auto matcher() const
{ {
return match::any(match::any_of[match::outputs()](match::name("slice")( return match::any(
match::any_of[match::outputs()](match::pointwise(), reduction())))); match::any_of[match::outputs()](match::name("slice")(match::any_of[match::outputs()](
match::pointwise(match::any_of(match::nargs(1), match::nargs(2))), reduction()))));
} }
static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2) static bool is_dependent(const module& m, instruction_ref ins1, instruction_ref ins2)
...@@ -1048,6 +1083,7 @@ void simplify_algebra::apply(module& m) const ...@@ -1048,6 +1083,7 @@ void simplify_algebra::apply(module& m) const
find_mul_conv{}, find_mul_conv{},
find_mul_slice_conv{}, find_mul_slice_conv{},
find_mul_add{}, find_mul_add{},
find_dot_add{},
find_div_const{}, find_div_const{},
find_sub_const{}, find_sub_const{},
find_rsqrt{}, find_rsqrt{},
......
...@@ -151,8 +151,11 @@ struct find_transpose ...@@ -151,8 +151,11 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("transpose")(match::none_of( auto output_not_transpose =
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose =
match::args(match::skip(match::name("contiguous"))(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
...@@ -664,9 +667,94 @@ struct find_slice_transpose ...@@ -664,9 +667,94 @@ struct find_slice_transpose
} }
}; };
struct find_transpose_slice
{
auto matcher() const
{
return match::name("transpose")(match::all_of[match::outputs()](match::name("slice")));
}
static std::vector<int64_t> slice_distance(const op::slice& op)
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slices = ins->outputs();
if(slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by lens of corresponding axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(),
v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1;
return f(x) % d;
});
};
if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if(sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
for(auto s : slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
};
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 4; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
...@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert{}, find_nested_convert{},
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -25,6 +25,13 @@ ...@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -75,25 +82,25 @@ std::string vectorize::str() const ...@@ -75,25 +82,25 @@ std::string vectorize::str() const
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{ {
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result; std::vector<bool> result(inputs.size());
std::transform(inputs.begin(), std::vector<std::size_t> preloaded;
inputs.end(), auto idxs = range(inputs.size());
std::back_inserter(result), std::copy_if(idxs.begin(), idxs.end(), std::back_inserter(preloaded), [&](auto i) {
[&](const shape& input) { return input.strides()[axis] == 0; }); return inputs[i].strides()[axis] == 0;
auto bytes = std::inner_product(inputs.begin(), });
inputs.end(), std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) {
result.begin(), return inputs[i].bytes();
std::size_t{0}, }));
std::plus<>{},
[](const shape& s, bool b) -> std::size_t { std::size_t bytes = 0;
if(b) for(auto i : preloaded)
return s.bytes(); {
return 0; auto input = inputs[i];
}); bytes += input.bytes();
if(bytes < max_lds_bytes) if(bytes > max_lds_bytes)
return {result}; break;
// TODO: Try to partially preload items result[i] = true;
std::fill(result.begin(), result.end(), false); }
return {result}; return {result};
} }
...@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
return g.str();
}
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
std::string generate_name_from_ops(const module& m)
{
auto op_names = get_op_names(m);
return join_strings(op_names, "_");
}
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -48,8 +48,10 @@ ...@@ -48,8 +48,10 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath> #include <cmath>
#include <set> #include <set>
...@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm) ...@@ -279,6 +281,11 @@ MIGRAPHX_REGISTER_OP(hip_layernorm)
struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm> struct hip_triadd_layernorm : ternary_device<hip_triadd_layernorm, &device::triadd_layernorm>
{ {
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(4).standard();
return inputs[0];
}
// Empty finalize to skip dimension reduction // Empty finalize to skip dimension reduction
void finalize(context&, const shape&, const std::vector<shape>&) {} void finalize(context&, const shape&, const std::vector<shape>&) {}
}; };
...@@ -827,13 +834,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r) ...@@ -827,13 +834,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT template <class... Strings>
inline auto precompile_name(Strings... names) // NOLINT
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) { return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
return false; return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op")); auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s); return (contains({names...}, op.name()));
}); });
} }
...@@ -942,28 +950,70 @@ struct find_gemm_add ...@@ -942,28 +950,70 @@ struct find_gemm_add
} }
}; };
auto pointwise_name(const std::string& s)
{
return precompile_name("pointwise")(match::make_basic_pred_matcher([=](auto ins) {
module_ref pm = ins->module_inputs().front();
auto n = std::count_if(pm->begin(), pm->end(), [&](auto& i) { return i.name() == s; });
if(n != 1)
return false;
return std::all_of(pm->begin(), pm->end(), [&](auto& i) {
return starts_with(i.name(), "@") or i.name() == s;
});
}));
}
struct find_gemm_pointwise struct find_gemm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return pointwise_name("add")( return precompile_name("pointwise")(
match::nargs(3), match::nargs(3),
match::all_of[match::inputs()](match::standard_shape()), match::either_arg(0, 1)(
match::either_arg(0, 1)(match::used_once().bind("c"), match::any_of(match::standard_shape(), match::is_constant()).bind("c"),
match::name("gpu::gemm")(match::nargs(3)).bind("gemm"))); match::name("gpu::gemm")(match::nargs(3), match::used_once()).bind("gemm")));
}
// TODO: Move to matcher.hpp
static auto match_param(const std::string& name)
{
return match::make_basic_pred_matcher([=](auto ins) {
if(ins->name() != "@param")
return false;
auto p = any_cast<builtin::param>(ins->get_operator());
return p.parameter == name;
});
}
template <class M>
static auto match_mul_const(M m, const std::string& var)
{
return match::name("mul")(match::either_arg(0, 1)(match::name("@literal").bind(var), m))
.bind(var + "_mul");
}
static auto match_add(const std::string& input, const std::string& output)
{
auto param = match::name("@param");
auto add = match::name("add")(match::args(param, param));
auto inner_mul = match::any_of(match_mul_const(match_param(input), "alpha"),
match_mul_const(match_param(output), "beta"));
auto mul_add = match::name("add")(match::either_arg(0, 1)(inner_mul, param));
auto add_mul = match_mul_const(add, "gamma");
return match::name("@return")(match::args(match::any_of(add, mul_add, add_mul)));
}
static float get_float(instruction_ref ins) { return ins->get_literal().at<float>(); }
template <class Gemm>
static bool update_gemm(Gemm& gemm, module_ref pm, unsigned input)
{
auto names = pm->get_parameter_names();
if(names.size() != 2)
return false;
std::sort(names.begin(), names.end());
unsigned output = input == 0 ? 1 : 0;
auto mr = match::match_instruction(
*pm, std::prev(pm->end()), match_add(names[input], names[output]));
if(mr.result == pm->end())
return false;
if(contains(mr.instructions, "alpha_mul"))
gemm.alpha *= get_float(mr.instructions["alpha"]);
else if(contains(mr.instructions, "beta_mul"))
gemm.beta *= get_float(mr.instructions["beta"]);
else if(contains(mr.instructions, "gamma_mul"))
{
gemm.alpha *= get_float(mr.instructions["gamma"]);
gemm.beta *= get_float(mr.instructions["gamma"]);
}
return true;
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -977,6 +1027,19 @@ struct find_gemm_pointwise ...@@ -977,6 +1027,19 @@ struct find_gemm_pointwise
// Already fused gemm // Already fused gemm
if(not float_equal(gemm.beta, 0)) if(not float_equal(gemm.beta, 0))
return; return;
gemm.beta = 1;
if(not update_gemm(
gemm, ins->module_inputs().front(), ins->inputs().front() == gemm_ins ? 0 : 1))
return;
// const-fold input if not standard shape since rocblas can't handle it
if(not c_ins->get_shape().standard())
{
auto c = op::contiguous{};
auto l = c.compute(c.compute_shape({c_ins->get_shape()}), {c_ins->eval()});
c_ins = m.add_literal(l.get_shape(), l.data());
}
auto inputs = gemm_ins->inputs(); auto inputs = gemm_ins->inputs();
inputs.pop_back(); inputs.pop_back();
...@@ -984,11 +1047,68 @@ struct find_gemm_pointwise ...@@ -984,11 +1047,68 @@ struct find_gemm_pointwise
inputs.push_back(c_ins); inputs.push_back(c_ins);
inputs.push_back(ins->inputs().back()); inputs.push_back(ins->inputs().back());
gemm.beta = 1;
m.replace_instruction(ins, gemm, inputs); m.replace_instruction(ins, gemm, inputs);
} }
}; };
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if(perm.size() < 3)
return;
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1041,6 +1161,31 @@ struct find_contiguous_pointwise ...@@ -1041,6 +1161,31 @@ struct find_contiguous_pointwise
} }
}; };
struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if(not layernorm->module_inputs().empty())
return;
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
...@@ -1063,7 +1208,9 @@ void fuse_ops::apply(module& m) const ...@@ -1063,7 +1208,9 @@ void fuse_ops::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_triadd_layernorm{}, find_triadd_layernorm{},
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,19 @@ void blas_shape(const shape& s) ...@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx, ...@@ -97,6 +111,12 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
const bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape()); bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape()); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx, ...@@ -105,12 +125,8 @@ void gemm_impl(context& ctx,
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
bool is_3inputs = (args.size() == 4);
if(!is_3inputs)
{
beta = 0;
}
rocblas_datatype arg_type = get_type(args[0].get_shape().type()); rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type; auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r) if(output_type == rocblas_datatype_i8_r)
...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx, ...@@ -186,7 +202,7 @@ void gemm_impl(context& ctx,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx, ...@@ -197,6 +213,7 @@ void gemm_impl(context& ctx,
auto a_stride = get_batch_stride(args[0]); auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]); auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]); auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx, ...@@ -220,8 +237,8 @@ void gemm_impl(context& ctx,
c_stride, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldd,
c_stride, d_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP #define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs) ...@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return make_transformer_args({xs.str()...}); return make_transformer_args({xs.str()...});
} }
std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_name_from_ops(const module& m);
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -42,15 +42,17 @@ namespace gpu { ...@@ -42,15 +42,17 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s); void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -58,7 +60,9 @@ struct rocblas_gemm ...@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"))); f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
} }
std::string name() const std::string name() const
...@@ -74,13 +78,14 @@ struct rocblas_gemm ...@@ -74,13 +78,14 @@ struct rocblas_gemm
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.has(2, 3);
blas_shape(inputs[0]); blas_shape(inputs[0]);
blas_shape(inputs[1]); blas_shape(inputs[1]);
// if gemm and add are fused // if gemm and add are fused
if(in_shapes.size() > 2) if(in_shapes.size() > 2)
{ {
auto cmat_shape = in_shapes.back(); auto cmat_shape = in_shapes.back();
check_shapes{{cmat_shape}, *this}.not_transposed().not_broadcasted();
in_shapes.pop_back(); in_shapes.pop_back();
blas_shape(cmat_shape); blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes); auto op_out_shape = op.compute_shape(in_shapes);
...@@ -97,10 +102,10 @@ struct rocblas_gemm ...@@ -97,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape; return transpose_batch(op_out_shape, trans_batch);
} }
return op.compute_shape(in_shapes); return transpose_batch(op.compute_shape(in_shapes), trans_batch);
} }
argument argument
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const
{
return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = inputs.front().lens().size() - 1;
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(axis == faxis)
{
vec = vectorize::elements(faxis, inputs);
}
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel");
auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if(op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params}) ...@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
...@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else else
{ {
assert(not ins->module_inputs().empty()); assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); auto pf = generate_pointwise(*pm, "inner_pointwise");
cpp_generator g; std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); return replace(
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); compile_op(ctx,
g.add_point_op("sign", to_shapes(ins->inputs()),
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"); {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}));
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
g.add_point_op("mod", "migraphx::mod(${0}, ${1})");
g.add_point_op("fmod", "migraphx::fmod(${0}, ${1})");
// Add explict conversions
g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
} }
} }
}; };
......
...@@ -31,8 +31,9 @@ ...@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; } ->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace migraphx { namespace migraphx {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
template <index_int Axis,
class F,
class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
};
// mean(x)
auto mean_x = mean(op);
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
...@@ -224,6 +224,18 @@ struct block ...@@ -224,6 +224,18 @@ struct block
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
}; };
template <class Slicer> template <class Slicer>
...@@ -281,6 +293,13 @@ struct lane ...@@ -281,6 +293,13 @@ struct lane
} }
}); });
} }
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
return get_shape_c<reduce_type>{}.elements();
}
}; };
template <class Slicer> template <class Slicer>
......
...@@ -175,7 +175,7 @@ template <class T, class Op> ...@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr auto vec_reduce(T x, Op op) constexpr auto vec_reduce(T x, Op op)
{ {
if constexpr(vec_size<T>() < 2) if constexpr(vec_size<T>() < 2)
return x; return vec_type<T>{x};
else else
{ {
vec_type<T> result = x[0]; vec_type<T> result = x[0];
......
...@@ -24,12 +24,53 @@ ...@@ -24,12 +24,53 @@
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace { namespace {
template <class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
if(not mods.empty())
{
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
};
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
...@@ -39,59 +80,30 @@ struct find_layernorm ...@@ -39,59 +80,30 @@ struct find_layernorm
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard()) m.replace_instruction(ins, layernorm{}, x_ins);
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
} }
}; };
struct find_triaddlayernorm struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
auto add1 = return match::layernorm()(match::var("x")(match::name("add").bind("add")));
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["z1"]; auto add_ins = r.instructions["add"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction( m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
} }
}; };
} // namespace } // namespace
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{}); match::find_matches(m, find_add_layernorm{}, find_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -42,6 +42,7 @@ ...@@ -42,6 +42,7 @@
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/replace_allocate.hpp> #include <migraphx/replace_allocate.hpp>
#include <migraphx/rewrite_batchnorm.hpp> #include <migraphx/rewrite_batchnorm.hpp>
#include <migraphx/rewrite_gelu.hpp>
#include <migraphx/rewrite_pooling.hpp> #include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_quantization.hpp> #include <migraphx/rewrite_quantization.hpp>
#include <migraphx/rewrite_rnn.hpp> #include <migraphx/rewrite_rnn.hpp>
...@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -116,6 +117,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
inline_module{}, inline_module{},
rewrite_pooling{}, rewrite_pooling{},
dead_code_elimination{}, dead_code_elimination{},
rewrite_gelu{},
dead_code_elimination{},
eliminate_common_subexpression{}, eliminate_common_subexpression{},
dead_code_elimination{}, dead_code_elimination{},
simplify_algebra{}, simplify_algebra{},
...@@ -134,8 +137,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -134,8 +137,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
...@@ -144,6 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -144,6 +145,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment