Commit 8c2d316e authored by Scott Thornton's avatar Scott Thornton
Browse files

Able to read raw_data from onnx (at least in the case of Reshape)

parent 71c777bd
......@@ -6,6 +6,7 @@
#include <unordered_map>
#include <functional>
#include <array>
#include <vector>
#include <migraph/fallthrough.hpp>
#include <migraph/program.hpp>
......@@ -314,6 +315,98 @@ struct onnx_parser
static literal parse_tensor(const onnx::TensorProto& t)
{
std::vector<std::size_t> dims(t.dims().begin(), t.dims().end());
if(t.has_raw_data())
{
std::string s = t.raw_data();
if(t.data_type() == onnx::TensorProto::FLOAT)
{
std::vector<float> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::float_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::UINT8)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::INT8)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::UINT16)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::INT16)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::INT32)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::INT64)
{
std::vector<int64_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int64_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::STRING)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::BOOL)
{
std::vector<int32_t> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::int32_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::FLOAT16)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::DOUBLE)
{
std::vector<double> raw(
std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<size_t>()));
memcpy(raw.data(), s.data(), s.length());
return literal{{shape::double_type, dims}, raw};
}
else if(t.data_type() == onnx::TensorProto::UINT32)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::UINT64)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::COMPLEX64)
{
throw std::runtime_error("");
}
else if(t.data_type() == onnx::TensorProto::COMPLEX128)
{
throw std::runtime_error("");
}
else
{
MIGRAPH_THROW("Invalid tensor type");
}
}
switch(t.data_type())
{
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment