Commit 74343e23 authored by Khalique's avatar Khalique
Browse files

adjusted create_literal for onnx to use container/range

parent b3743835
...@@ -1393,27 +1393,27 @@ struct onnx_parser ...@@ -1393,27 +1393,27 @@ struct onnx_parser
case onnx::TensorProto::UNDEFINED: throw std::runtime_error(""); case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: case onnx::TensorProto::FLOAT:
return create_literal( return create_literal(
shape::float_type, dims, t.float_data().begin(), t.float_data().end()); shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return create_literal( return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return create_literal( return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return create_literal( return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return create_literal( return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64:
return create_literal( return create_literal(
shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end()); shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
return create_literal( return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end()); shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1422,11 +1422,11 @@ struct onnx_parser ...@@ -1422,11 +1422,11 @@ struct onnx_parser
data_uint16.end(), data_uint16.end(),
std::back_inserter(data_half), std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half.begin(), data_half.end()); return create_literal(shape::half_type, dims, data_half);
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::DOUBLE:
return create_literal( return create_literal(
shape::double_type, dims, t.double_data().begin(), t.double_data().end()); shape::double_type, dims, t.double_data());
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
...@@ -1436,7 +1436,7 @@ struct onnx_parser ...@@ -1436,7 +1436,7 @@ struct onnx_parser
} }
static literal static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data) create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{ {
// in case of scalar constants in onnx file, use dims=1 to fill initializer data // in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty()) if(dims.empty())
...@@ -1444,13 +1444,13 @@ struct onnx_parser ...@@ -1444,13 +1444,13 @@ struct onnx_parser
return literal{{shape_type, dims}, data}; return literal{{shape_type, dims}, data};
} }
template <class Iterator> template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end) create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{ {
if(dims.empty()) if(dims.empty())
return literal{{shape_type, {1}, {0}}, start, end}; return literal{{shape_type, {1}, {0}}, data.begin(), data.end()};
return literal{{shape_type, dims}, start, end}; return literal{{shape_type, dims}, data.begin(), data.end()};
} }
static shape parse_type(const onnx::TypeProto& t) static shape parse_type(const onnx::TypeProto& t)
......
...@@ -843,7 +843,7 @@ struct tf_parser ...@@ -843,7 +843,7 @@ struct tf_parser
template <class T> template <class T>
static literal static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data) create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{ {
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar // assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1)) if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment