Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
87fc2260
Commit
87fc2260
authored
Oct 26, 2022
by
charlie
Browse files
initial
parent
57884353
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
29 deletions
+56
-29
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+56
-29
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
87fc2260
...
@@ -33,18 +33,24 @@ namespace migraphx {
...
@@ -33,18 +33,24 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
namespace
op
{
/// The broadcast operator performs the numpy-style broadcasting of an axis of a given tensor. This
/**
/// is achieved primarily by setting the stride of the broadcasted axis to zero. Linear indicies are
* 1 input version:
/// computed from multi-indicies by computing the inner product on the multi-index with the strides.
* Broadcasts a tensor from the original shape to the broadcast_lens by setting the stride of
/// For example, if we have a tensor A(2,3) it has lengths of (2,3) and strides of (3,1). If we want
* broadcasted dimensions to zero. `axis` attribute for a 1D input shape is the output dimension
/// to compute the linear offset that corresponds to the element on the 2nd row (i = 1) and 3rd
* that stays the same. ex: broadcasting shape [1024] -> [4, 1024, 3] has axis = 1 For higher rank
/// column (j = 2), we compute the following inner product (1,2) dot (3, 1) = 1*3 + 2*1 = 5. It is
* input shapes, axis is an offset parameter for the broadcasting. Such that this operator would
/// obvious from there that we can negate the effects of a given axis by setting the stride of that
* work in the opposite direction of NumPy broadcasting. ex: broadcasting shape [2, 2] -> [2, 2, 3]
/// axis to zero.
* with axis = 0
*
* 2 input version:
* Broadcast the first input 1D shape into the second input shape based on the axis parameter.
* Handles broadcasting a 1D fixed shape into a higher rank dynamic shape.
* broadcast_lens is not used
*/
struct
broadcast
struct
broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
=
{}
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -55,33 +61,54 @@ struct broadcast
...
@@ -55,33 +61,54 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
s0
=
inputs
.
at
(
0
);
auto
s0
=
inputs
.
at
(
0
);
auto
t
=
s0
.
type
();
auto
t
=
s0
.
type
();
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
if
(
inputs
.
size
()
==
1
)
// the broadcast op is deprecated now, so not handling the negative
{
// the ONNX broadcast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
MIGRAPHX_THROW
(
"BROADCAST : axis is out of range"
);
}
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
if
(
broadcast_lens
.
size
()
-
axis
<
s0
.
lens
().
size
())
{
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
MIGRAPHX_THROW
(
"BROADCAST: (broadcast ndims - axis) is less than s0 ndims"
);
}
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
if
(
not
std
::
equal
(
s0
.
lens
().
begin
(),
s0
.
lens
().
end
(),
broadcast_lens
.
begin
()
+
axis
))
{
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
MIGRAPHX_THROW
(
"BROADCAST: when broadcasting, succeeding sizes must match"
);
}
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
shape
output
{
t
,
broadcast_lens
,
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
MIGRAPHX_THROW
(
"BROADCAST: output size must be greater than or equal to s0 size"
);
return
output
;
return
output
;
}
}
else
{
// two inputs
auto
s1
=
inputs
.
at
(
1
);
if
(
s0
.
dynamic
())
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 is a static shape, does not handle broadcasting "
"a static shape"
);
if
(
s0
.
ndim
()
!=
1
)
MIGRAPHX_THROW
(
"BROADCAST_2in: s0 has ndim "
+
migraphx
::
to_string
(
s0
.
ndim
())
+
", only handle ndim = 1"
);
if
(
axis
>
s1
.
ndim
())
MIGRAPHX_THROW
(
"BROADCAST_2in: axis is out of range"
);
if
(
s1
.
ndim
()
-
axis
<
s0
.
ndim
())
MIGRAPHX_THROW
(
"BROADCAST_2in: (s1_ndim - axis) is less than s0 ndim"
);
if
(
s1
.
dynamic
())
return
s1
;
std
::
vector
<
size_t
>
bcast_strides
(
s1
.
ndim
(),
0
);
std
::
copy
(
s0
.
strides
().
begin
(),
s0
.
strides
().
end
(),
bcast_strides
.
begin
()
+
axis
);
shape
output
{
t
,
s1
.
lens
(),
std
::
move
(
bcast_strides
)};
if
(
output
.
elements
()
<
s0
.
elements
())
MIGRAPHX_THROW
(
"BROADCAST_2in: output size must be greater than or equal to s0 size"
);
return
output
;
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment