Commit a67e6327 authored by charlie's avatar charlie
Browse files

fixed-fixed test and fix strides bug

parent bb7c3a25
......@@ -64,11 +64,11 @@ struct multibroadcast
MIGRAPHX_THROW("MULTIBROADCAST: inputs dimensions should be > 0");
}
auto make_bcast_strides = [&](std::size_t out_num_dims, std::size_t offset) {
std::vector<size_t> bcast_strides(out_num_dims, 0);
auto make_bcast_strides = [&](std::vector<std::size_t> bcast_lens, std::size_t offset) {
std::vector<size_t> bcast_strides(bcast_lens.size(), 0);
for(std::ptrdiff_t i = s0.lens().size() - 1; i >= 0; i--)
{
if(output_lens[i + offset] == s0.lens()[i])
if(bcast_lens[i + offset] == s0.lens()[i])
{
bcast_strides[i + offset] = s0.strides()[i];
}
......@@ -94,7 +94,7 @@ struct multibroadcast
}
}
auto bcast_strides = make_bcast_strides(output_lens.size(), offset);
auto bcast_strides = make_bcast_strides(output_lens, offset);
return {t, output_lens, std::move(bcast_strides)};
}
else
......@@ -119,8 +119,8 @@ struct multibroadcast
else
{
auto bcast_lens = compute_broadcasted_lens(s0.lens(), s1.lens());
auto offset = s1.lens().size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(s1.lens().size(), offset);
auto offset = bcast_lens.size() - s0.lens().size();
auto bcast_strides = make_bcast_strides(bcast_lens, offset);
return {t, std::move(bcast_lens), std::move(bcast_strides)};
}
}
......
......@@ -1216,7 +1216,53 @@ TEST_CASE(multibroadcast_2in)
migraphx::shape b_shape{migraphx::shape::float_type, b};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
}
// both inputs are fixed
// fixed-fixed
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 6}};
migraphx::shape b_shape{migraphx::shape::float_type, {3, 6}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 6}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {1, 8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}, {0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {8}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4, 8}, {0, 0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {16, 4, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 1, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {3, 4, 4}, {4, 0, 1}},
migraphx::make_op("multibroadcast"),
a_shape,
b_shape);
}
{
migraphx::shape a_shape{migraphx::shape::float_type, {3, 4, 4}};
migraphx::shape b_shape{migraphx::shape::float_type, {4, 3}};
throws_shape(migraphx::make_op("multibroadcast"), a_shape, b_shape);
}
}
TEST_CASE(multinomial)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment