Commit f89638ec authored by turneram's avatar turneram
Browse files

Update ref to match gpu and handle scale and bias in parser

parent ba7a370a
...@@ -72,7 +72,9 @@ struct layernorm ...@@ -72,7 +72,9 @@ struct layernorm
mean_inv_std_dev_dim.at(i) = 1; mean_inv_std_dev_dim.at(i) = 1;
} */ } */
visit_all(result, args[0], args[1], args[2])( if (args.size() == 3)
{
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto data, auto weights, auto bias) { [&](auto output, auto data, auto weights, auto bias) {
par_for(norm_count, [&](auto idx) { par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size; auto offset = idx * norm_size;
...@@ -87,17 +89,39 @@ struct layernorm ...@@ -87,17 +89,39 @@ struct layernorm
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i) for(std::size_t i = 0; i < norm_size; ++i)
{ {
output[offset + i] = (data[offset + i] - mean) / mean_square; if(args.size() == 3)
/* if(args.size() == 3)
output[offset + i] = output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i]; (data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else else
output[offset + i] = output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i]; */ (data[offset + i] - mean) / mean_square * weights[i];
// scale and bias handled by onnx parser
} }
}); });
}); });
}
else
{
visit_all(result, args[0])(
[&](auto output, auto data) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
// scale and bias handled by onnx parser
}
});
});
}
return result; return result;
} }
......
...@@ -30,11 +30,11 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -30,11 +30,11 @@ struct parse_layernorm : op_parser<parse_layernorm>
auto layernorm = info.add_instruction( auto layernorm = info.add_instruction(
make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front()); make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front());
if(args.size() == 3) if(args.size() >= 2)
{
layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1)); layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1));
if (args.size() == 3)
layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2)); layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2));
}
return layernorm; return layernorm;
} }
}; };
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <migraphx/op/layernorm.hpp> #include <migraphx/op/layernorm.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp> #include <migraphx/reflect.hpp>
#include <migraphx/argument.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
......
...@@ -8,27 +8,14 @@ namespace gpu { ...@@ -8,27 +8,14 @@ namespace gpu {
shape hip_layernorm::compute_shape(std::vector<shape> inputs) const shape hip_layernorm::compute_shape(std::vector<shape> inputs) const
{ {
std::cout << "compute shape" << std::endl;
inputs.pop_back(); inputs.pop_back();
return op.normalize_compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
/* if (args.size() == 3) device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
{
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2],
tuned_axis);
}
else */
std::cout << "calling device::ln" << std::endl;
{
device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
std::cout << "called device::ln" << std::endl;
}
return args.back(); return args.back();
} }
......
...@@ -389,10 +389,6 @@ struct miopen_apply ...@@ -389,10 +389,6 @@ struct miopen_apply
apply_map.emplace(op_name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
if(op_name == "layernorm")
{
std::cout << "layernorm op" << std::endl;
}
refs.push_back(output); refs.push_back(output);
return mod->replace_instruction(ins, make_op(gpu_name), refs); return mod->replace_instruction(ins, make_op(gpu_name), refs);
......
layernorm_op_test:
8
x
w
boutput"LayerNormalization*
epsilon'7layernorm_op_testZ
x



Z
w

Z
b

b
output



B
\ No newline at end of file
...@@ -11,8 +11,6 @@ struct test_layernorm : verify_program<test_layernorm> ...@@ -11,8 +11,6 @@ struct test_layernorm : verify_program<test_layernorm>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 768}}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 768}});
auto w = mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {768}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {768}});
mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x); mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x);
p.debug_print(); p.debug_print();
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment