Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
295c2abb
Commit
295c2abb
authored
Oct 26, 2022
by
charlie
Browse files
Shape tests and update implement
parent
87fc2260
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
7 deletions
+73
-7
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+29
-7
test/op_shape_test.cpp
test/op_shape_test.cpp
+44
-0
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
295c2abb
...
...
@@ -44,7 +44,7 @@ namespace op {
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D
fixed
shape into a higher rank dynamic shape.
* Handles broadcasting a 1D
static
shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct
broadcast
...
...
@@ -69,17 +69,25 @@ struct broadcast
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
{
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
}
return
output
;
}
else
...
...
@@ -87,19 +95,33 @@ struct broadcast
// two inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a static shape, does not handle broadcasting "
"a static shape"
);
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a dynamic shape, does not handle broadcasting "
"a dynamic shape"
);
}
if
(
s0
.
ndim
()
!=
1
)
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 has ndim "
+
migraphx
::
to_string
(
s0
.
ndim
())
+
", only handle ndim = 1"
);
if
(
axis
>
s1
.
ndim
())
}
if
(
axis
>=
s1
.
ndim
())
{
MIGRAPHX_THROW
(
"BROADCAST_2in: axis is out of range"
);
if
(
s1
.
ndim
()
-
axis
<
s0
.
ndim
())
MIGRAPHX_THROW
(
"BROADCAST_2in: (s1_ndim - axis) is less than s0 ndim"
);
}
if
(
s1
.
dynamic
())
{
s0
=
s0
.
to_dynamic
();
if
(
s0
.
dyn_dims
()[
0
]
!=
s1
.
dyn_dims
()[
axis
])
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with dynamic s1 axis "
"dimension length"
);
return
s1
;
}
if
(
s0
.
lens
()[
0
]
!=
s1
.
lens
()[
axis
])
{
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 length doesn't match with static s1 axis dimension length"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
s1
.
lens
(),
std
::
move
(
bcast_strides
)};
...
...
test/op_shape_test.cpp
View file @
295c2abb
...
...
@@ -118,6 +118,50 @@ TEST_CASE(broadcast)
}
}
TEST_CASE
(
broadcast_2in
)
{
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
},
{
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
4
,
1
}};
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
1
,
0
}},
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
a_input
,
b_input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
0
,
1
}},
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
2
}}),
a_input
,
b_input
);
}
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
},
{
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
2
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
}
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
,
2
},
{
2
,
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
2
,
2
,
0
}}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
b_input
,
a_input
);
}
{
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
4
,
4
,
0
}};
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
dd
};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{
4
,
4
},
{
4
,
1
}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
a_input
,
b_input
);
}
{
migraphx
::
shape
a_input
{
migraphx
::
shape
::
float_type
,
{
4
},
{
1
}};
migraphx
::
shape
b_input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
2
,
2
,
0
}}};
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
0
}}),
a_input
,
b_input
);
expect_shape
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
4
,
4
,
0
},
{
2
,
2
,
0
}}},
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
throws_shape
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
2
}}),
a_input
,
b_input
);
}
}
TEST_CASE
(
convolution_shape
)
{
migraphx
::
shape
output
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment