Commit 344cf591 authored by Paul's avatar Paul
Browse files

Formatting

parent dabb2049
...@@ -1493,8 +1493,10 @@ struct onnx_parser ...@@ -1493,8 +1493,10 @@ struct onnx_parser
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::FLOAT16: return create_literal(shape::half_type, dims, s.data()); case onnx::TensorProto::FLOAT16:
case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, s.data()); return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
...@@ -1517,10 +1519,14 @@ struct onnx_parser ...@@ -1517,10 +1519,14 @@ struct onnx_parser
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, t.int32_data()); case onnx::TensorProto::BOOL:
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data()); return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, t.double_data()); case onnx::TensorProto::INT64:
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data()); return create_literal(shape::int64_type, dims, t.int64_data());
case onnx::TensorProto::DOUBLE:
return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1579,7 +1585,8 @@ struct onnx_parser ...@@ -1579,7 +1585,8 @@ struct onnx_parser
case onnx::TensorProto::BOOL: case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED: case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: break; // throw std::runtime_error("Unsupported type"); case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
......
...@@ -840,9 +840,7 @@ struct tf_parser ...@@ -840,9 +840,7 @@ struct tf_parser
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break; case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break; case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break; case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64: case tensorflow::DataType::DT_UINT64: shape_type = shape::uint64_type; break;
shape_type = shape::uint64_type;
break;
case tensorflow::DataType::DT_INVALID: case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8: case tensorflow::DataType::DT_UINT8:
...@@ -882,7 +880,7 @@ struct tf_parser ...@@ -882,7 +880,7 @@ struct tf_parser
case tensorflow::DataType::DT_VARIANT_REF: case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF: case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF: case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:\ case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
} }
return shape_type; return shape_type;
...@@ -1025,7 +1023,8 @@ struct tf_parser ...@@ -1025,7 +1023,8 @@ struct tf_parser
case tensorflow::DataType::DT_UINT32_REF: case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF: case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: throw std::runtime_error(""); case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment