"...composable_kernel_rocm.git" did not exist on "0345963eef4f92e9c5eab608bb8557b5463a1dcb"
Commit 702a8092 authored by charlie's avatar charlie
Browse files

Update tests, possible throw that won't occur

parent f888f611
...@@ -86,6 +86,7 @@ struct broadcast ...@@ -86,6 +86,7 @@ struct broadcast
shape output{t, broadcast_lens, std::move(bcast_strides)}; shape output{t, broadcast_lens, std::move(bcast_strides)};
if(output.elements() < s0.elements()) if(output.elements() < s0.elements())
{ {
// don't think this can occur?
MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size"); MIGRAPHX_THROW("BROADCAST: output size must be greater than or equal to s0 size");
} }
return output; return output;
...@@ -125,9 +126,6 @@ struct broadcast ...@@ -125,9 +126,6 @@ struct broadcast
std::vector<size_t> bcast_strides(s1.ndim(), 0); std::vector<size_t> bcast_strides(s1.ndim(), 0);
std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis); std::copy(s0.strides().begin(), s0.strides().end(), bcast_strides.begin() + axis);
shape output{t, s1.lens(), std::move(bcast_strides)}; shape output{t, s1.lens(), std::move(bcast_strides)};
if(output.elements() < s0.elements())
MIGRAPHX_THROW(
"BROADCAST_2in: output size must be greater than or equal to s0 size");
return output; return output;
} }
} }
......
...@@ -118,48 +118,67 @@ TEST_CASE(broadcast) ...@@ -118,48 +118,67 @@ TEST_CASE(broadcast)
} }
} }
TEST_CASE(broadcast_2in) TEST_CASE(broadcast_axis_out_of_range_error)
{ {
{ std::vector<std::size_t> lens{1, 1};
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::shape input{migraphx::shape::float_type, {1}, {0}};
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}}; throws_shape(migraphx::make_op("broadcast", {{"axis", 4}, {"out_lens", lens}}), input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}}, }
migraphx::make_op("broadcast", {{"axis", 0}}),
a_input, TEST_CASE(broadcast_2in_static_static)
b_input); {
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}}, migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::make_op("broadcast", {{"axis", 1}}), migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
a_input, expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {1, 0}},
b_input); migraphx::make_op("broadcast", {{"axis", 0}}),
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input); a_input,
} b_input);
{ expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 4}, {0, 1}},
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::make_op("broadcast", {{"axis", 1}}),
migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}}; a_input,
throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input); b_input);
} throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
{ }
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}}; TEST_CASE(broadcast_2in_not_matching_error)
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input); {
} migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
{ migraphx::shape b_input{migraphx::shape::float_type, {2, 2}, {2, 1}};
std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}}; throws_shape(migraphx::make_op("broadcast", {{"axis", 1}}), a_input, b_input);
migraphx::shape a_input{migraphx::shape::float_type, dd}; }
migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input); TEST_CASE(broadcast_2in_dynamic_s0_error1)
} {
{ migraphx::shape a_input{migraphx::shape::float_type, {4, 2}, {2, 1}};
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}}; migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}}; throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), b_input, a_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input); }
expect_shape(
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}}, TEST_CASE(broadcast_2in_dynamic_s0_error2)
migraphx::make_op("broadcast", {{"axis", 1}}), {
a_input, std::vector<migraphx::shape::dynamic_dimension> dd{{4, 4, 0}};
b_input); migraphx::shape a_input{migraphx::shape::float_type, dd};
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input); migraphx::shape b_input{migraphx::shape::float_type, {4, 4}, {4, 1}};
} throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_static_dyn)
{
migraphx::shape a_input{migraphx::shape::float_type, {4}, {1}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}},
migraphx::make_op("broadcast", {{"axis", 1}}),
a_input,
b_input);
throws_shape(migraphx::make_op("broadcast", {{"axis", 2}}), a_input, b_input);
}
TEST_CASE(broadcast_2in_dyn_s0_ndim_greater_than_1_error)
{
migraphx::shape a_input{migraphx::shape::float_type, {4, 2}};
migraphx::shape b_input{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {2, 2, 0}}};
throws_shape(migraphx::make_op("broadcast", {{"axis", 0}}), a_input, b_input);
} }
TEST_CASE(convolution_shape) TEST_CASE(convolution_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment