Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
079ccd40
Commit
079ccd40
authored
Jul 08, 2019
by
Shucai Xiao
Browse files
fix review comments
parent
2b8daf9c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
45 deletions
+14
-45
src/include/migraphx/op/reduce_mean.hpp
src/include/migraphx/op/reduce_mean.hpp
+2
-2
src/include/migraphx/op/reduce_sum.hpp
src/include/migraphx/op/reduce_sum.hpp
+2
-2
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+10
-41
No files found.
src/include/migraphx/op/reduce_mean.hpp
View file @
079ccd40
...
@@ -14,7 +14,7 @@ namespace op {
...
@@ -14,7 +14,7 @@ namespace op {
struct
reduce_mean
struct
reduce_mean
{
{
std
::
vector
<
int64
_t
>
axes
{};
std
::
vector
<
std
::
size
_t
>
axes
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -31,7 +31,7 @@ struct reduce_mean
...
@@ -31,7 +31,7 @@ struct reduce_mean
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
for
(
auto
axis
:
axes
)
for
(
auto
axis
:
axes
)
{
{
if
(
axis
<
0
or
axis
>=
lens
.
size
())
if
(
axis
>=
lens
.
size
())
MIGRAPHX_THROW
(
"REDUCE_MEAN: axis out of range"
);
MIGRAPHX_THROW
(
"REDUCE_MEAN: axis out of range"
);
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
}
}
...
...
src/include/migraphx/op/reduce_sum.hpp
View file @
079ccd40
...
@@ -14,7 +14,7 @@ namespace op {
...
@@ -14,7 +14,7 @@ namespace op {
struct
reduce_sum
struct
reduce_sum
{
{
std
::
vector
<
int64
_t
>
axes
{};
std
::
vector
<
std
::
size
_t
>
axes
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -31,7 +31,7 @@ struct reduce_sum
...
@@ -31,7 +31,7 @@ struct reduce_sum
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
for
(
auto
axis
:
axes
)
for
(
auto
axis
:
axes
)
{
{
if
(
axis
<
0
or
axis
>=
lens
.
size
())
if
(
axis
>=
lens
.
size
())
MIGRAPHX_THROW
(
"REDUCE_SUM: axis out of range"
);
MIGRAPHX_THROW
(
"REDUCE_SUM: axis out of range"
);
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
}
}
...
...
src/onnx/onnx.cpp
View file @
079ccd40
...
@@ -95,8 +95,8 @@ struct onnx_parser
...
@@ -95,8 +95,8 @@ struct onnx_parser
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"ReduceSum"
,
&
onnx_parser
::
parse_reduce_sum
);
add_mem_op
(
"ReduceSum"
,
&
onnx_parser
::
parse_reduce_
oper
<
op
::
reduce_
sum
>
);
add_mem_op
(
"ReduceMean"
,
&
onnx_parser
::
parse_reduce_mean
);
add_mem_op
(
"ReduceMean"
,
&
onnx_parser
::
parse_reduce_
oper
<
op
::
reduce_
mean
>
);
// init the activation function map
// init the activation function map
init_actv_func
();
init_actv_func
();
...
@@ -1288,20 +1288,21 @@ struct onnx_parser
...
@@ -1288,20 +1288,21 @@ struct onnx_parser
return
{
hidden_states
,
last_output
,
last_cell_output
};
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
}
instruction_ref
parse_reduce_sum
(
const
std
::
string
&
,
template
<
class
T
>
instruction_ref
parse_reduce_oper
(
const
std
::
string
&
,
attribute_map
attributes
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
{
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
// default to reduce over all dimensions
// default to reduce over all dimensions
std
::
vector
<
int64
_t
>
axes
(
n_dim
);
std
::
vector
<
std
::
size
_t
>
axes
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
contains
(
attributes
,
"axes"
))
if
(
contains
(
attributes
,
"axes"
))
{
{
axes
.
clear
();
axes
.
clear
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
axes
=
std
::
vector
<
int64
_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
axes
=
std
::
vector
<
std
::
size
_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
}
}
int
keep_dims
=
1
;
int
keep_dims
=
1
;
...
@@ -1312,45 +1313,13 @@ struct onnx_parser
...
@@ -1312,45 +1313,13 @@ struct onnx_parser
if
(
keep_dims
==
1
)
if
(
keep_dims
==
1
)
{
{
return
prog
.
add_instruction
(
op
::
reduce_sum
{
axes
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
T
{
axes
},
std
::
move
(
args
));
}
}
else
else
{
{
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_sum
{
axes
},
std
::
move
(
args
));
auto
ins
=
prog
.
add_instruction
(
T
{
axes
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{
axes
},
ins
);
std
::
vector
<
int64_t
>
sq_axes
(
axes
.
begin
(),
axes
.
end
());
}
return
prog
.
add_instruction
(
op
::
squeeze
{
sq_axes
},
ins
);
}
instruction_ref
parse_reduce_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
// default to reduce over all dimensions
std
::
vector
<
int64_t
>
axes
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
contains
(
attributes
,
"axes"
))
{
axes
.
clear
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
axes
=
std
::
vector
<
int64_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
}
int
keep_dims
=
1
;
if
(
contains
(
attributes
,
"keepdims"
))
{
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
if
(
keep_dims
==
1
)
{
return
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
std
::
move
(
args
));
}
else
{
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{
axes
},
ins
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment