Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
079ccd40
Commit
079ccd40
authored
Jul 08, 2019
by
Shucai Xiao
Browse files
fix review comments
parent
2b8daf9c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
45 deletions
+14
-45
src/include/migraphx/op/reduce_mean.hpp
src/include/migraphx/op/reduce_mean.hpp
+2
-2
src/include/migraphx/op/reduce_sum.hpp
src/include/migraphx/op/reduce_sum.hpp
+2
-2
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+10
-41
No files found.
src/include/migraphx/op/reduce_mean.hpp
View file @
079ccd40
...
@@ -14,7 +14,7 @@ namespace op {
...
@@ -14,7 +14,7 @@ namespace op {
struct
reduce_mean
struct
reduce_mean
{
{
std
::
vector
<
int64
_t
>
axes
{};
std
::
vector
<
std
::
size
_t
>
axes
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -31,7 +31,7 @@ struct reduce_mean
...
@@ -31,7 +31,7 @@ struct reduce_mean
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
for
(
auto
axis
:
axes
)
for
(
auto
axis
:
axes
)
{
{
if
(
axis
<
0
or
axis
>=
lens
.
size
())
if
(
axis
>=
lens
.
size
())
MIGRAPHX_THROW
(
"REDUCE_MEAN: axis out of range"
);
MIGRAPHX_THROW
(
"REDUCE_MEAN: axis out of range"
);
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
}
}
...
...
src/include/migraphx/op/reduce_sum.hpp
View file @
079ccd40
...
@@ -14,7 +14,7 @@ namespace op {
...
@@ -14,7 +14,7 @@ namespace op {
struct
reduce_sum
struct
reduce_sum
{
{
std
::
vector
<
int64
_t
>
axes
{};
std
::
vector
<
std
::
size
_t
>
axes
{};
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -31,7 +31,7 @@ struct reduce_sum
...
@@ -31,7 +31,7 @@ struct reduce_sum
auto
lens
=
s
.
lens
();
auto
lens
=
s
.
lens
();
for
(
auto
axis
:
axes
)
for
(
auto
axis
:
axes
)
{
{
if
(
axis
<
0
or
axis
>=
lens
.
size
())
if
(
axis
>=
lens
.
size
())
MIGRAPHX_THROW
(
"REDUCE_SUM: axis out of range"
);
MIGRAPHX_THROW
(
"REDUCE_SUM: axis out of range"
);
lens
[
axis
]
=
1
;
lens
[
axis
]
=
1
;
}
}
...
...
src/onnx/onnx.cpp
View file @
079ccd40
...
@@ -95,8 +95,8 @@ struct onnx_parser
...
@@ -95,8 +95,8 @@ struct onnx_parser
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"LSTM"
,
&
onnx_parser
::
parse_lstm
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"Pad"
,
&
onnx_parser
::
parse_pad
);
add_mem_op
(
"ReduceSum"
,
&
onnx_parser
::
parse_reduce_sum
);
add_mem_op
(
"ReduceSum"
,
&
onnx_parser
::
parse_reduce_
oper
<
op
::
reduce_
sum
>
);
add_mem_op
(
"ReduceMean"
,
&
onnx_parser
::
parse_reduce_mean
);
add_mem_op
(
"ReduceMean"
,
&
onnx_parser
::
parse_reduce_
oper
<
op
::
reduce_
mean
>
);
// init the activation function map
// init the activation function map
init_actv_func
();
init_actv_func
();
...
@@ -1288,20 +1288,21 @@ struct onnx_parser
...
@@ -1288,20 +1288,21 @@ struct onnx_parser
return
{
hidden_states
,
last_output
,
last_cell_output
};
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
}
instruction_ref
parse_reduce_sum
(
const
std
::
string
&
,
template
<
class
T
>
instruction_ref
parse_reduce_oper
(
const
std
::
string
&
,
attribute_map
attributes
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
{
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
// default to reduce over all dimensions
// default to reduce over all dimensions
std
::
vector
<
int64
_t
>
axes
(
n_dim
);
std
::
vector
<
std
::
size
_t
>
axes
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
contains
(
attributes
,
"axes"
))
if
(
contains
(
attributes
,
"axes"
))
{
{
axes
.
clear
();
axes
.
clear
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
axes
=
std
::
vector
<
int64
_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
axes
=
std
::
vector
<
std
::
size
_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
}
}
int
keep_dims
=
1
;
int
keep_dims
=
1
;
...
@@ -1312,45 +1313,13 @@ struct onnx_parser
...
@@ -1312,45 +1313,13 @@ struct onnx_parser
if
(
keep_dims
==
1
)
if
(
keep_dims
==
1
)
{
{
return
prog
.
add_instruction
(
op
::
reduce_sum
{
axes
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
T
{
axes
},
std
::
move
(
args
));
}
}
else
else
{
{
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_sum
{
axes
},
std
::
move
(
args
));
auto
ins
=
prog
.
add_instruction
(
T
{
axes
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{
axes
},
ins
);
std
::
vector
<
int64_t
>
sq_axes
(
axes
.
begin
(),
axes
.
end
());
}
return
prog
.
add_instruction
(
op
::
squeeze
{
sq_axes
},
ins
);
}
instruction_ref
parse_reduce_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
std
::
size_t
n_dim
=
args
.
front
()
->
get_shape
().
lens
().
size
();
// default to reduce over all dimensions
std
::
vector
<
int64_t
>
axes
(
n_dim
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
if
(
contains
(
attributes
,
"axes"
))
{
axes
.
clear
();
auto
&&
attr_axes
=
attributes
[
"axes"
].
ints
();
axes
=
std
::
vector
<
int64_t
>
(
attr_axes
.
begin
(),
attr_axes
.
end
());
}
int
keep_dims
=
1
;
if
(
contains
(
attributes
,
"keepdims"
))
{
keep_dims
=
parse_value
(
attributes
.
at
(
"keepdims"
)).
at
<
int
>
();
}
if
(
keep_dims
==
1
)
{
return
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
std
::
move
(
args
));
}
else
{
auto
ins
=
prog
.
add_instruction
(
op
::
reduce_mean
{
axes
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
squeeze
{
axes
},
ins
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment