Commit 817543c7 authored by Paul's avatar Paul
Browse files

Format

parent 5b8c054e
...@@ -825,7 +825,7 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r) ...@@ -825,7 +825,7 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins); m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
} }
template<class... Strings> template <class... Strings>
inline auto precompile_name(Strings... names) // NOLINT inline auto precompile_name(Strings... names) // NOLINT
{ {
return match::make_basic_pred_matcher([=](instruction_ref ins) { return match::make_basic_pred_matcher([=](instruction_ref ins) {
...@@ -1052,8 +1052,8 @@ struct find_layernorm_pointwise ...@@ -1052,8 +1052,8 @@ struct find_layernorm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")( return precompile_name("pointwise")(match::arg(0)(
match::arg(0)(precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm"))); precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
...@@ -1062,12 +1062,12 @@ struct find_layernorm_pointwise ...@@ -1062,12 +1062,12 @@ struct find_layernorm_pointwise
auto layernorm = r.instructions["layernorm"]; auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
if (not layernorm->module_inputs().empty()) if(not layernorm->module_inputs().empty())
return; return;
auto inputs = layernorm->inputs(); auto inputs = layernorm->inputs();
inputs.pop_back(); inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin()+1, ins->inputs().end()); inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm}); m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
} }
......
...@@ -43,7 +43,10 @@ __global__ void ${kernel}(${params}) ...@@ -43,7 +43,10 @@ __global__ void ${kernel}(${params})
struct layernorm_compiler : compiler<layernorm_compiler> struct layernorm_compiler : compiler<layernorm_compiler>
{ {
std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"}; } std::vector<std::string> names() const
{
return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
...@@ -82,20 +85,21 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -82,20 +85,21 @@ struct layernorm_compiler : compiler<layernorm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
v["layernorm"] = "layernorm"; v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel"; v["kernel"] = "layernorm_kernel";
if (op.name() == "gpu::preadd_layernorm") if(op.name() == "gpu::preadd_layernorm")
{ {
v["layernorm"] = "add_layernorm"; v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel"; v["kernel"] = "add_layernorm_kernel";
} }
if(not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm"); v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)"; v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] = v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel"; v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
} }
return replace(compile_op(ctx, to_shapes(ins->inputs()), v)); return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
} }
......
...@@ -6,8 +6,15 @@ ...@@ -6,8 +6,15 @@
namespace migraphx { namespace migraphx {
template <index_int Axis, class F, class BinOp, class Output, class Input1, class Input2, class... Inputs> template <index_int Axis,
__device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs) class F,
class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
constexpr auto relements = constexpr auto relements =
...@@ -16,8 +23,9 @@ __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Inp ...@@ -16,8 +23,9 @@ __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Inp
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { return f(x1, x2) / value_type{relements}; })( return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
input1, input2); return f(x1, x2) / value_type{relements};
})(input1, input2);
}; };
// mean(x) // mean(x)
auto mean_x = mean(op); auto mean_x = mean(op);
...@@ -35,17 +43,19 @@ __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Inp ...@@ -35,17 +43,19 @@ __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Inp
}); });
} }
template <index_int Axis, class F, class Output, class Input, class... Inputs> template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs) __device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>(compute, [](auto x, auto) { return x; }, output, input, input, inputs...); generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
} }
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs> template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs) __device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
generic_binary_layernorm<Axis>(compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...); generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -31,7 +31,7 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -31,7 +31,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace { namespace {
template<class Derived, std::size_t N> template <class Derived, std::size_t N>
struct layernorm_base struct layernorm_base
{ {
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
...@@ -42,7 +42,7 @@ struct layernorm_base ...@@ -42,7 +42,7 @@ struct layernorm_base
auto* pm = mods.front(); auto* pm = mods.front();
nargs = pm->get_parameter_names().size(); nargs = pm->get_parameter_names().size();
} }
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs+N); check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0); auto s = inputs.at(0);
if(s.scalar()) if(s.scalar())
{ {
...@@ -86,11 +86,14 @@ struct find_layernorm ...@@ -86,11 +86,14 @@ struct find_layernorm
struct find_add_layernorm struct find_add_layernorm
{ {
auto matcher() const { return match::layernorm()(match::var("x")(match::name("add").bind("add"))); } auto matcher() const
{
return match::layernorm()(match::var("x")(match::name("add").bind("add")));
}
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto add_ins = r.instructions["add"]; auto add_ins = r.instructions["add"];
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment