"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "2fdf510d05a11280fff4688aa231491be98ef8d6"
Commit 8b8395be authored by Paul's avatar Paul
Browse files

Apply vec_reduce first

parent d8038982
......@@ -49,9 +49,9 @@ __device__ void generic_binary_layernorm(
reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto means = r.reduce(op::sum{}, make_array<value_type>(0, 0), [&](auto x1, auto x2) {
auto means = r.reduce(op::sum{}, make_array<vec_type<value_type>>(0, 0), [&](auto x1, auto x2) {
auto x = op(x1, x2);
return make_array(x, x * x) / value_type{relements};
return make_array(x, x * x) / vec_type<value_type>{relements};
})(input1, input2);
auto mean_x = means[0];
......
......@@ -201,12 +201,11 @@ struct block
__device__ auto reduce(Op op, T init, Read read) const
{
return sliced(slicer, [=](auto x, auto... xs) {
return vec_reduce(block_reduce(idx,
return block_reduce(idx,
op,
init,
x.get_shape().elements(),
[&](auto j) { return read(x[j], xs[j]...); }),
op);
[&](auto j) { return vec_reduce(read(x[j], xs[j]...), op); });
});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment