Commit 7e88e866 authored by Scott Thornton's avatar Scott Thornton
Browse files

Added tests for onnx parsing

parent b52e7149
......@@ -318,70 +318,26 @@ struct onnx_parser
if(t.has_raw_data())
{
std::string s = t.raw_data();
if(t.data_type() == onnx::TensorProto::FLOAT)
{
return literal{{shape::float_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT8)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::INT8)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT16)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT16)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT32)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::INT64)
{
return literal{{shape::int64_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::STRING)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::BOOL)
{
return literal{{shape::int32_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::FLOAT16)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::DOUBLE)
{
return literal{{shape::double_type, dims}, s.data()};
}
else if(t.data_type() == onnx::TensorProto::UINT32)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::UINT64)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::COMPLEX64)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::COMPLEX128)
{
throw std::runtime_error("");
}
else
{
MIGRAPH_THROW("Invalid tensor type");
}
switch(t.data_type())
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return literal{{shape::float_type, dims}, s.data()};
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::UINT16: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::INT16: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::INT32: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::INT64: return literal{{shape::int64_type, dims}, s.data()};
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: return literal{{shape::int32_type, dims}, s.data()};
case onnx::TensorProto::FLOAT16: throw std::runtime_error("");
case onnx::TensorProto::DOUBLE: return literal{{shape::double_type, dims}, s.data()};
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
}
MIGRAPH_THROW("Invalid tensor type");
}
switch(t.data_type())
{
......
......@@ -84,7 +84,7 @@ function(add_test_executable TEST_NAME)
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
set_tests_properties(${TEST_NAME} PROPERTIES FAIL_REGULAR_EXPRESSION "FAILED")
target_link_libraries(${TEST_NAME} migraph migraph_cpu)
target_link_libraries(${TEST_NAME} migraph migraph_cpu migraph_onnx)
target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable)
......@@ -105,3 +105,10 @@ if(MIGRAPH_ENABLE_GPU)
target_link_libraries(test_gpu_${BASE_NAME} migraph_gpu)
endforeach()
endif()
add_executable(test_onnx onnx/onnx_test.cpp)
target_link_libraries(test_onnx migraph_onnx)
target_include_directories(test_onnx PUBLIC include)
add_test(NAME test_onnx COMMAND $<TARGET_FILE:test_onnx> WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/onnx)
add_dependencies(tests test_onnx)
add_dependencies(check test_onnx)
#include <iostream>
#include <vector>
#include <migraph/literal.hpp>
#include <migraph/operators.hpp>
#include <migraph/program.hpp>
#include <migraph/onnx.hpp>
#include "test.hpp"
#include "verify.hpp"
void pytorch_conv_bias_test()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
p.add_instruction(migraph::add{}, l3, l4);
auto prog = migraph::parse_onnx("conv.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpool()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {1, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {1}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::activation{"relu"}, l5);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto prog = migraph::parse_onnx("conv_relu_maxpool.onnx");
EXPECT(p == prog);
}
void pytorch_conv_relu_maxpoolX2()
{
migraph::program p;
auto l0 = p.add_parameter("0", {migraph::shape::float_type, {1, 3, 32, 32}});
auto l1 = p.add_parameter("1", {migraph::shape::float_type, {5, 3, 5, 5}});
auto l2 = p.add_parameter("2", {migraph::shape::float_type, {5}});
uint64_t axis = 1;
auto l3 = p.add_instruction(migraph::convolution{}, l0, l1);
auto l4 = p.add_instruction(migraph::broadcast{axis}, l3, l2);
auto l5 = p.add_instruction(migraph::add{}, l3, l4);
auto l6 = p.add_instruction(migraph::activation{"relu"}, l5);
auto l7 = p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l6);
auto l8 = p.add_parameter("3", {migraph::shape::float_type, {1, 5, 5, 5}});
auto l9 = p.add_parameter("4", {migraph::shape::float_type, {1}});
auto l10 = p.add_instruction(migraph::convolution{}, l7, l8);
auto l11 = p.add_instruction(migraph::broadcast{axis}, l10, l9);
auto l12 = p.add_instruction(migraph::add{}, l10, l11);
auto l13 = p.add_instruction(migraph::activation{"relu"}, l12);
p.add_instruction(migraph::pooling{"max", {{0, 0}}, {{2, 2}}, {{2, 2}}}, l13);
auto prog = migraph::parse_onnx("conv_relu_maxpoolX2.onnx");
EXPECT(p == prog);
}
int main()
{
pytorch_conv_bias_test();
pytorch_conv_relu_maxpool();
pytorch_conv_relu_maxpoolX2();
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment