Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b4d2a740
Unverified
Commit
b4d2a740
authored
Aug 02, 2018
by
Paul Fultz II
Committed by
GitHub
Aug 02, 2018
Browse files
Merge pull request #28 from ROCmSoftwarePlatform/batchnorm_onnx
Batchnorm onnx
parents
6bb6b72e
a8dd3210
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
1 deletion
+58
-1
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+4
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+30
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+24
-0
No files found.
src/include/migraph/operators.hpp
View file @
b4d2a740
...
@@ -20,7 +20,8 @@ struct not_computable
...
@@ -20,7 +20,8 @@ struct not_computable
struct
batch_norm_inference
struct
batch_norm_inference
{
{
double
epsilon
=
1.0e-6
;
float
epsilon
=
1.0e-6
f
;
float
momentum
=
0.9
f
;
std
::
string
name
()
const
{
return
"batch_norm_inference"
;
}
std
::
string
name
()
const
{
return
"batch_norm_inference"
;
}
...
@@ -32,6 +33,8 @@ struct batch_norm_inference
...
@@ -32,6 +33,8 @@ struct batch_norm_inference
bn_infer_mode_t
bn_mode
=
spatial
;
bn_infer_mode_t
bn_mode
=
spatial
;
bool
is_test
=
false
;
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
check_shapes
{
inputs
,
*
this
}.
has
(
5
);
...
...
src/onnx/onnx.cpp
View file @
b4d2a740
...
@@ -62,6 +62,7 @@ struct onnx_parser
...
@@ -62,6 +62,7 @@ struct onnx_parser
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"Conv"
,
&
onnx_parser
::
parse_conv
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
onnx_parser
::
parse_pooling
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
onnx_parser
::
parse_reshape
);
add_mem_op
(
"BatchNormalization"
,
&
onnx_parser
::
parse_batchnorm
);
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -167,6 +168,35 @@ struct onnx_parser
...
@@ -167,6 +168,35 @@ struct onnx_parser
return
prog
.
add_literal
(
v
);
return
prog
.
add_literal
(
v
);
}
}
instruction_ref
parse_batchnorm
(
std
::
string
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
float
epsilon
=
1e-5
f
;
float
momentum
=
0.9
f
;
batch_norm_inference
::
bn_infer_mode_t
bn_mode
=
batch_norm_inference
::
spatial
;
bool
is_test
=
false
;
if
(
contains
(
attributes
,
"epsilon"
))
{
epsilon
=
parse_value
(
attributes
.
at
(
"epsilon"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"momentum"
))
{
epsilon
=
parse_value
(
attributes
.
at
(
"momentum"
)).
at
<
float
>
();
}
if
(
contains
(
attributes
,
"is_test"
))
{
is_test
=
parse_value
(
attributes
.
at
(
"is_test"
)).
at
<
uint64_t
>
()
>
0
;
}
if
(
contains
(
attributes
,
"spatial"
))
{
bn_mode
=
(
parse_value
(
attributes
.
at
(
"spatial"
)).
at
<
uint64_t
>
()
>
0
)
?
batch_norm_inference
::
spatial
:
batch_norm_inference
::
per_activation
;
}
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
,
is_test
};
return
prog
.
add_instruction
(
op
,
args
);
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
onnx
::
ModelProto
model
;
onnx
::
ModelProto
model
;
...
...
test/onnx/onnx_test.cpp
View file @
b4d2a740
...
@@ -39,6 +39,29 @@ void pytorch_conv_relu_maxpool()
...
@@ -39,6 +39,29 @@ void pytorch_conv_relu_maxpool()
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
void
pytorch_conv_bn_relu_maxpool
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
auto
p3
=
p
.
add_parameter
(
"3"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
auto
p4
=
p
.
add_parameter
(
"4"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
auto
p5
=
p
.
add_parameter
(
"5"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
auto
p6
=
p
.
add_parameter
(
"6"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
auto
l5
=
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
l6
=
p
.
add_instruction
(
migraph
::
batch_norm_inference
{},
l5
,
p3
,
p4
,
p5
,
p6
);
auto
l7
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l6
);
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l7
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_bn_relu_maxpool.onnx"
);
EXPECT
(
p
==
prog
);
}
void
pytorch_conv_relu_maxpoolX2
()
void
pytorch_conv_relu_maxpoolX2
()
{
{
migraph
::
program
p
;
migraph
::
program
p
;
...
@@ -69,5 +92,6 @@ int main()
...
@@ -69,5 +92,6 @@ int main()
{
{
pytorch_conv_bias_test
();
pytorch_conv_bias_test
();
pytorch_conv_relu_maxpool
();
pytorch_conv_relu_maxpool
();
pytorch_conv_bn_relu_maxpool
();
pytorch_conv_relu_maxpoolX2
();
pytorch_conv_relu_maxpoolX2
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment