"src/op/operator.h" did not exist on "73a6cb8bfd1f6dfc6197b7ad9253719dd720d681"
Commit 817543c7 authored by Paul's avatar Paul
Browse files

Format

parent 5b8c054e
......@@ -825,7 +825,7 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m.replace_instruction(ins, cb, input_ins, weights_ins, old_ws_ins, bias_ins, alloc_ins);
}
template<class... Strings>
template <class... Strings>
inline auto precompile_name(Strings... names) // NOLINT
{
return match::make_basic_pred_matcher([=](instruction_ref ins) {
......@@ -1052,8 +1052,8 @@ struct find_layernorm_pointwise
{
auto matcher() const
{
return precompile_name("pointwise")(
match::arg(0)(precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
return precompile_name("pointwise")(match::arg(0)(
precompile_name("gpu::prelayernorm", "gpu::preadd_layernorm").bind("layernorm")));
}
void apply(module& m, const match::matcher_result& r) const
......@@ -1062,12 +1062,12 @@ struct find_layernorm_pointwise
auto layernorm = r.instructions["layernorm"];
auto* pm = ins->module_inputs().front();
if (not layernorm->module_inputs().empty())
if(not layernorm->module_inputs().empty())
return;
auto inputs = layernorm->inputs();
auto inputs = layernorm->inputs();
inputs.pop_back();
inputs.insert(inputs.end(), ins->inputs().begin()+1, ins->inputs().end());
inputs.insert(inputs.end(), ins->inputs().begin() + 1, ins->inputs().end());
m.replace_instruction(ins, layernorm->get_operator(), inputs, {pm});
}
......
......@@ -43,7 +43,10 @@ __global__ void ${kernel}(${params})
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"}; }
std::vector<std::string> names() const
{
return {"layernorm", "gpu::prelayernorm", "gpu::preadd_layernorm"};
}
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
......@@ -82,20 +85,21 @@ struct layernorm_compiler : compiler<layernorm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
auto v = op.to_value();
auto v = op.to_value();
v["layernorm"] = "layernorm";
v["kernel"] = "layernorm_kernel";
if (op.name() == "gpu::preadd_layernorm")
v["kernel"] = "layernorm_kernel";
if(op.name() == "gpu::preadd_layernorm")
{
v["layernorm"] = "add_layernorm";
v["kernel"] = "add_layernorm_kernel";
v["kernel"] = "add_layernorm_kernel";
}
if(not ins->module_inputs().empty())
{
auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm");
v["post"] = "MIGRAPHX_LIFT(post_layernorm)";
v["kernel"] = v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
v["kernel"] =
v["layernorm"].to<std::string>() + "_" + generate_name_from_ops(*pm) + "_kernel";
}
return replace(compile_op(ctx, to_shapes(ins->inputs()), v));
}
......
......@@ -6,8 +6,15 @@
namespace migraphx {
template <index_int Axis, class F, class BinOp, class Output, class Input1, class Input2, class... Inputs>
__device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
template <index_int Axis,
class F,
class BinOp,
class Output,
class Input1,
class Input2,
class... Inputs>
__device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
using reduce_output = reduce::with_axis<Input1, Axis>;
constexpr auto relements =
......@@ -16,8 +23,9 @@ __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Inp
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { return f(x1, x2) / value_type{relements}; })(
input1, input2);
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements};
})(input1, input2);
};
// mean(x)
auto mean_x = mean(op);
......@@ -35,17 +43,19 @@ __device__ void generic_binary_layernorm(F compute, BinOp op, Output output, Inp
});
}
template <index_int Axis, class F, class Output, class Input, class... Inputs>
__device__ void layernorm(F compute, Output output, Input input, Inputs... inputs)
{
generic_binary_layernorm<Axis>(compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
generic_binary_layernorm<Axis>(
compute, [](auto x, auto) { return x; }, output, input, input, inputs...);
}
template <index_int Axis, class F, class Output, class Input1, class Input2, class... Inputs>
__device__ void add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
__device__ void
add_layernorm(F compute, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{
generic_binary_layernorm<Axis>(compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
generic_binary_layernorm<Axis>(
compute, [](auto x1, auto x2) { return x1 + x2; }, output, input1, input2, inputs...);
}
} // namespace migraphx
......
......@@ -31,7 +31,7 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
template<class Derived, std::size_t N>
template <class Derived, std::size_t N>
struct layernorm_base
{
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
......@@ -42,7 +42,7 @@ struct layernorm_base
auto* pm = mods.front();
nargs = pm->get_parameter_names().size();
}
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs+N);
check_shapes{inputs, static_cast<const Derived&>(*this)}.has(nargs + N);
auto s = inputs.at(0);
if(s.scalar())
{
......@@ -86,11 +86,14 @@ struct find_layernorm
struct find_add_layernorm
{
auto matcher() const { return match::layernorm()(match::var("x")(match::name("add").bind("add"))); }
auto matcher() const
{
return match::layernorm()(match::var("x")(match::name("add").bind("add")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto ins = r.result;
auto add_ins = r.instructions["add"];
m.replace_instruction(ins, add_layernorm{}, add_ins->inputs());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment