Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0856b6e2
Commit
0856b6e2
authored
Aug 15, 2018
by
Paul
Browse files
Fix flatten operator
parent
33212f8f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
31 deletions
+17
-31
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+7
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+8
-27
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+2
-1
No files found.
src/include/migraph/operators.hpp
View file @
0856b6e2
...
@@ -427,17 +427,21 @@ struct flatten
...
@@ -427,17 +427,21 @@ struct flatten
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
if
(
axis
==
0
)
if
(
axis
==
0
)
{
{
return
{
inputs
.
at
(
0
).
type
(),
{
1
,
inputs
.
at
(
0
).
elements
()}};
return
{
inputs
.
at
(
0
).
type
(),
{
1
,
inputs
.
at
(
0
).
elements
()}};
}
}
if
(
axis
==
1
)
else
if
(
axis
<
lens
.
size
()
)
{
{
return
{
inputs
.
at
(
0
).
type
(),
{
inputs
.
at
(
0
).
elements
(),
1
}};
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
else
else
{
{
MIGRAPH_THROW
(
"axis for flatten
can only be either 0 or 1
"
);
MIGRAPH_THROW
(
"axis for flatten
must be less than tensor rank
"
);
}
}
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/onnx/onnx.cpp
View file @
0856b6e2
...
@@ -60,8 +60,8 @@ struct onnx_parser
...
@@ -60,8 +60,8 @@ struct onnx_parser
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_
max_
pooling
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_
average_
pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
...
@@ -129,28 +129,9 @@ struct onnx_parser
...
@@ -129,28 +129,9 @@ struct onnx_parser
}
}
instruction_ref
instruction_ref
parse_
max_
pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_pooling
(
std
::
string
name
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
pooling
op
{
"max"
};
pooling
op
{
name
==
"MaxPool"
?
"max"
:
"average"
};
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_average_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"average"
};
if
(
contains
(
attributes
,
"pads"
))
if
(
contains
(
attributes
,
"pads"
))
{
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
...
@@ -187,10 +168,10 @@ struct onnx_parser
...
@@ -187,10 +168,10 @@ struct onnx_parser
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
//
if(contains(attributes, "axis"))
if
(
contains
(
attributes
,
"axis"
))
//
{
{
//
axis = parse_value(attributes.at("axis")).at<int>();
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
//
}
}
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
0856b6e2
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/auto_contiguous.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
cpu
{
namespace
cpu
{
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
std
::
vector
<
pass
>
cpu_target
::
get_passes
(
context
&
)
const
{
return
{
cpu_lowering
{}};
}
std
::
vector
<
pass
>
cpu_target
::
get_passes
(
context
&
)
const
{
return
{
auto_contiguous
{},
cpu_lowering
{}};
}
}
// namespace cpu
}
// namespace cpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment