"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "55333759b0850cd5c398f27839fb632fe21c7943"
Unverified Commit 5b53552d authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Fix shapes check for allocate (#2258)

parent 28614abd
...@@ -49,17 +49,22 @@ struct allocate ...@@ -49,17 +49,22 @@ struct allocate
shape compute_shape(const std::vector<shape>& inputs) const shape compute_shape(const std::vector<shape>& inputs) const
{ {
migraphx::check_shapes{inputs, *this, true}.has(0, 1);
// check if shape attribute is not default
if(s != shape()) if(s != shape())
{ {
if(inputs.size() == 1)
{
migraphx::check_shapes{inputs, *this, false}.only_dims(1);
}
else
{
migraphx::check_shapes{inputs, *this, false}.has(0);
}
return s; return s;
} }
else else
{ {
migraphx::check_shapes{inputs, *this, false}.has(1).only_dims(1);
const auto& out_dims = inputs.at(0); const auto& out_dims = inputs.at(0);
assert(not out_dims.dynamic());
assert(out_dims.ndim() == 1);
std::size_t max_val = std::numeric_limits<std::size_t>::max(); std::size_t max_val = std::numeric_limits<std::size_t>::max();
std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0), std::vector<shape::dynamic_dimension> dyn_dims(out_dims.lens().at(0),
shape::dynamic_dimension{0, max_val}); shape::dynamic_dimension{0, max_val});
......
...@@ -88,6 +88,13 @@ TEST_CASE(allocate_static) ...@@ -88,6 +88,13 @@ TEST_CASE(allocate_static)
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}})); expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}));
} }
TEST_CASE(allocate_static_input_error)
{
migraphx::shape input{migraphx::shape::int64_type, {3}};
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4}};
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}), input);
}
TEST_CASE(allocate_dyn) TEST_CASE(allocate_dyn)
{ {
migraphx::shape input{migraphx::shape::int64_type, {2}}; migraphx::shape input{migraphx::shape::int64_type, {2}};
...@@ -109,6 +116,14 @@ TEST_CASE(allocate_dyn_with_shape_attr) ...@@ -109,6 +116,14 @@ TEST_CASE(allocate_dyn_with_shape_attr)
input); input);
} }
TEST_CASE(allocate_dyn_no_input_error)
{
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
expect_shape(shape_attr,
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
}
TEST_CASE(argmax_axis0) TEST_CASE(argmax_axis0)
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment