Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e88e866
Commit
7e88e866
authored
Jul 27, 2018
by
Scott Thornton
Browse files
Added tests for onnx parsing
parent
b52e7149
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
65 deletions
+101
-65
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+20
-64
test/CMakeLists.txt
test/CMakeLists.txt
+8
-1
test/onnx/conv.onnx
test/onnx/conv.onnx
+0
-0
test/onnx/conv_bn_relu_maxpool.onnx
test/onnx/conv_bn_relu_maxpool.onnx
+0
-0
test/onnx/conv_relu_maxpool.onnx
test/onnx/conv_relu_maxpool.onnx
+0
-0
test/onnx/conv_relu_maxpoolX2.onnx
test/onnx/conv_relu_maxpoolX2.onnx
+0
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+73
-0
No files found.
src/onnx/onnx.cpp
View file @
7e88e866
...
...
@@ -318,70 +318,26 @@ struct onnx_parser
if
(
t
.
has_raw_data
())
{
std
::
string
s
=
t
.
raw_data
();
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT
)
{
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT8
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT8
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT16
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT16
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT32
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT64
)
{
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
STRING
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
BOOL
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT16
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
DOUBLE
)
{
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT32
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX128
)
{
throw
std
::
runtime_error
(
""
);
}
else
{
MIGRAPH_THROW
(
"Invalid tensor type"
);
}
switch
(
t
.
data_type
())
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
MIGRAPH_THROW
(
"Invalid tensor type"
);
}
switch
(
t
.
data_type
())
{
...
...
test/CMakeLists.txt
View file @
7e88e866
...
...
@@ -84,7 +84,7 @@ function(add_test_executable TEST_NAME)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES FAIL_REGULAR_EXPRESSION
"FAILED"
)
target_link_libraries
(
${
TEST_NAME
}
migraph migraph_cpu
)
target_link_libraries
(
${
TEST_NAME
}
migraph migraph_cpu
migraph_onnx
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
endfunction
(
add_test_executable
)
...
...
@@ -105,3 +105,10 @@ if(MIGRAPH_ENABLE_GPU)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraph_gpu
)
endforeach
()
endif
()
add_executable
(
test_onnx onnx/onnx_test.cpp
)
target_link_libraries
(
test_onnx migraph_onnx
)
target_include_directories
(
test_onnx PUBLIC include
)
add_test
(
NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/onnx
)
add_dependencies
(
tests test_onnx
)
add_dependencies
(
check test_onnx
)
test/onnx/conv.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/conv_bn_relu_maxpool.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/conv_relu_maxpool.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/conv_relu_maxpoolX2.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/onnx_test.cpp
0 → 100644
View file @
7e88e866
#include <iostream>
#include <vector>
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/program.hpp>
#include <migraph/onnx.hpp>
#include "test.hpp"
#include "verify.hpp"
void
pytorch_conv_bias_test
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
prog
=
migraph
::
parse_onnx
(
"conv.onnx"
);
EXPECT
(
p
==
prog
);
}
void
pytorch_conv_relu_maxpool
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
auto
l5
=
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
l6
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l5
);
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l6
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_relu_maxpool.onnx"
);
EXPECT
(
p
==
prog
);
}
void
pytorch_conv_relu_maxpoolX2
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
5
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
5
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
auto
l5
=
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
l6
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l5
);
auto
l7
=
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l6
);
auto
l8
=
p
.
add_parameter
(
"3"
,
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
5
,
5
}});
auto
l9
=
p
.
add_parameter
(
"4"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
auto
l10
=
p
.
add_instruction
(
migraph
::
convolution
{},
l7
,
l8
);
auto
l11
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l10
,
l9
);
auto
l12
=
p
.
add_instruction
(
migraph
::
add
{},
l10
,
l11
);
auto
l13
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l12
);
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l13
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_relu_maxpoolX2.onnx"
);
EXPECT
(
p
==
prog
);
}
int
main
()
{
pytorch_conv_bias_test
();
pytorch_conv_relu_maxpool
();
pytorch_conv_relu_maxpoolX2
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment