Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
295c2abb
"docs/source/Compression/Quantizer.rst" did not exist on "b7d5aac7ebc68a761039c51cc67fe04baad91aa6"
Commit
295c2abb
authored
Oct 26, 2022
by
charlie
Browse files
Shape tests and update implement
parent
87fc2260
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
7 deletions
+73
-7
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+29
-7
test/op_shape_test.cpp
test/op_shape_test.cpp
+44
-0
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
295c2abb
...
@@ -44,7 +44,7 @@ namespace op {
...
@@ -44,7 +44,7 @@ namespace op {
*
*
* 2 input version:
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D
fixed
shape into a higher rank dynamic shape.
* Handles broadcasting a 1D
static
shape into a higher rank dynamic shape.
* broadcast_lens is not used
* broadcast_lens is not used
*/
*/
struct
broadcast
struct
broadcast
...
@@ -69,17 +69,25 @@ struct broadcast
...
@@ -69,17 +69,25 @@ struct broadcast
// the ONNX broadcast op is deprecated now, so not handling the negative
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
if
(
output
.
elements
()
<
s0
.
elements
())
{
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
}
return
output
;
return
output
;
}
}
else
else
...
@@ -87,19 +95,33 @@ struct broadcast
...
@@ -87,19 +95,33 @@ struct broadcast
// two inputs
// two inputs
auto
s1
=
inputs
.
at
(
1
);
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
())
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a static shape, does not handle broadcasting "
{
"a static shape"
);
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape"
);
}
if
(
s0
.
ndim
()
!=
1
)
if
(
s0
.
ndim
()
!=
1
)
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 has ndim "
+
migraphx
::
to_string
(
s0
.
ndim
())
+
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 has ndim "
+
migraphx
::
to_string
(
s0
.
ndim
())
+
", only handle ndim = 1"
);
", only handle ndim = 1"
);
if
(
axis
>
s1
.
ndim
())
}
if
(
axis
>=
s1
.
ndim
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: axis is out of range"
);
MIGRAPHX_THROW
(
"BROADCAST_2in: axis is out of range"
);
if
(
s1
.
ndim
()
-
axis
<
s0
.
ndim
())
}
MIGRAPHX_THROW
(
"BROADCAST_2in: (s1_ndim - axis) is less than s0 ndim"
);
if
(
s1
.
dynamic
())
if
(
s1
.
dynamic
())
{
s0
=
s0
.
to_dynamic
();
if
(
s0
.
dyn_dims
()[
0
]
!=
s1
.
dyn_dims
()[
axis
])
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length"
);
return
s1
;
return
s1
;
}
if
(
s0
.
lens
()[
0
]
!=
s1
.
lens
()[
axis
])
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with static s1 axis dimension length"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
s1
.
lens
(),
std
::
move
(
bcast_strides
)};
shape
output
{
t
,
s1
.
lens
(),
std
::
move
(
bcast_strides
)};
...
...
test/op_shape_test.cpp
View file @
295c2abb
...
@@ -118,6 +118,50 @@ TEST_CASE(broadcast)
...
@@ -118,6 +118,50 @@ TEST_CASE(broadcast)
}
}
}
}
TEST_CASE
(
broadcast_2in
)
{
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
},
{
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
4
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
1
,
0
}},
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
a_input
,
b_input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
0
,
1
}},
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
2
}}),
a_input
,
b_input
);
}
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
},
{
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
}
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
,
2
},
{
2
,
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
2
,
2
,
0
}}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
b_input
,
a_input
);
}
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
4
,
4
,
0
}};
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
dd
};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
4
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
a_input
,
b_input
);
}
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
},
{
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
2
,
2
,
0
}}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
a_input
,
b_input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
2
,
2
,
0
}}},
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
2
}}),
a_input
,
b_input
);
}
}
TEST_CASE
(
convolution_shape
)
TEST_CASE
(
convolution_shape
)
{
{
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment