Commit 3135fc93 authored by Paul's avatar Paul
Browse files

Fix bugs in calculating indices

parent 22f006cd
...@@ -122,6 +122,8 @@ template <class F, class... Arguments> ...@@ -122,6 +122,8 @@ template <class F, class... Arguments>
void nary_double_broadcast_vec_impl( void nary_double_broadcast_vec_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape()); assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape(); const auto& b_shape = barg1.get_shape();
...@@ -161,7 +163,7 @@ void nary_double_broadcast_vec_impl( ...@@ -161,7 +163,7 @@ void nary_double_broadcast_vec_impl(
{ {
auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride; auto bidx = ((i * vec_size) % bdim_next_stride) / bdim_stride;
auto b1 = bp[bidx]; auto b1 = bp[bidx];
auto b2 = bp[bidx + bdim_vec_len]; auto b2 = bp[bidx + bdim_len];
auto out = output.data()[i]; auto out = output.data()[i];
for(std::size_t j = 0; j < vec_size; j++) for(std::size_t j = 0; j < vec_size; j++)
{ {
...@@ -177,6 +179,8 @@ template <class F, class... Arguments> ...@@ -177,6 +179,8 @@ template <class F, class... Arguments>
void nary_double_broadcast_impl( void nary_double_broadcast_impl(
hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args) hipStream_t stream, F f, argument result, argument barg1, argument barg2, Arguments... args)
{ {
assert(barg1.get_shape().broadcasted());
assert(barg2.get_shape().broadcasted());
assert(barg1.get_shape() == barg2.get_shape()); assert(barg1.get_shape() == barg2.get_shape());
const auto& output_shape = result.get_shape(); const auto& output_shape = result.get_shape();
const auto& b_shape = barg1.get_shape(); const auto& b_shape = barg1.get_shape();
...@@ -348,12 +352,13 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar ...@@ -348,12 +352,13 @@ inline auto nary(hipStream_t stream, argument result, argument arg, argument bar
template <class... Arguments> template <class... Arguments>
auto nary(hipStream_t stream, argument result, Arguments... args) auto nary(hipStream_t stream, argument result, Arguments... args)
{ {
static_assert(sizeof...(args) > 2, "Args needs to be greater than 2");
return [=](auto f) { return [=](auto f) {
auto barg1 = back_args(args...); auto barg1 = back_args(args...);
bool fallback1 = pop_back_args(args...)([&](auto&&... args2) { bool fallback1 = pop_back_args(args...)([&](auto&&... args2) {
auto barg2 = back_args(args2...); auto barg2 = back_args(args2...);
bool fallback2 = bool fallback2 =
barg2.get_shape() == barg1.get_shape() and barg2.get_shape().broadcasted() and barg2.get_shape() != barg1.get_shape() or not barg2.get_shape().broadcasted() or
pop_back_args(args2...)([&](auto&&... args3) { pop_back_args(args2...)([&](auto&&... args3) {
bool divisible_by_4 = false; bool divisible_by_4 = false;
if(broadcastable(divisible_by_4, 1024, result, barg2, args3...)) if(broadcastable(divisible_by_4, 1024, result, barg2, args3...))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment