Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
78c799c5
Commit
78c799c5
authored
Sep 30, 2022
by
charlie
Browse files
Initial
parent
b4bbdde5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
26 deletions
+52
-26
src/include/migraphx/op/broadcast.hpp
src/include/migraphx/op/broadcast.hpp
+3
-1
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+49
-25
No files found.
src/include/migraphx/op/broadcast.hpp
View file @
78c799c5
...
@@ -44,6 +44,7 @@ struct broadcast
...
@@ -44,6 +44,7 @@ struct broadcast
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
std
::
size_t
>
broadcast_lens
;
std
::
vector
<
shape
::
dynamic_dimension
>
broadcast_dyn_dims
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -54,11 +55,12 @@ struct broadcast
...
@@ -54,11 +55,12 @@ struct broadcast
std
::
string
name
()
const
{
return
"broadcast"
;
}
std
::
string
name
()
const
{
return
"broadcast"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
broadcast_lens
.
size
(),
0
);
// the broacast op is deprecated now, so not handling the negative
// the broa
d
cast op is deprecated now, so not handling the negative
// value of axis anymore
// value of axis anymore
if
(
axis
>=
broadcast_lens
.
size
())
if
(
axis
>=
broadcast_lens
.
size
())
{
{
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
78c799c5
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/common.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -46,44 +48,66 @@ struct multibroadcast
...
@@ -46,44 +48,66 @@ struct multibroadcast
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
);
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
auto
t
=
inputs
.
at
(
0
).
type
();
auto
input_shape
=
inputs
.
at
(
0
);
if
(
inputs
.
size
()
==
1
)
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
if
(
input_shape
.
lens
().
empty
())
}
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
input
_shape
.
lens
().
size
()
>
output_lens
.
size
())
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
}
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
input_shape
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
input_shape
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
input
.
lens
()[
i
]
!=
1
)
{
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
if
(
output_lens
[
i
+
offset
]
!=
input_shape
.
lens
()[
i
]
and
input_shape
.
lens
()[
i
]
!=
1
)
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
{
"}!"
);
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input_shape
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
input_shape
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
==
input_shape
.
lens
()[
i
])
{
bcast_strides
[
i
+
offset
]
=
input_shape
.
strides
()[
i
];
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
else
{
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
// need both shapes when handling dynamic case
// shapes can be dynamic (at compile-time) or static (at evaluation time)
// this function will be called through compute_output_shape conversion to dyn_output
// new compute_broadcasted_lens for dynamic shapes
auto
other_shape
=
inputs
.
at
(
1
);
if
(
input_shape
.
dynamic
()
and
other_shape
.
dynamic
())
{}
else
if
(
not
input_shape
.
dynamic
()
and
not
other_shape
.
dynamic
())
{
{
bcast_strides
[
i
+
offset
]
=
input
.
strides
()[
i
];
auto
output_lens
=
compute_broadcasted_lens
(
input_shape
.
lens
(),
other_shape
.
lens
());
}
else
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input_shape and other_shape are not both dynamic or static"
);
}
}
}
}
return
{
t
,
output_lens
,
bcast_strides
};
}
}
argument
compute
(
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
return
args
[
0
].
reshape
(
out
put_shape
);
return
args
[
0
].
reshape
(
dyn_out
.
com
put
ed
_shape
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment