Commit ba72ce42 authored by Paul's avatar Paul
Browse files

Fix relements with vectors

parent 7cc97df7
...@@ -40,11 +40,9 @@ __device__ void generic_binary_layernorm( ...@@ -40,11 +40,9 @@ __device__ void generic_binary_layernorm(
F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs) F compute, BinOp op, Output output, Input1 input1, Input2 input2, Inputs... inputs)
{ {
using reduce_output = reduce::with_axis<Input1, Axis>; using reduce_output = reduce::with_axis<Input1, Axis>;
constexpr auto relements =
get_shape_c<Input1>{}.elements() / get_shape_c<reduce_output>{}.elements();
MIGRAPHX_ASSERT(relements > 0);
reduce::block::run<reduce_output>([&](auto, auto r) { reduce::block::run<reduce_output>([&](auto, auto r) {
using value_type = typename Input1::type; using value_type = typename Input1::type;
constexpr auto relements = r.template elements<Input1>();
auto mean = [&](auto f) { auto mean = [&](auto f) {
return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) { return r.reduce(op::sum{}, 0, [&](auto x1, auto x2) {
return f(x1, x2) / value_type{relements}; return f(x1, x2) / value_type{relements};
......
...@@ -224,6 +224,18 @@ struct block ...@@ -224,6 +224,18 @@ struct block
idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); }); idx.local_stride(x.get_shape().elements(), [&](auto j) { f(x[j], xs[j]...); });
}); });
} }
template<class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
using value_type = typename Input::type;
constexpr auto relements = get_shape_c<reduce_type>{}.elements();
if constexpr(vec_size<value_type>() > 1)
return relements * vec_size<value_type>();
else
return relements;
}
}; };
template <class Slicer> template <class Slicer>
...@@ -281,6 +293,13 @@ struct lane ...@@ -281,6 +293,13 @@ struct lane
} }
}); });
} }
template<class Input>
constexpr auto elements() const
{
using reduce_type = decltype(slicer(Input{}));
return get_shape_c<reduce_type>{}.elements();
}
}; };
template <class Slicer> template <class Slicer>
......
...@@ -175,7 +175,7 @@ template <class T, class Op> ...@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr auto vec_reduce(T x, Op op) constexpr auto vec_reduce(T x, Op op)
{ {
if constexpr(vec_size<T>() < 2) if constexpr(vec_size<T>() < 2)
return x; return vec_type<T>{x};
else else
{ {
vec_type<T> result = x[0]; vec_type<T> result = x[0];
......
...@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm> ...@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
std::vector<size_t> dims = {1, 1, 5}; std::vector<size_t> dims = {1, 2, 5};
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims}); auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
add_layernorm(*mm, x, dims); add_layernorm(*mm, x, dims);
return p; return p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment