Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3101f6fe
Commit
3101f6fe
authored
Jun 13, 2022
by
Paul
Browse files
Add step to unsqeeze
parent
aa7ff911
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
14 deletions
+73
-14
src/include/migraphx/op/unsqueeze.hpp
src/include/migraphx/op/unsqueeze.hpp
+19
-7
test/op_shape_test.cpp
test/op_shape_test.cpp
+54
-7
No files found.
src/include/migraphx/op/unsqueeze.hpp
View file @
3101f6fe
...
@@ -19,11 +19,12 @@ namespace op {
...
@@ -19,11 +19,12 @@ namespace op {
struct
unsqueeze
struct
unsqueeze
{
{
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
steps
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
{
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
return
pack
(
f
(
self
.
axes
,
"axes"
)
,
f
(
self
.
steps
,
"steps"
)
);
}
}
value
attributes
()
const
value
attributes
()
const
...
@@ -57,16 +58,27 @@ struct unsqueeze
...
@@ -57,16 +58,27 @@ struct unsqueeze
std
::
size_t
p
=
0
;
std
::
size_t
p
=
0
;
for
(
auto
i
:
range
(
new_size
))
for
(
auto
i
:
range
(
new_size
))
{
{
if
(
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
!=
axes
.
end
())
auto
axis_idx
=
std
::
find
(
axes
.
begin
(),
axes
.
end
(),
i
)
-
axes
.
begin
();
if
(
axis_idx
<
axes
.
size
())
{
{
new_lens
[
i
]
=
1
;
std
::
int64_t
step
=
1
;
if
(
p
==
0
)
// unsqueeze on the first axes
if
(
axis_idx
<
steps
.
size
())
step
=
steps
[
axis_idx
];
if
(
step
==
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: step must be non-zero"
);
new_lens
[
i
]
=
step
;
if
(
p
<
old_strides
.
size
())
{
{
new_strides
[
i
]
=
old_lens
[
0
]
*
old_strides
[
0
];
if
((
old_lens
[
p
]
%
step
)
!=
0
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Axis dimenstion is not divisible by step"
);
old_lens
[
p
]
/=
step
;
new_strides
[
i
]
=
old_strides
[
p
]
*
old_lens
[
p
];
}
}
else
// unsqueeze on middle or last axes
else
{
{
new_strides
[
i
]
=
(
p
<
old_strides
.
size
())
?
old_strides
[
p
-
1
]
:
1
;
if
(
step
!=
1
)
MIGRAPHX_THROW
(
"UNSQUEEZE: Step must be 1 for extra axes"
);
new_strides
[
i
]
=
1
;
}
}
}
}
else
else
...
...
test/op_shape_test.cpp
View file @
3101f6fe
...
@@ -1510,15 +1510,40 @@ TEST_CASE(test_squeeze_wrong_axis)
...
@@ -1510,15 +1510,40 @@ TEST_CASE(test_squeeze_wrong_axis)
TEST_CASE
(
test_unsqueeze
)
TEST_CASE
(
test_unsqueeze
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
3
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
1
,
3
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_step
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
12
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
2
,
6
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_step_non_divisable
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_step_non_zero
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
12
}};
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
0
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_step_at_end
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
12
}};
throws_shape
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
3
}},
{
"steps"
,
{
2
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_negative_axis
)
TEST_CASE
(
test_unsqueeze_negative_axis
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
3
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
3
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
1
,
3
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
-
2
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
-
2
}}}),
s1
);
}
}
...
@@ -1544,21 +1569,29 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
...
@@ -1544,21 +1569,29 @@ TEST_CASE(test_unsqueeze_scalar_tensor2)
TEST_CASE
(
test_unsqueeze_transpose
)
TEST_CASE
(
test_unsqueeze_transpose
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
3
},
{
12
,
1
,
4
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
3
},
{
12
,
1
,
4
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
3
},
{
12
,
1
,
1
,
4
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
3
},
{
12
,
1
,
1
2
,
4
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_transpose_step
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
6
},
{
24
,
1
,
4
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
2
,
3
},
{
24
,
1
,
12
,
4
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
2
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_multibroadcast
)
TEST_CASE
(
test_unsqueeze_multibroadcast
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
},
{
0
,
1
,
0
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
},
{
0
,
1
,
0
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
4
},
{
0
,
1
,
1
,
0
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
4
},
{
0
,
1
,
0
,
0
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_slice
)
TEST_CASE
(
test_unsqueeze_slice
)
{
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
},
{
108
,
36
,
1
}};
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
},
{
108
,
36
,
1
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
4
},
{
108
,
36
,
36
,
1
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1
,
4
},
{
108
,
36
,
4
,
1
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}}}),
s1
);
}
}
...
@@ -1590,6 +1623,20 @@ TEST_CASE(test_unsqueeze_multiple_axes_2)
...
@@ -1590,6 +1623,20 @@ TEST_CASE(test_unsqueeze_multiple_axes_2)
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
s1
);
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
s1
);
}
}
TEST_CASE
(
test_unsqueeze_multiple_axes_3
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
1
,
5
,
1
,
1
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
,
4
,
5
}}}),
s1
);
}
TEST_CASE
(
test_unsqueeze_multiple_axes_4
)
{
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
5
}};
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
3
,
4
,
1
,
5
,
1
,
1
}};
expect_shape
(
s2
,
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
5
,
4
,
2
}}}),
s1
);
}
TEST_CASE
(
transpose_shape
)
TEST_CASE
(
transpose_shape
)
{
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment