Commit fe7562cb authored by turneram's avatar turneram
Browse files

Update ref to match gpu and handle scale and bias in parser

parent e2bbfca1
......@@ -72,6 +72,8 @@ struct layernorm
mean_inv_std_dev_dim.at(i) = 1;
} */
if (args.size() == 3)
{
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto data, auto weights, auto bias) {
par_for(norm_count, [&](auto idx) {
......@@ -87,17 +89,39 @@ struct layernorm
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
/* if(args.size() == 3)
if(args.size() == 3)
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i]; */
(data[offset + i] - mean) / mean_square * weights[i];
}
});
});
}
else
{
visit_all(result, args[0])(
[&](auto output, auto data) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
// scale and bias handled by onnx parser
}
});
});
}
return result;
}
......
......@@ -30,11 +30,11 @@ struct parse_layernorm : op_parser<parse_layernorm>
auto layernorm = info.add_instruction(
make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front());
if(args.size() == 3)
{
if(args.size() >= 2)
layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1));
if (args.size() == 3)
layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2));
}
return layernorm;
}
};
......
......@@ -423,7 +423,6 @@ layernorm_half(void* in1, void* data_out, index_int batch_item_num, index_int bl
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
{
std::cout << "void layernorm" << std::endl;
auto in_s = arg1.get_shape();
auto type = in_s.type();
auto batch_item_num = in_s.lens().back();
......
......@@ -4,6 +4,7 @@
#include <migraphx/op/layernorm.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/argument.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......
......@@ -8,27 +8,14 @@ namespace gpu {
shape hip_layernorm::compute_shape(std::vector<shape> inputs) const
{
std::cout << "compute shape" << std::endl;
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
/* if (args.size() == 3)
{
auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2],
tuned_axis);
}
else */
std::cout << "calling device::ln" << std::endl;
{
device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
std::cout << "called device::ln" << std::endl;
}
return args.back();
}
......
......@@ -394,10 +394,6 @@ struct miopen_apply
apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs();
if(op_name == "layernorm")
{
std::cout << "layernorm op" << std::endl;
}
refs.push_back(output);
return mod->replace_instruction(ins, make_op(gpu_name), refs);
......
layernorm_op_test:
8
x
w
boutput"LayerNormalization*
epsilon'7layernorm_op_testZ
x



Z
w

Z
b

b
output



B
\ No newline at end of file
......@@ -11,8 +11,6 @@ struct test_layernorm : verify_program<test_layernorm>
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 768}});
auto w = mm->add_parameter("w", migraphx::shape{migraphx::shape::float_type, {768}});
auto b = mm->add_parameter("b", migraphx::shape{migraphx::shape::float_type, {768}});
mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x);
p.debug_print();
return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment