Commit ae513aa8 authored by turneram's avatar turneram
Browse files

Formatting

parent f89638ec
...@@ -72,37 +72,10 @@ struct layernorm ...@@ -72,37 +72,10 @@ struct layernorm
mean_inv_std_dev_dim.at(i) = 1; mean_inv_std_dev_dim.at(i) = 1;
} */ } */
if (args.size() == 3) if(args.size() == 3)
{ {
visit_all(result, args[0], args[1], args[2])( visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto data, auto weights, auto bias) { [&](auto output, auto data, auto weights, auto bias) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
if(args.size() == 3)
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i];
}
});
});
}
else
{
visit_all(result, args[0])(
[&](auto output, auto data) {
par_for(norm_count, [&](auto idx) { par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size; auto offset = idx * norm_size;
double mean = 0; double mean = 0;
...@@ -116,12 +89,38 @@ struct layernorm ...@@ -116,12 +89,38 @@ struct layernorm
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i) for(std::size_t i = 0; i < norm_size; ++i)
{ {
output[offset + i] = (data[offset + i] - mean) / mean_square; if(args.size() == 3)
// scale and bias handled by onnx parser output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i];
} }
}); });
}); });
} }
else
{
visit_all(result, args[0])([&](auto output, auto data) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
// scale and bias handled by onnx parser
}
});
});
}
return result; return result;
} }
......
...@@ -32,9 +32,9 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -32,9 +32,9 @@ struct parse_layernorm : op_parser<parse_layernorm>
if(args.size() >= 2) if(args.size() >= 2)
layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1)); layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1));
if (args.size() == 3) if(args.size() == 3)
layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2)); layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2));
return layernorm; return layernorm;
} }
}; };
......
...@@ -15,7 +15,7 @@ shape hip_layernorm::compute_shape(std::vector<shape> inputs) const ...@@ -15,7 +15,7 @@ shape hip_layernorm::compute_shape(std::vector<shape> inputs) const
argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{ {
device::layernorm(ctx.get_stream().get(), args.back(), args[0]); device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
return args.back(); return args.back();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment