"include/vscode:/vscode.git/clone" did not exist on "ad0a4ce13d66a02a69421ba3cc11c9f3f0d883e2"
Commit c7096299 authored by turneram's avatar turneram
Browse files

Use parse_layernorm to un-fuse layernorm op

parent ebfbae82
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/layernorm.hpp> #include <migraphx/op/layernorm.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -16,7 +18,14 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -16,7 +18,14 @@ struct parse_layernorm : op_parser<parse_layernorm>
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
float epsilon = 1e-3f; // un-fuse layernorm op so migraphx can handle fusion instead
auto x = args.front();
auto x_type = x->get_shape().type();
auto weights = args.at(1);
auto bias = args.at(2);
float epsilon = 1e-12f;
int64_t axis = -1; int64_t axis = -1;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
...@@ -26,16 +35,25 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -26,16 +35,25 @@ struct parse_layernorm : op_parser<parse_layernorm>
{ {
epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>(); epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
} }
auto epsilon_lit = info.add_literal(literal{shape{x_type, {1}}, {epsilon}});
auto exponent = info.add_literal(literal{shape{x_type, {1}}, {2.0}});
auto dims = x->get_shape().lens();
auto layernorm = info.add_instruction( auto mean = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x);
make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front()); auto mean_mbcast =
info.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
if(args.size() >= 2) auto sub = info.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
layernorm = info.add_broadcastable_binary_op("mul", layernorm, args.at(1)); auto exponent_mbcast =
if(args.size() == 3) info.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
layernorm = info.add_broadcastable_binary_op("add", layernorm, args.at(2)); auto pow = info.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), pow);
auto add_epsilon = info.add_broadcastable_binary_op("add", var, epsilon_lit);
auto sqrt = info.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto div = info.add_broadcastable_binary_op("div", sub, sqrt);
auto mul = info.add_broadcastable_binary_op("mul", div, weights);
return layernorm; return info.add_broadcastable_binary_op("add", mul, bias);
} }
}; };
......
...@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t) ...@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t)
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS; const int out_offset = n * head_size + s * NH + b * NHS;
if(index.local < 1024) if(index.global < input_shape.elements())
output_t[out_offset + idx[3]] = input_t[index.global]; output_t[out_offset + idx[3]] = input_t[index.global];
} }
......
...@@ -23,7 +23,6 @@ __device__ void transposeqkv(const T& input_t, const U& output_t) ...@@ -23,7 +23,6 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
const int s = idx[1]; const int s = idx[1];
const int m = idx[2]; const int m = idx[2];
const int n = idx[3]; const int n = idx[3];
// const int j = idx[4];
const int num_heads = lens[3]; const int num_heads = lens[3];
const int sequence_length = lens[1]; const int sequence_length = lens[1];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment