"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "b4c4234dde830705ceed3ed81513c7cc6988c7b0"
Unverified Commit 557b1ad1 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into fastsoftmax

parents 70687f79 7c8f2690
name: MIGraphX Performance Tests name: MIGraphX Performance Tests
on: on:
push:
branches: [develop]
pull_request: pull_request:
branches: [develop] branches: [develop]
types: [opened, synchronize, closed]
schedule: schedule:
- cron: "0 5 * * 1-6" - cron: "0 5 * * 1-6"
......
...@@ -151,8 +151,11 @@ struct find_transpose ...@@ -151,8 +151,11 @@ struct find_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("transpose")(match::none_of( auto output_not_transpose =
match::skip_output(match::name("contiguous"))(match::name("transpose")))); match::none_of(match::skip_output(match::name("contiguous"))(match::name("transpose")));
auto input_has_transpose =
match::args(match::skip(match::name("contiguous"))(match::name("transpose")));
return match::name("transpose")(output_not_transpose, input_has_transpose);
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
...@@ -664,9 +667,94 @@ struct find_slice_transpose ...@@ -664,9 +667,94 @@ struct find_slice_transpose
} }
}; };
struct find_transpose_slice
{
auto matcher() const
{
return match::name("transpose")(match::all_of[match::outputs()](match::name("slice")));
}
static std::vector<int64_t> slice_distance(const op::slice& op)
{
assert(op.starts.size() == op.ends.size());
std::vector<int64_t> result(op.starts.size());
std::transform(
op.ends.begin(), op.ends.end(), op.starts.begin(), result.begin(), std::minus<>{});
return result;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto slices = ins->outputs();
if(slices.empty())
return;
auto slice = any_cast<op::slice>(slices.front()->get_operator());
auto sdistance = slice_distance(slice);
// Check all distances and axes are the same
if(std::any_of(slices.begin(), slices.end(), [&](auto sins) {
auto s = any_cast<op::slice>(sins->get_operator());
return s.axes != slice.axes or slice_distance(s) != sdistance;
}))
return;
// Check distances are divisible by lens of corresponding axes
auto mod_by_distance = [&](const auto& v, auto f) {
return std::inner_product(v.begin(),
v.end(),
sdistance.begin(),
0,
std::plus<>{},
[&](auto x, auto d) -> uint64_t {
if(d == 0)
return 1;
return f(x) % d;
});
};
if(mod_by_distance(slice.axes, [&](auto x) { return ins->get_shape().lens()[x]; }) != 0 or
mod_by_distance(slice.starts, id{}) != 0 or mod_by_distance(slice.ends, id{}) != 0)
return;
// TODO: Handle multiple axes
if(sdistance.size() != 1)
return;
auto axis = slice.axes.front();
// Skip if axis would be packed
if(std::all_of(ins->get_shape().lens().begin(),
ins->get_shape().lens().begin() + axis,
[](auto x) { return x == 1; }))
return;
// Compute axis before transpose to use for unsqueeze
auto perm = ins->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto preaxis = std::find(perm.begin(), perm.end(), axis) - perm.begin();
// Make unsqeeze
auto unsqueeze = m.insert_instruction(
ins, make_op("unsqueeze", {{"axes", {preaxis}}, {"steps", sdistance}}), ins->inputs());
// Make transpose
std::transform(perm.begin(), perm.end(), perm.begin(), [&](auto i) {
if(i > preaxis)
return i + 1;
return i;
});
perm.insert(perm.begin(), preaxis + 1);
auto transpose =
m.insert_instruction(ins, make_op("transpose", {{"permutation", perm}}), unsqueeze);
// Slice and squeeze
for(auto s : slices)
{
auto op = any_cast<op::slice>(s->get_operator());
op.axes = {0};
op.starts = {op.starts.front() / sdistance.front()};
op.ends = {op.ends.front() / sdistance.front()};
auto slice_ins = m.insert_instruction(ins, op, transpose);
auto squeeze =
m.insert_instruction(ins, make_op("squeeze", {{"axes", {0}}}), slice_ins);
m.replace_instruction(s, squeeze);
}
}
};
void simplify_reshapes::apply(module& m) const void simplify_reshapes::apply(module& m) const
{ {
for(int i = 0; i < 2; i++) for(int i = 0; i < 4; i++)
{ {
match::find_matches(m, match::find_matches(m,
find_where_op{}, find_where_op{},
...@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const ...@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert{}, find_nested_convert{},
find_nested_slice{}, find_nested_slice{},
find_nested_concat{}, find_nested_concat{},
find_transpose_slice{},
find_slice_transpose{}, find_slice_transpose{},
find_transpose_contiguous_reshaper_unary{}); find_transpose_contiguous_reshaper_unary{});
dead_code_elimination{}.apply(m); dead_code_elimination{}.apply(m);
......
...@@ -25,6 +25,13 @@ ...@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -75,25 +82,25 @@ std::string vectorize::str() const ...@@ -75,25 +82,25 @@ std::string vectorize::str() const
preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs) preload preload::broadcasts(std::size_t axis, const std::vector<shape>& inputs)
{ {
const std::size_t max_lds_bytes = 4096; const std::size_t max_lds_bytes = 4096;
std::vector<bool> result; std::vector<bool> result(inputs.size());
std::transform(inputs.begin(), std::vector<std::size_t> preloaded;
inputs.end(), auto idxs = range(inputs.size());
std::back_inserter(result), std::copy_if(idxs.begin(), idxs.end(), std::back_inserter(preloaded), [&](auto i) {
[&](const shape& input) { return input.strides()[axis] == 0; }); return inputs[i].strides()[axis] == 0;
auto bytes = std::inner_product(inputs.begin(), });
inputs.end(), std::sort(preloaded.begin(), preloaded.end(), by(std::less<>{}, [&](auto i) {
result.begin(), return inputs[i].bytes();
std::size_t{0}, }));
std::plus<>{},
[](const shape& s, bool b) -> std::size_t { std::size_t bytes = 0;
if(b) for(auto i : preloaded)
return s.bytes(); {
return 0; auto input = inputs[i];
}); bytes += input.bytes();
if(bytes < max_lds_bytes) if(bytes > max_lds_bytes)
return {result}; break;
// TODO: Try to partially preload items result[i] = true;
std::fill(result.begin(), result.end(), false); }
return {result}; return {result};
} }
...@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return join_strings(std::move(transformers), ", "); return join_strings(std::move(transformers), ", ");
} }
std::string generate_pointwise(const module& pm, const std::string& name)
{
module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})");
g.add_point_op("sign", "${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))");
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
// Add explict conversions
g.fresult(
[](const shape& s) { return "migraphx::convert<" + shape::cpp_type(s.type()) + ">"; });
g.create_function(
g.generate_module(m).set_attributes({"__device__"}).set_generic_types(m).set_name(name));
return g.str();
}
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
std::string generate_name_from_ops(const module& m)
{
auto op_names = get_op_names(m);
return join_strings(op_names, "_");
}
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r) ...@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
inline auto precompile_name(std::string s) // NOLINT template <class... Strings>
inline auto precompile_name(Strings... names) // NOLINT
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) { return match::make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "gpu::precompile_op") if(ins->name() != "gpu::precompile_op")
return false; return false;
auto op = from_value<operation>(ins->get_operator().to_value().at("op")); auto op = from_value<operation>(ins->get_operator().to_value().at("op"));
return (op.name() == s); return (contains({names...}, op.name()));
}); });
} }
...@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise ...@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
} }
}; };
struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(match::arg(0)(
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if(not layernorm->module_inputs().empty())
return;
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
}
};
void fuse_ops::apply(module& m) const void fuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math}); match::find_matches(m, find_contiguous_pointwise{}, find_gelu{}, find_gelu_new{fast_math});
...@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const ...@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
match::find_matches(m, match::find_matches(m,
find_triadd_layernorm{}, find_triadd_layernorm{},
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP #define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs) ...@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return make_transformer_args({xs.str()...}); return make_transformer_args({xs.str()...});
} }
std::string generate_pointwise(const module& pm, const std::string& name);
std::string generate_name_from_ops(const module& m);
} // namespace gen } // namespace gen
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
using namespace migraphx::gpu::gen; // NOLINT
static const char* const layernorm_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__";
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const
{
return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
// TODO: Use reduce_dims
auto axis = inputs.front().lens().size() - 1;
auto faxis = find_fast_axis({inputs.front()});
vectorize vec{};
// Vectorize if the axis is a reduction axis
if(axis == faxis)
{
vec = vectorize::elements(faxis, inputs);
}
auto preloads = preload::broadcasts(axis, inputs);
auto relements = inputs[0].lens()[axis] / vec.size;
auto nelements = (inputs.back().elements() / inputs[0].lens()[axis]);
auto block_size = compute_block_size(relements, 256);
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, nelements * block_size, 256), block_size);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel");
auto src = interpolate_string(layernorm_kernel,
{{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"transformers", make_transformer_args(preloads, vec)},
{"post", v.get("post", std::string{"op::id{}"})},
{"preamble", v.get("preamble", std::string{})},
{"layernorm", v.get("layernorm", std::string{"layernorm"})},
{"axis", to_string(axis)}});
return compile_hip_code_object(src, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if(op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params}) ...@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"; )__migraphx__";
static std::vector<std::string> get_op_names(const module& m)
{
std::vector<std::string> result;
for(auto& ins : m)
{
if(starts_with(ins.name(), "@"))
continue;
result.push_back(ins.name());
}
return result;
}
struct pointwise_compiler : compiler<pointwise_compiler> struct pointwise_compiler : compiler<pointwise_compiler>
{ {
std::vector<std::string> names() const { return {"pointwise", "contiguous"}; } std::vector<std::string> names() const { return {"pointwise", "contiguous"}; }
...@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler> ...@@ -126,34 +114,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else else
{ {
assert(not ins->module_inputs().empty()); assert(not ins->module_inputs().empty());
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
run_passes(*pm, {eliminate_common_subexpression{}, dead_code_elimination{}}); auto pf = generate_pointwise(*pm, "inner_pointwise");
cpp_generator g; std::string lambda = "MIGRAPHX_LIFT(inner_pointwise)";
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); auto kernel_name = generate_name_from_ops(*pm) + "_kernel";
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); return replace(
g.add_point_op("prelu", "${function:where}(${0} < 0, ${0} * ${1}, ${0})"); compile_op(ctx,
g.add_point_op("sign", to_shapes(ins->inputs()),
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"); {{"lambda", lambda}, {"preamble", pf}, {"kernel", kernel_name}}));
g.add_point_op("equal", "migraphx::abs(${0} == ${1})");
g.add_point_op("less", "migraphx::abs(${0} < ${1})");
g.add_point_op("greater", "migraphx::abs(${0} > ${1})");
g.add_point_op("not", "migraphx::abs(not ${0})");
g.add_point_op("mod", "migraphx::mod(${0}, ${1})");
g.add_point_op("fmod", "migraphx::fmod(${0}, ${1})");
// Add explict conversions
g.fresult([](const shape& s) {
return "migraphx::convert<" + shape::cpp_type(s.type()) + ">";
});
auto name = g.create_function(
g.generate_module(*pm).set_attributes({"__device__"}).set_generic_types(*pm));
std::string lambda = "MIGRAPHX_LIFT(" + name + ")";
auto op_names = get_op_names(*pm);
op_names.push_back("kernel");
auto op_name_string = join_strings(op_names, "_");
return replace(compile_op(
ctx,
to_shapes(ins->inputs()),
{{"lambda", lambda}, {"preamble", g.str()}, {"kernel", op_name_string}}));
} }
} }
}; };
......
...@@ -31,8 +31,9 @@ ...@@ -31,8 +31,9 @@
->decltype(__VA_ARGS__) { return __VA_ARGS__; } ->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)) [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace migraphx { namespace migraphx {
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace migraphx {
template <index_int Axis,
class F,
class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input1, Axis>;
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
};
// mean(x)
auto mean_x = mean(op);
// mean(m ^ 2)
auto mean_m2 = mean([&](auto x1, auto x2) {
auto m = op(x1, x2) - mean_x;
return m * m;
});
r.inner([&](auto& y, auto x1, auto x2, auto... xs) {
auto m = op(x1, x2) - mean_x;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y = compute(m * rsqrt(mean_m2 + value_type{1e-12}), xs...);
})(output, input1, input2, inputs...);
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
...@@ -224,6 +224,18 @@ struct block ...@@ -224,6 +224,18 @@ struct block
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
}; };
template <class Slicer> template <class Slicer>
...@@ -281,6 +293,13 @@ struct lane ...@@ -281,6 +293,13 @@ struct lane
} }
}); });
} }
template <class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
return get_shape_c<reduce_type>{}.elements();
}
}; };
template <class Slicer> template <class Slicer>
......
...@@ -175,7 +175,7 @@ template <class T, class Op> ...@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr auto vec_reduce(T x, Op op) constexpr auto vec_reduce(T x, Op op)
{ {
if constexpr(vec_size<T>() < 2) if constexpr(vec_size<T>() < 2)
return x; return vec_type<T>{x};
else else
{ {
vec_type<T> result = x[0]; vec_type<T> result = x[0];
......
...@@ -24,12 +24,53 @@ ...@@ -24,12 +24,53 @@
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace { namespace {
template <class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{
std::size_t nargs = 1;
if(not mods.empty())
{
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
if(s.scalar())
{
return s;
}
else if(s.broadcasted())
{
return {s.type(), s.lens()};
}
else
{
return s.with_lens(s.lens());
}
}
};
struct layernorm : layernorm_base<layernorm, 0>
{
std::string name() const { return "gpu::prelayernorm"; }
};
MIGRAPHX_REGISTER_OP(layernorm);
struct add_layernorm : layernorm_base<add_layernorm, 1>
{
std::string name() const { return "gpu::preadd_layernorm"; }
};
MIGRAPHX_REGISTER_OP(add_layernorm);
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
...@@ -39,59 +80,30 @@ struct find_layernorm ...@@ -39,59 +80,30 @@ struct find_layernorm
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["x"]; auto x_ins = r.instructions["x"];
if(not x_ins->get_shape().standard()) m.replace_instruction(ins, layernorm{}, x_ins);
x_ins = m.insert_instruction(ins, make_op("contiguous"), x_ins);
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction(
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::layernorm"), x_ins, a);
} }
}; };
struct find_triaddlayernorm struct find_add_layernorm
{ {
auto matcher() const auto matcher() const
{ {
auto add1 = return match::layernorm()(match::var("x")(match::name("add").bind("add")));
match::name("add")(match::none_of(match::is_constant()),
match::args(match::any().bind("z1"), match::any().bind("z2")));
auto add2 = match::name("add")(match::either_arg(0, 1)(add1, match::any().bind("z3")));
return match::layernorm()(match::var("x")(add2));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto x_ins = r.instructions["z1"]; auto add_ins = r.instructions["add"];
auto y_ins = r.instructions["z2"];
auto z_ins = r.instructions["z3"];
for(auto* pins : {&x_ins, &y_ins, &z_ins})
{
if(not(*pins)->get_shape().standard())
*pins = m.insert_instruction(ins, make_op("contiguous"), *pins);
}
auto relements = x_ins->get_shape().lens().back();
if(relements > 1024 or (relements % 4 != 0 and relements > 256))
return;
auto a = m.insert_instruction( m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
ins, make_op("hip::allocate", {{"shape", to_value(x_ins->get_shape())}}));
m.replace_instruction(ins, make_op("gpu::triadd_layernorm"), x_ins, y_ins, z_ins, a);
} }
}; };
} // namespace } // namespace
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{}); match::find_matches(m, find_add_layernorm{}, find_layernorm{});
} }
} // namespace gpu } // namespace gpu
......
...@@ -108,15 +108,7 @@ struct function ...@@ -108,15 +108,7 @@ struct function
}; };
template <class Stream, class Iterator> template <class Stream, class Iterator>
inline Stream& stream_range(Stream& s, Iterator start, Iterator last) Stream& stream_range(Stream& s, Iterator start, Iterator last);
{
if(start != last)
{
s << *start;
std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; });
}
return s;
}
template <class Stream> template <class Stream>
inline Stream& operator<<(Stream& s, std::nullptr_t) inline Stream& operator<<(Stream& s, std::nullptr_t)
...@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v. ...@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return s; return s;
} }
template <class Stream, class Iterator>
inline Stream& stream_range(Stream& s, Iterator start, Iterator last)
{
if(start != last)
{
s << *start;
std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; });
}
return s;
}
template <class T> template <class T>
const T& get_value(const T& x) const T& get_value(const T& x)
{ {
......
...@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m) ...@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m)
migraphx::run_passes(m, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(m, {migraphx::simplify_reshapes{}, migraphx::dead_code_elimination{}});
} }
inline std::vector<std::vector<std::size_t>> to_lens(const std::vector<migraphx::shape>& shapes)
{
std::vector<std::vector<std::size_t>> result;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(result), [&](const auto& s) {
return s.lens();
});
return result;
}
TEST_CASE(double_contig) TEST_CASE(double_contig)
{ {
migraphx::program p; migraphx::program p;
...@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose) ...@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_slice_non_packed_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto sqrt = m1.add_instruction(migraphx::make_op("sqrt"), slice);
m1.add_return({sqrt});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(m1.get_output_shapes() == output_shapes);
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice);
auto sqrt = m2.add_instruction(migraphx::make_op("sqrt"), squeeze);
m2.add_return({sqrt});
}
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice_non_packed_multi_axis)
{
migraphx::module m1;
{
auto x = m1.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto transpose =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto slice1 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {12}}}),
transpose);
auto slice2 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {12}}, {"ends", {24}}}),
transpose);
auto transpose2 = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), slice2);
auto slice3 = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {24}}, {"ends", {36}}}),
transpose);
m1.add_return({slice1, transpose2, slice3});
}
auto output_shapes = m1.get_output_shapes();
run_pass(m1);
EXPECT(to_lens(m1.get_output_shapes()) == to_lens(output_shapes));
migraphx::module m2;
{
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {2, 384, 36, 64}});
auto unsqueeze =
m2.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {12}}}), x);
auto transpose = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {3, 0, 2, 1, 4}}}), unsqueeze);
auto slice1 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transpose);
auto squeeze1 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice1);
auto slice2 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transpose);
auto squeeze2 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice2);
auto transpose2 = m2.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), squeeze2);
auto slice3 = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transpose);
auto squeeze3 = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {0}}}), slice3);
m2.add_return({squeeze1, transpose2, squeeze3});
}
EXPECT(m1.sort() == m2.sort());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm> ...@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 1, 5}; std::vector<size_t> dims = {1, 2, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims); add_layernorm(*mm, x, dims);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment