Commit 74343e23 authored by Khalique's avatar Khalique
Browse files

adjusted create_literal for onnx to use container/range

parent b3743835
......@@ -1393,27 +1393,27 @@ struct onnx_parser
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return create_literal(
shape::float_type, dims, t.float_data().begin(), t.float_data().end());
shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8:
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16:
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16:
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32:
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64:
return create_literal(
shape::int64_type, dims, t.int64_data().begin(), t.int64_data().end());
shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::STRING: throw std::runtime_error("");
case onnx::TensorProto::BOOL:
return create_literal(
shape::int32_type, dims, t.int32_data().begin(), t.int32_data().end());
shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::FLOAT16:
{
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
......@@ -1422,11 +1422,11 @@ struct onnx_parser
data_uint16.end(),
std::back_inserter(data_half),
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half.begin(), data_half.end());
return create_literal(shape::half_type, dims, data_half);
}
case onnx::TensorProto::DOUBLE:
return create_literal(
shape::double_type, dims, t.double_data().begin(), t.double_data().end());
shape::double_type, dims, t.double_data());
case onnx::TensorProto::UINT32: throw std::runtime_error("");
case onnx::TensorProto::UINT64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
......@@ -1436,7 +1436,7 @@ struct onnx_parser
}
static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, const char* data)
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, const char* data)
{
// in case of scalar constants in onnx file, use dims=1 to fill initializer data
if(dims.empty())
......@@ -1444,13 +1444,13 @@ struct onnx_parser
return literal{{shape_type, dims}, data};
}
template <class Iterator>
template <class T, MIGRAPHX_REQUIRES(not std::is_pointer<T>{})>
static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, Iterator start, Iterator end)
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, T data)
{
if(dims.empty())
return literal{{shape_type, {1}, {0}}, start, end};
return literal{{shape_type, dims}, start, end};
return literal{{shape_type, {1}, {0}}, data.begin(), data.end()};
return literal{{shape_type, dims}, data.begin(), data.end()};
}
static shape parse_type(const onnx::TypeProto& t)
......
......@@ -843,7 +843,7 @@ struct tf_parser
template <class T>
static literal
create_literal(shape::type_t shape_type, std::vector<size_t> dims, std::vector<T> data)
create_literal(shape::type_t shape_type, const std::vector<size_t>& dims, std::vector<T> data)
{
// assume if explicit value is mentioned in protobuf and dim size <= 1, treat as scalar
if(dims.empty() or (dims.size() == 1 and dims.front() == 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment