Unverified Commit 2140fe19 authored by shivadbhavsar's avatar shivadbhavsar Committed by GitHub
Browse files

scalar unsqueeze broadcast support (#1753)

Adding support for broadcasted scalars to unsqueeze op.

Specifying steps other than 1 is disallowed in this implementation since we want the output the always be a tensor. We can support varying step sizes if we allow a broadcasted scalar output from this op.
parent 177e5dbc
......@@ -95,13 +95,10 @@ struct unsqueeze
auto type = input_shape.type();
auto old_lens = input_shape.lens();
auto old_strides = input_shape.strides();
if(input_shape.scalar())
{
if(old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
else
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
}
auto is_scalar = input_shape.scalar();
if(is_scalar and old_lens.size() == 1 and old_lens.front() == 1)
return shape{type, old_lens};
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
......@@ -121,13 +118,15 @@ struct unsqueeze
step = steps[axis_idx];
if(step == 0)
MIGRAPHX_THROW("UNSQUEEZE: step must be non-zero");
if(is_scalar and step != 1)
MIGRAPHX_THROW("UNSQUEEZE: step must be 1 when input is scalar");
new_lens[i] = step;
if(p < old_strides.size())
{
if((old_lens[p] % step) != 0)
MIGRAPHX_THROW("UNSQUEEZE: Axis dimenstion is not divisible by step");
old_lens[p] /= step;
new_strides[i] = old_strides[p] * old_lens[p];
new_strides[i] = is_scalar ? 1 : old_strides[p] * old_lens[p];
}
else
{
......
......@@ -3141,14 +3141,22 @@ TEST_CASE(test_unsqueeze_scalar)
TEST_CASE(test_unsqueeze_scalar_tensor1)
{
migraphx::shape s{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
migraphx::shape s1{migraphx::shape::float_type, {4, 3, 3}, {0, 0, 0}};
migraphx::shape s2{migraphx::shape::float_type, {4, 3, 1, 3}, {0, 0, 1, 0}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar_tensor2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {-2}}}), s);
migraphx::shape s1{migraphx::shape::float_type, {1, 1, 1}, {0, 0, 0}};
migraphx::shape s2{migraphx::shape::float_type, {1, 1, 1, 1}, {0, 0, 0, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {-1}}}), s1);
}
TEST_CASE(test_unsqueeze_scalar_step)
{
migraphx::shape s{migraphx::shape::float_type, {6, 1, 2}, {0, 0, 0}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {0}}, {"steps", {3}}}), s);
}
TEST_CASE(test_unsqueeze_transpose)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment