Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7e88e866
Commit
7e88e866
authored
Jul 27, 2018
by
Scott Thornton
Browse files
Added tests for onnx parsing
parent
b52e7149
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
65 deletions
+101
-65
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+20
-64
test/CMakeLists.txt
test/CMakeLists.txt
+8
-1
test/onnx/conv.onnx
test/onnx/conv.onnx
+0
-0
test/onnx/conv_bn_relu_maxpool.onnx
test/onnx/conv_bn_relu_maxpool.onnx
+0
-0
test/onnx/conv_relu_maxpool.onnx
test/onnx/conv_relu_maxpool.onnx
+0
-0
test/onnx/conv_relu_maxpoolX2.onnx
test/onnx/conv_relu_maxpoolX2.onnx
+0
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+73
-0
No files found.
src/onnx/onnx.cpp
View file @
7e88e866
...
@@ -318,70 +318,26 @@ struct onnx_parser
...
@@ -318,70 +318,26 @@ struct onnx_parser
if
(
t
.
has_raw_data
())
if
(
t
.
has_raw_data
())
{
{
std
::
string
s
=
t
.
raw_data
();
std
::
string
s
=
t
.
raw_data
();
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT
)
switch
(
t
.
data_type
())
{
{
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
}
case
onnx
::
TensorProto
::
FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT8
)
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
{
case
onnx
::
TensorProto
::
INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
case
onnx
::
TensorProto
::
INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT8
)
case
onnx
::
TensorProto
::
INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
{
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
}
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT16
)
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
{
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
}
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT16
)
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
{
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
}
MIGRAPH_THROW
(
"Invalid tensor type"
);
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT32
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
INT64
)
{
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
STRING
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
BOOL
)
{
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
FLOAT16
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
DOUBLE
)
{
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT32
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
UINT64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX64
)
{
throw
std
::
runtime_error
(
""
);
}
else
if
(
t
.
data_type
()
==
onnx
::
TensorProto
::
COMPLEX128
)
{
throw
std
::
runtime_error
(
""
);
}
else
{
MIGRAPH_THROW
(
"Invalid tensor type"
);
}
}
}
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
...
...
test/CMakeLists.txt
View file @
7e88e866
...
@@ -84,7 +84,7 @@ function(add_test_executable TEST_NAME)
...
@@ -84,7 +84,7 @@ function(add_test_executable TEST_NAME)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES FAIL_REGULAR_EXPRESSION
"FAILED"
)
set_tests_properties
(
${
TEST_NAME
}
PROPERTIES FAIL_REGULAR_EXPRESSION
"FAILED"
)
target_link_libraries
(
${
TEST_NAME
}
migraph migraph_cpu
)
target_link_libraries
(
${
TEST_NAME
}
migraph migraph_cpu
migraph_onnx
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
target_include_directories
(
${
TEST_NAME
}
PUBLIC include
)
endfunction
(
add_test_executable
)
endfunction
(
add_test_executable
)
...
@@ -105,3 +105,10 @@ if(MIGRAPH_ENABLE_GPU)
...
@@ -105,3 +105,10 @@ if(MIGRAPH_ENABLE_GPU)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraph_gpu
)
target_link_libraries
(
test_gpu_
${
BASE_NAME
}
migraph_gpu
)
endforeach
()
endforeach
()
endif
()
endif
()
add_executable
(
test_onnx onnx/onnx_test.cpp
)
target_link_libraries
(
test_onnx migraph_onnx
)
target_include_directories
(
test_onnx PUBLIC include
)
add_test
(
NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/onnx
)
add_dependencies
(
tests test_onnx
)
add_dependencies
(
check test_onnx
)
test/onnx/conv.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/conv_bn_relu_maxpool.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/conv_relu_maxpool.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/conv_relu_maxpoolX2.onnx
0 → 100644
View file @
7e88e866
File added
test/onnx/onnx_test.cpp
0 → 100644
View file @
7e88e866
#include <iostream>
#include <vector>
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/program.hpp>
#include <migraph/onnx.hpp>
#include "test.hpp"
#include "verify.hpp"
void
pytorch_conv_bias_test
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
prog
=
migraph
::
parse_onnx
(
"conv.onnx"
);
EXPECT
(
p
==
prog
);
}
void
pytorch_conv_relu_maxpool
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
auto
l5
=
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
l6
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l5
);
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l6
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_relu_maxpool.onnx"
);
EXPECT
(
p
==
prog
);
}
void
pytorch_conv_relu_maxpoolX2
()
{
migraph
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
32
,
32
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
{
migraph
::
shape
::
float_type
,
{
5
,
3
,
5
,
5
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
{
migraph
::
shape
::
float_type
,
{
5
}});
uint64_t
axis
=
1
;
auto
l3
=
p
.
add_instruction
(
migraph
::
convolution
{},
l0
,
l1
);
auto
l4
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l3
,
l2
);
auto
l5
=
p
.
add_instruction
(
migraph
::
add
{},
l3
,
l4
);
auto
l6
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l5
);
auto
l7
=
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l6
);
auto
l8
=
p
.
add_parameter
(
"3"
,
{
migraph
::
shape
::
float_type
,
{
1
,
5
,
5
,
5
}});
auto
l9
=
p
.
add_parameter
(
"4"
,
{
migraph
::
shape
::
float_type
,
{
1
}});
auto
l10
=
p
.
add_instruction
(
migraph
::
convolution
{},
l7
,
l8
);
auto
l11
=
p
.
add_instruction
(
migraph
::
broadcast
{
axis
},
l10
,
l9
);
auto
l12
=
p
.
add_instruction
(
migraph
::
add
{},
l10
,
l11
);
auto
l13
=
p
.
add_instruction
(
migraph
::
activation
{
"relu"
},
l12
);
p
.
add_instruction
(
migraph
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
2
,
2
}}},
l13
);
auto
prog
=
migraph
::
parse_onnx
(
"conv_relu_maxpoolX2.onnx"
);
EXPECT
(
p
==
prog
);
}
int
main
()
{
pytorch_conv_bias_test
();
pytorch_conv_relu_maxpool
();
pytorch_conv_relu_maxpoolX2
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment