"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "fed23ec79685ed0f5051d72bd98a959116c27641"
Commit e47c0140 authored by turneram's avatar turneram
Browse files

Formatting

parent fe7562cb
......@@ -72,37 +72,10 @@ struct layernorm
mean_inv_std_dev_dim.at(i) = 1;
} */
if (args.size() == 3)
if(args.size() == 3)
{
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto data, auto weights, auto bias) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
if(args.size() == 3)
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i];
}
});
});
}
else
{
visit_all(result, args[0])(
[&](auto output, auto data) {
[&](auto output, auto data, auto weights, auto bias) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
......@@ -116,12 +89,38 @@ struct layernorm
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
// scale and bias handled by onnx parser
if(args.size() == 3)
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i];
}
});
});
}
else
{
visit_all(result, args[0])([&](auto output, auto data) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
// scale and bias handled by onnx parser
}
});
});
}
return result;
}
......
......@@ -32,9 +32,9 @@ struct parse_layernorm : op_parser<parse_layernorm>
if(args.size() >= 2)
layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1));
if (args.size() == 3)
if(args.size() == 3)
layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2));
return layernorm;
}
};
......
......@@ -15,7 +15,7 @@ shape hip_layernorm::compute_shape(std::vector<shape> inputs) const
argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
return args.back();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment