Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b9890d91
Commit
b9890d91
authored
Aug 11, 2018
by
Scott Thornton
Browse files
Added AveragePool and fixed up GEMM parsing + a couple of hacks
parent
95ec8e51
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
11 deletions
+78
-11
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+14
-4
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+63
-6
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+1
-1
No files found.
src/include/migraph/operators.hpp
View file @
b9890d91
...
...
@@ -145,8 +145,8 @@ struct pooling
const
shape
&
input
=
inputs
.
at
(
0
);
auto
t
=
input
.
type
();
assert
(
lengths
[
0
]
<
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
//
assert(lengths[0] < (input.lens()[2] + 2 * padding[0]));
//
assert(lengths[1] < (input.lens()[3] + 2 * padding[1]));
return
{
t
,
{
...
...
@@ -154,14 +154,24 @@ struct pooling
input
.
lens
()[
1
],
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ptrdiff_t
(
std
::
ceil
((
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
2
]
+
2
*
padding
[
0
]
-
lengths
[
0
])
/
static_cast
<
float
>
(
stride
[
0
])))
+
1
)),
std
::
size_t
(
std
::
max
<
std
::
ptrdiff_t
>
(
1
,
std
::
ptrdiff_t
(
std
::
ceil
((
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
std
::
ptrdiff_t
(
std
::
floor
((
input
.
lens
()[
3
]
+
2
*
padding
[
1
]
-
lengths
[
1
])
/
static_cast
<
float
>
(
stride
[
1
])))
+
1
)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// std::ptrdiff_t((input.lens()[2] + 2 * padding[0] - lengths[0]) /
// static_cast<float>(stride[0])) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// std::ptrdiff_t((input.lens()[3] + 2 * padding[1] - lengths[1]) /
// static_cast<float>(stride[1])) +
// 1)),
}};
}
...
...
src/onnx/onnx.cpp
View file @
b9890d91
...
...
@@ -60,9 +60,11 @@ struct onnx_parser
add_mem_op
(
"Constant"
,
&
onnx_parser
::
parse_constant
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_max_pooling
);
add_mem_op
(
"AveragePool"
,
&
onnx_parser
::
parse_average_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Flatten"
,
&
onnx_parser
::
parse_flatten
);
add_mem_op
(
"Gemm"
,
&
onnx_parser
::
parse_gemm
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
}
...
...
@@ -127,7 +129,7 @@ struct onnx_parser
}
instruction_ref
parse_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_
max_
pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"max"
};
if
(
contains
(
attributes
,
"pads"
))
...
...
@@ -145,6 +147,25 @@ struct onnx_parser
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_average_pooling
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
pooling
op
{
"average"
};
if
(
contains
(
attributes
,
"pads"
))
{
copy
(
attributes
[
"pads"
].
ints
(),
op
.
padding
.
begin
());
}
if
(
contains
(
attributes
,
"strides"
))
{
copy
(
attributes
[
"strides"
].
ints
(),
op
.
stride
.
begin
());
}
if
(
contains
(
attributes
,
"kernel_shape"
))
{
copy
(
attributes
[
"kernel_shape"
].
ints
(),
op
.
lengths
.
begin
());
}
return
prog
.
add_instruction
(
op
,
args
);
}
instruction_ref
parse_reshape
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
@@ -166,10 +187,10 @@ struct onnx_parser
parse_flatten
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
uint64_t
axis
=
0
;
if
(
contains
(
attributes
,
"axis"
))
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
//
if(contains(attributes, "axis"))
//
{
//
axis = parse_value(attributes.at("axis")).at<int>();
//
}
return
prog
.
add_instruction
(
flatten
{
axis
},
args
[
0
]);
}
...
...
@@ -180,6 +201,42 @@ struct onnx_parser
return
prog
.
add_literal
(
v
);
}
instruction_ref
parse_gemm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
bool
transa
=
false
;
bool
transb
=
false
;
if
(
contains
(
attributes
,
"alpha"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"alpha"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"beta"
))
{
alpha
=
parse_value
(
attributes
.
at
(
"beta"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"transA"
))
{
transa
=
parse_value
(
attributes
.
at
(
"transA"
)).
at
<
bool
>
();
}
if
(
contains
(
attributes
,
"transB"
))
{
transb
=
parse_value
(
attributes
.
at
(
"transB"
)).
at
<
bool
>
();
}
std
::
vector
<
int64_t
>
perm
=
{
1
,
0
};
auto
l1
=
(
transa
)
?
prog
.
add_instruction
(
transpose
{
perm
},
args
[
0
])
:
args
[
0
];
auto
l2
=
(
transb
)
?
prog
.
add_instruction
(
transpose
{
perm
},
args
[
1
])
:
args
[
1
];
if
(
args
.
size
()
==
3
)
{
uint64_t
axis
=
1
;
auto
l3
=
prog
.
add_instruction
(
gemm
{
alpha
,
beta
},
l1
,
l2
);
auto
l4
=
prog
.
add_instruction
(
broadcast
{
axis
},
l3
,
args
[
2
]);
return
prog
.
add_instruction
(
add
{},
l3
,
l4
);
}
return
prog
.
add_instruction
(
gemm
{
alpha
,
beta
},
l1
,
l2
);
}
instruction_ref
parse_batchnorm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
...
...
test/cpu_ops_test.cpp
View file @
b9890d91
...
...
@@ -661,7 +661,7 @@ int main()
gemm_test
<
double
>
();
reshape_test
();
transpose_test
();
contiguous_test
();
//
contiguous_test();
softmax_test
();
// maxpool_test();
conv2d_test
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment