Commit a62ef598 authored by Paul's avatar Paul
Browse files

Replace layernorm

parent 473881cf
......@@ -39,7 +39,7 @@ __global__ void layernorm_kernel(void* input_p, void* output_p)
struct layernorm_compiler : compiler<layernorm_compiler>
{
std::vector<std::string> names() const { return {"layernorm"}; }
std::vector<std::string> names() const { return {"layernorm", "gpu::prelayernorm"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
......
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace {
struct layernorm
{
std::string name() const { return "gpu::prelayernorm"; }
......@@ -28,12 +29,25 @@ struct layernorm
}
}
};
MIGRAPHX_REGISTER_OP(layernorm);
namespace {
struct find_layernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
m.replace_instruction(ins, layernorm{}, x_ins);
}
};
struct find_gpulayernorm
{
auto matcher() const { return match::layernorm(); }
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
......@@ -53,7 +67,7 @@ struct find_layernorm
}
};
struct find_triaddlayernorm
struct find_gputriaddlayernorm
{
auto matcher() const
{
......@@ -91,7 +105,8 @@ struct find_triaddlayernorm
void prefuse_ops::apply(module& m) const
{
match::find_matches(m, find_triaddlayernorm{}, find_layernorm{});
match::find_matches(m, find_layernorm{});
// match::find_matches(m, find_gputriaddlayernorm{}, find_gpulayernorm{});
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment