Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1adf6096
Commit
1adf6096
authored
Jun 24, 2021
by
Shucai Xiao
Browse files
add scatter op
parent
e00479af
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
71 additions
and
0 deletions
+71
-0
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/include/migraphx/op/scatter.hpp
src/include/migraphx/op/scatter.hpp
+68
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/onnx/parse_generic_op.cpp
src/onnx/parse_generic_op.cpp
+1
-0
No files found.
src/CMakeLists.txt
View file @
1adf6096
...
@@ -144,6 +144,7 @@ register_migraphx_ops(
...
@@ -144,6 +144,7 @@ register_migraphx_ops(
round
round
rsqrt
rsqrt
scalar
scalar
scatter
sigmoid
sigmoid
sign
sign
sinh
sinh
...
...
src/include/migraphx/op/scatter.hpp
0 → 100755
View file @
1adf6096
#ifndef MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#define MIGRAPHX_GUARD_OPERATORS_SCATTER_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
scatter
{
int64_t
axis
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axis
,
"axis"
));
}
value
attributes
()
const
{
value
normalize
;
normalize
[
"axis"
]
=
value
::
array
{
normalize_attribute
::
include_min
};
return
{{
"normalize_axes"
,
normalize
}};
}
std
::
string
name
()
const
{
return
"scatter"
;
}
shape
normalize_compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
standard
();
return
inputs
.
front
();
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
// max dimension in axis
visit_all
(
result
,
args
[
0
],
args
[
2
])([
&
](
auto
output
,
auto
data
,
auto
update
)
{
std
::
copy
(
data
.
begin
(),
data
.
end
(),
output
.
begin
());
args
[
1
].
visit
([
&
](
auto
indices
)
{
auto
ind_s
=
indices
.
get_shape
();
shape_for_each
(
ind_s
,
[
&
](
const
auto
&
idx
)
{
auto
out_idx
=
idx
;
out_idx
[
axis
]
=
indices
[
ind_s
.
index
(
idx
)];
output
[
output_shape
.
index
(
out_idx
)]
=
update
[
ind_s
.
index
(
idx
)];
});
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
1adf6096
...
@@ -80,6 +80,7 @@
...
@@ -80,6 +80,7 @@
#include <migraphx/op/round.hpp>
#include <migraphx/op/round.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/rsqrt.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scalar.hpp>
#include <migraphx/op/scatter.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sigmoid.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sign.hpp>
#include <migraphx/op/sinh.hpp>
#include <migraphx/op/sinh.hpp>
...
...
src/onnx/parse_generic_op.cpp
View file @
1adf6096
...
@@ -35,6 +35,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
...
@@ -35,6 +35,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
{
"Reciprocal"
,
"recip"
},
{
"Reciprocal"
,
"recip"
},
{
"Relu"
,
"relu"
},
{
"Relu"
,
"relu"
},
{
"Round"
,
"round"
},
{
"Round"
,
"round"
},
{
"Scatter"
,
"scatter"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sigmoid"
,
"sigmoid"
},
{
"Sign"
,
"sign"
},
{
"Sign"
,
"sign"
},
{
"Sin"
,
"sin"
},
{
"Sin"
,
"sin"
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment