#include "verify_program.hpp" #include #include #include migraphx::instruction_ref add_layernorm(migraphx::program& p, std::vector dims) { auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); auto scale = p.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {dims.back()}}); auto epsilon = p.add_literal(1e-12f); auto exponent = p.add_literal(2.0f); auto mean = p.add_instruction(migraphx::op::reduce_mean({2}), x); auto mean_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, mean); auto sub = p.add_instruction(migraphx::op::sub{}, x, mean_mbcast); auto exponent_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, exponent); auto pow = p.add_instruction(migraphx::op::pow{}, sub, exponent_mbcast); auto var = p.add_instruction(migraphx::op::reduce_mean({2}), pow); auto epsilon_mbcast = p.add_instruction(migraphx::op::multibroadcast{{1, dims.at(1), 1}}, epsilon); auto add_epsilon = p.add_instruction(migraphx::op::add{}, var, epsilon_mbcast); auto sqrt = p.add_instruction(migraphx::op::sqrt{}, add_epsilon); auto sqrt_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, sqrt); auto div = p.add_instruction(migraphx::op::div{}, sub, sqrt_mbcast); auto scale_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, scale); auto mul = p.add_instruction(migraphx::op::mul{}, scale_mbcast, div); auto bias_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, bias); return p.add_instruction(migraphx::op::add{}, mul, bias_mbcast); } struct test_layernorm : verify_program { migraphx::program create_program() const { migraphx::program p; add_layernorm(p, {1, 1, 5}); return p; } }; struct test_layernorm2 : verify_program { migraphx::program create_program() const { migraphx::program p; add_layernorm(p, {1, 4, 24}); return p; } };