Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dfbfd078
Commit
dfbfd078
authored
Feb 16, 2022
by
Shucai Xiao
Browse files
Merge branch 'shape_op' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into parse_dynamic_shape
parents
ecb1545c
f20d6acb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
7 deletions
+52
-7
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/op/shape_op.hpp
src/include/migraphx/op/shape_op.hpp
+38
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/onnx/parse_shape.cpp
src/onnx/parse_shape.cpp
+12
-7
No files found.
src/CMakeLists.txt
View file @
dfbfd078
...
...
@@ -161,6 +161,7 @@ register_migraphx_ops(
rsqrt
scalar
scatter
shape_op
sigmoid
sign
sinh
...
...
src/include/migraphx/op/shape_op.hpp
0 → 100644
View file @
dfbfd078
#ifndef MIGRAPHX_GUARD_OPERATORS_SHAPE_OP_HPP
#define MIGRAPHX_GUARD_OPERATORS_SHAPE_OP_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/context.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
shape_op
{
std
::
string
name
()
const
{
return
"shape"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
std
::
vector
<
std
::
size_t
>
lens
=
{
inputs
[
0
].
lens
().
size
()};
return
{
shape
::
int64_type
,
lens
};
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
auto
lens
=
args
.
front
().
get_shape
().
lens
();
result
.
visit
([
&
](
auto
v
)
{
std
::
copy
(
lens
.
begin
(),
lens
.
end
(),
v
.
begin
());
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
dfbfd078
...
...
@@ -86,6 +86,7 @@
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/shape_op.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
...
...
src/onnx/parse_shape.cpp
View file @
dfbfd078
...
...
@@ -19,14 +19,19 @@ struct parse_shape : op_parser<parse_shape>
std
::
vector
<
instruction_ref
>
args
)
const
{
if
(
args
.
size
()
!=
1
)
{
MIGRAPHX_THROW
(
"Shape: operator should have 1 operand"
);
std
::
vector
<
std
::
size_t
>
arg_shape
=
args
[
0
]
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
vec_shape
(
arg_shape
.
size
());
migraphx
::
shape
s
(
migraphx
::
shape
::
int64_type
,
{
arg_shape
.
size
()});
std
::
transform
(
arg_shape
.
begin
(),
arg_shape
.
end
(),
vec_shape
.
begin
(),
[](
auto
i
)
{
return
int64_t
(
i
);
});
return
info
.
add_literal
(
migraphx
::
literal
{
s
,
vec_shape
});
}
return
info
.
add_instruction
(
make_op
(
"shape"
),
args
);
// std::vector<std::size_t> arg_shape = args[0]->get_shape().lens();
// std::vector<int64_t> vec_shape(arg_shape.size());
// migraphx::shape s(migraphx::shape::int64_type, {arg_shape.size()});
// std::transform(arg_shape.begin(), arg_shape.end(), vec_shape.begin(), [](auto i) {
// return int64_t(i);
// });
// return info.add_literal(migraphx::literal{s, vec_shape});
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment