Commit 610f0417 authored by Paul's avatar Paul
Browse files

Format

parent f3aa2c67
...@@ -1051,7 +1051,8 @@ struct find_layernorm_pointwise ...@@ -1051,7 +1051,8 @@ struct find_layernorm_pointwise
{ {
auto matcher() const auto matcher() const
{ {
return precompile_name("pointwise")(match::arg(0)(precompile_name("gpu::prelayernorm").bind("layernorm"))); return precompile_name("pointwise")(
match::arg(0)(precompile_name("gpu::prelayernorm").bind("layernorm")));
} }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
......
...@@ -67,8 +67,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -67,8 +67,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = v.get("kernel", "layernorm_kernel"); options.kernel_name = v.get("kernel", "layernorm_kernel");
auto src = interpolate_string( auto src = interpolate_string(layernorm_kernel,
layernorm_kernel,
{{"kernel", options.kernel_name}, {{"kernel", options.kernel_name},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")}, {"args", enum_params(inputs.size(), "private_p")},
...@@ -83,7 +82,7 @@ struct layernorm_compiler : compiler<layernorm_compiler> ...@@ -83,7 +82,7 @@ struct layernorm_compiler : compiler<layernorm_compiler>
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{ {
auto v = op.to_value(); auto v = op.to_value();
if (not ins->module_inputs().empty()) if(not ins->module_inputs().empty())
{ {
auto* pm = ins->module_inputs().front(); auto* pm = ins->module_inputs().front();
v["preamble"] = generate_pointwise(*pm, "post_layernorm"); v["preamble"] = generate_pointwise(*pm, "post_layernorm");
......
...@@ -32,7 +32,8 @@ ...@@ -32,7 +32,8 @@
// NOLINTNEXTLINE // NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \ #define MIGRAPHX_LIFT(...) \
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...)) [](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace migraphx { namespace migraphx {
......
...@@ -37,7 +37,8 @@ struct layernorm ...@@ -37,7 +37,8 @@ struct layernorm
shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const shape compute_shape(std::vector<shape> inputs, std::vector<module_ref> mods) const
{ {
std::size_t nargs = 1; std::size_t nargs = 1;
if (not mods.empty()) { if(not mods.empty())
{
auto* pm = mods.front(); auto* pm = mods.front();
nargs = pm->get_parameter_names().size(); nargs = pm->get_parameter_names().size();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment