Commit 7e88e866 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added tests for onnx parsing

parent b52e7149
...@@ -318,71 +318,27 @@ struct onnx_parser ...@@ -318,71 +318,27 @@ struct onnx_parser
if(t.has_raw_data()) if(t.has_raw_data())
{ {
std::string s = t.raw_data(); std::string s = t.raw_data();
if(t.data_type() == onnx::TensorProto::FLOAT) switch(t.data_type())
{
return literal{{shape::float_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT8)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::INT8)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT16)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT16)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT32)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT64)
{
return literal{{shape::int64_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::STRING)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::BOOL)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::FLOAT16)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::DOUBLE)
{
return literal{{shape::double_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT32)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::UINT64)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::COMPLEX64)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::COMPLEX128)
{ {
throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()};
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
else
{
MIGRAPH_THROW("Invalid tensor type"); MIGRAPH_THROW("Invalid tensor type");
} }
}
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
......
...@@ -84,7 +84,7 @@ function(add_test_executable TEST_NAME) ...@@ -84,7 +84,7 @@ function(add_test_executable TEST_NAME)
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
set_tests_properties(${TEST_NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED") set_tests_properties(${TEST_NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED")
target_link_libraries(${TEST_NAME} migraph migraph_cpu) target_link_libraries(${TEST_NAME} migraph migraph_cpu migraph_onnx)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
...@@ -105,3 +105,10 @@ if(MIGRAPH_ENABLE_GPU) ...@@ -105,3 +105,10 @@ if(MIGRAPH_ENABLE_GPU)
target_link_libraries(test_gpu_${BASE_NAME} migraph_gpu) target_link_libraries(test_gpu_${BASE_NAME} migraph_gpu)
endforeach() endforeach()
endif() endif()
add_executable(test_onnx onnx/onnx_test.cpp)
target_link_libraries(test_onnx migraph_onnx)
target_include_directories(test_onnx PUBLIC include)
add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests test_onnx)
add_dependencies(check test_onnx)
#include <iostream>
#include <vector>
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/program.hpp>
#include <migraph/onnx.hpp>
#include "test.hpp"
#include "verify.hpp"
void pytorch_conv_bias_test()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
p.add_instruction(migraph::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpool()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::activation{"relu"}, l5);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpoolX2()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::activation{"relu"}, l5);
auto l7 = p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::broadcast{axis}, l10, l9);
auto l12 = p.add_instruction(migraph::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::activation{"relu"}, l12);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_relu_maxpoolX2();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment