Commit a62ef598 authored by Paul's avatar Paul
Browse files

Replace layernorm

parent 473881cf
...@@ -39,7 +39,7 @@ __global__ void layernorm_kernel(void* input_p, void* output_p) ...@@ -39,7 +39,7 @@ __global__ void layernorm_kernel(void* input_p, void* output_p)
struct layernorm_compiler : compiler<layernorm_compiler> struct layernorm_compiler : compiler<layernorm_compiler>
{ {
std::vector<std::string> names() const { return {"layernorm"}; } std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
......
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace {
struct layernorm struct layernorm
{ {
std::string name() const { return "gpu::prelayernorm"; } std::string name() const { return "gpu::prelayernorm"; }
...@@ -28,12 +29,25 @@ struct layernorm ...@@ -28,12 +29,25 @@ struct layernorm
} }
} }
}; };
MIGRAPHX_REGISTER_OP(layernorm);
namespace {
struct find_layernorm struct find_layernorm
{ {
auto matcher() const { return match::layernorm(); } auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
m.replace_instruction(ins, layernorm{}, x_ins);
}
};
struct find_gpulayernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -53,7 +67,7 @@ struct find_layernorm ...@@ -53,7 +67,7 @@ struct find_layernorm
} }
}; };
struct find_triaddlayernorm struct find_gputriaddlayernorm
{ {
auto matcher() const auto matcher() const
{ {
...@@ -91,7 +105,8 @@ struct find_triaddlayernorm ...@@ -91,7 +105,8 @@ struct find_triaddlayernorm
void prefuse_ops::apply(module& m) const void prefuse_ops::apply(module& m) const
{ {
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{}); match::find_matches(m, find_layernorm{});
// match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment