Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0856b6e2
Commit
0856b6e2
authored
Aug 15, 2018
by
Paul
Browse files
Fix flatten operator
parent
33212f8f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
31 deletions
+17
-31
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+7
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+8
-27
src/targets/cpu/cpu_target.cpp
src/targets/cpu/cpu_target.cpp
+2
-1
No files found.
src/include/migraph/operators.hpp
View file @
0856b6e2
...
@@ -427,17 +427,21 @@ struct flatten
...
@@ -427,17 +427,21 @@ struct flatten
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
}.
has
(
1
);
check_shapes
{
inputs
}.
has
(
1
);
auto
&&
lens
=
inputs
.
front
().
lens
();
if
(
axis
==
0
)
if
(
axis
==
0
)
{
{
return
{
inputs
.
at
(
0
).
type
(),
{
1
,
inputs
.
at
(
0
).
elements
()}};
return
{
inputs
.
at
(
0
).
type
(),
{
1
,
inputs
.
at
(
0
).
elements
()}};
}
}
if
(
axis
==
1
)
else
if
(
axis
<
lens
.
size
()
)
{
{
return
{
inputs
.
at
(
0
).
type
(),
{
inputs
.
at
(
0
).
elements
(),
1
}};
auto
x
=
std
::
accumulate
(
lens
.
begin
(),
lens
.
begin
()
+
axis
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
auto
y
=
std
::
accumulate
(
lens
.
begin
()
+
axis
,
lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
{
inputs
.
at
(
0
).
type
(),
{
x
,
y
}};
}
}
else
else
{
{
MIGRAPH_THROW
(
"axis for flatten
can only be either 0 or 1
"
);
MIGRAPH_THROW
(
"axis for flatten
must be less than tensor rank
"
);
}
}
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
...
...
src/onnx/onnx.cpp
View file @
0856b6e2
...
@@ -60,8 +60,8 @@ struct onnx_parser
...
@@ -60,8 +60,8 @@ struct onnx_parser
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_
max_
pooling
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_
average_
pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
...
@@ -129,28 +129,9 @@ struct onnx_parser
...
@@ -129,28 +129,9 @@ struct onnx_parser
}
}
instruction_ref
instruction_ref
parse_
max_
pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_pooling
(
std
::
string
name
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
pooling
op
{
"max"
};
pooling
op
{
name
==
"MaxPool"
?
"max"
:
"average"
};
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_average_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"average"
};
if
(
contains
(
attributes
,
"pads"
))
if
(
contains
(
attributes
,
"pads"
))
{
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
...
@@ -187,10 +168,10 @@ struct onnx_parser
...
@@ -187,10 +168,10 @@ struct onnx_parser
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
uint64_t
axis
=
0
;
uint64_t
axis
=
0
;
//
if(contains(attributes, "axis"))
if
(
contains
(
attributes
,
"axis"
))
//
{
{
//
axis = parse_value(attributes.at("axis")).at<int>();
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
//
}
}
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
}
}
...
...
src/targets/cpu/cpu_target.cpp
View file @
0856b6e2
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_target.hpp>
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/cpu/cpu_lowering.hpp>
#include <migraph/auto_contiguous.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
cpu
{
namespace
cpu
{
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
std
::
string
cpu_target
::
name
()
const
{
return
"cpu"
;
}
std
::
vector
<
pass
>
cpu_target
::
get_passes
(
context
&
)
const
{
return
{
cpu_lowering
{}};
}
std
::
vector
<
pass
>
cpu_target
::
get_passes
(
context
&
)
const
{
return
{
auto_contiguous
{},
cpu_lowering
{}};
}
}
// namespace cpu
}
// namespace cpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment