Commit dabb2049 authored by Paul's avatar Paul
Browse files

Fix duplicate branches

parent ab35b581
...@@ -81,6 +81,7 @@ rocm_enable_clang_tidy( ...@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
-modernize-use-override -modernize-use-override
-modernize-pass-by-value -modernize-pass-by-value
-modernize-use-default-member-init -modernize-use-default-member-init
-modernize-use-trailing-return-type
-modernize-use-transparent-functors -modernize-use-transparent-functors
-performance-type-promotion-in-math-fn -performance-type-promotion-in-math-fn
-readability-braces-around-statements -readability-braces-around-statements
......
...@@ -1469,16 +1469,16 @@ struct onnx_parser ...@@ -1469,16 +1469,16 @@ struct onnx_parser
{ {
switch(attr.type()) switch(attr.type())
{ {
case onnx::AttributeProto::UNDEFINED: return {};
case onnx::AttributeProto::FLOAT: return literal{attr.f()}; case onnx::AttributeProto::FLOAT: return literal{attr.f()};
case onnx::AttributeProto::INT: return literal{attr.i()}; case onnx::AttributeProto::INT: return literal{attr.i()};
case onnx::AttributeProto::STRING: return {};
case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t()); case onnx::AttributeProto::TENSOR: return parse_tensor(attr.t());
case onnx::AttributeProto::GRAPH: return {};
case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats()); case onnx::AttributeProto::FLOATS: return from_repeated(shape::float_type, attr.floats());
case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints()); case onnx::AttributeProto::INTS: return from_repeated(shape::int64_type, attr.ints());
case onnx::AttributeProto::STRINGS: return {}; case onnx::AttributeProto::UNDEFINED:
case onnx::AttributeProto::TENSORS: return {}; case onnx::AttributeProto::GRAPH:
case onnx::AttributeProto::STRING:
case onnx::AttributeProto::STRINGS:
case onnx::AttributeProto::TENSORS:
case onnx::AttributeProto::GRAPHS: return {}; case onnx::AttributeProto::GRAPHS: return {};
} }
MIGRAPHX_THROW("Invalid attribute type"); MIGRAPHX_THROW("Invalid attribute type");
...@@ -1492,47 +1492,35 @@ struct onnx_parser ...@@ -1492,47 +1492,35 @@ struct onnx_parser
const std::string& s = t.raw_data(); const std::string& s = t.raw_data();
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data()); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, s.data());
case onnx::TensorProto::UINT8: throw std::runtime_error(""); case onnx::TensorProto::FLOAT16: return create_literal(shape::half_type, dims, s.data());
case onnx::TensorProto::INT8: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, s.data());
case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT16: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT32: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data()); case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, s.data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::INT8:
case onnx::TensorProto::UINT16:
case onnx::TensorProto::INT16:
case onnx::TensorProto::INT32:
case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data()); case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, s.data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::UINT8:
return create_literal(shape::half_type, dims, s.data()); case onnx::TensorProto::STRING:
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::UNDEFINED:
return create_literal(shape::double_type, dims, s.data()); case onnx::TensorProto::UINT32:
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::UINT64:
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX64: throw std::runtime_error("");
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
switch(t.data_type()) switch(t.data_type())
{ {
case onnx::TensorProto::UNDEFINED: throw std::runtime_error("");
case onnx::TensorProto::FLOAT:
return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::UINT8: throw std::runtime_error("");
case onnx::TensorProto::INT8: case onnx::TensorProto::INT8:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::UINT16: case onnx::TensorProto::UINT16:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT16: case onnx::TensorProto::INT16:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT32: case onnx::TensorProto::INT32:
return create_literal(shape::int32_type, dims, t.int32_data()); case onnx::TensorProto::BOOL: return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::INT64: case onnx::TensorProto::INT64: return create_literal(shape::int64_type, dims, t.int64_data());
return create_literal(shape::int64_type, dims, t.int64_data()); case onnx::TensorProto::DOUBLE: return create_literal(shape::double_type, dims, t.double_data());
case onnx::TensorProto::STRING: throw std::runtime_error(""); case onnx::TensorProto::FLOAT: return create_literal(shape::float_type, dims, t.float_data());
case onnx::TensorProto::BOOL:
return create_literal(shape::int32_type, dims, t.int32_data());
case onnx::TensorProto::FLOAT16: case onnx::TensorProto::FLOAT16:
{ {
std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end()); std::vector<uint16_t> data_uint16(t.int32_data().begin(), t.int32_data().end());
...@@ -1543,11 +1531,12 @@ struct onnx_parser ...@@ -1543,11 +1531,12 @@ struct onnx_parser
[](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); }); [](uint16_t raw_val) { return *reinterpret_cast<half*>(&raw_val); });
return create_literal(shape::half_type, dims, data_half); return create_literal(shape::half_type, dims, data_half);
} }
case onnx::TensorProto::DOUBLE: case onnx::TensorProto::UNDEFINED:
return create_literal(shape::double_type, dims, t.double_data()); case onnx::TensorProto::UINT8:
case onnx::TensorProto::UINT32: throw std::runtime_error(""); case onnx::TensorProto::STRING:
case onnx::TensorProto::UINT64: throw std::runtime_error(""); case onnx::TensorProto::UINT32:
case onnx::TensorProto::COMPLEX64: throw std::runtime_error(""); case onnx::TensorProto::UINT64:
case onnx::TensorProto::COMPLEX64:
case onnx::TensorProto::COMPLEX128: throw std::runtime_error(""); case onnx::TensorProto::COMPLEX128: throw std::runtime_error("");
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
...@@ -1575,28 +1564,22 @@ struct onnx_parser ...@@ -1575,28 +1564,22 @@ struct onnx_parser
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t.tensor_type().elem_type()) switch(t.tensor_type().elem_type())
{ {
case onnx::TensorProto::UNDEFINED:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break; case onnx::TensorProto::FLOAT: shape_type = shape::float_type; break;
case onnx::TensorProto::UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case onnx::TensorProto::INT8: shape_type = shape::int8_type; break; case onnx::TensorProto::INT8: shape_type = shape::int8_type; break;
case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break; case onnx::TensorProto::UINT16: shape_type = shape::uint16_type; break;
case onnx::TensorProto::INT16: shape_type = shape::int16_type; break; case onnx::TensorProto::INT16: shape_type = shape::int16_type; break;
case onnx::TensorProto::INT32: shape_type = shape::int32_type; break; case onnx::TensorProto::INT32: shape_type = shape::int32_type; break;
case onnx::TensorProto::INT64: shape_type = shape::int64_type; break; case onnx::TensorProto::INT64: shape_type = shape::int64_type; break;
case onnx::TensorProto::STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case onnx::TensorProto::BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break; case onnx::TensorProto::FLOAT16: shape_type = shape::half_type; break;
case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break; case onnx::TensorProto::DOUBLE: shape_type = shape::double_type; break;
case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break; case onnx::TensorProto::UINT32: shape_type = shape::uint32_type; break;
case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break; case onnx::TensorProto::UINT64: shape_type = shape::uint64_type; break;
case onnx::TensorProto::UINT8:
case onnx::TensorProto::STRING:
case onnx::TensorProto::BOOL:
case onnx::TensorProto::UNDEFINED:
case onnx::TensorProto::COMPLEX64: case onnx::TensorProto::COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64"); case onnx::TensorProto::COMPLEX128: break; // throw std::runtime_error("Unsupported type");
case onnx::TensorProto::COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
} }
std::vector<std::size_t> dims; std::vector<std::size_t> dims;
auto&& tensor_dims = t.tensor_type().shape().dim(); auto&& tensor_dims = t.tensor_type().shape().dim();
......
...@@ -831,72 +831,58 @@ struct tf_parser ...@@ -831,72 +831,58 @@ struct tf_parser
shape::type_t shape_type{}; shape::type_t shape_type{};
switch(t) switch(t)
{ {
case tensorflow::DataType::DT_INVALID:
break; // throw std::runtime_error("Unsupported type UNDEFINED");
case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break; case tensorflow::DataType::DT_FLOAT: shape_type = shape::float_type; break;
case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break; case tensorflow::DataType::DT_DOUBLE: shape_type = shape::double_type; break;
case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break; case tensorflow::DataType::DT_INT32: shape_type = shape::int32_type; break;
case tensorflow::DataType::DT_UINT8:
break; // throw std::runtime_error("Unsupported type UINT8");
case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break; case tensorflow::DataType::DT_INT16: shape_type = shape::int16_type; break;
case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break; case tensorflow::DataType::DT_INT8: shape_type = shape::int8_type; break;
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_STRING: case tensorflow::DataType::DT_STRING:
break; // throw std::runtime_error("Unsupported type STRING");
case tensorflow::DataType::DT_COMPLEX64: case tensorflow::DataType::DT_COMPLEX64:
break; // throw std::runtime_error("Unsupported type COMPLEX64");
case tensorflow::DataType::DT_INT64: shape_type = shape::int64_type; break;
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
break; // throw std::runtime_error("Unsupported type BOOL");
case tensorflow::DataType::DT_QINT8: case tensorflow::DataType::DT_QINT8:
break; // throw std::runtime_error("Unsupported type QINT8");
case tensorflow::DataType::DT_QUINT8: case tensorflow::DataType::DT_QUINT8:
break; // throw std::runtime_error("Unsupported type QUINT8");
case tensorflow::DataType::DT_QINT32: case tensorflow::DataType::DT_QINT32:
break; // throw std::runtime_error("Unsupported type QINT32");
case tensorflow::DataType::DT_BFLOAT16: case tensorflow::DataType::DT_BFLOAT16:
break; // throw std::runtime_error("Unsupported type BFLOAT16");
case tensorflow::DataType::DT_QINT16: case tensorflow::DataType::DT_QINT16:
break; // throw std::runtime_error("Unsupported type QINT16");
case tensorflow::DataType::DT_QUINT16: case tensorflow::DataType::DT_QUINT16:
break; // throw std::runtime_error("Unsupported type QUINT16");
case tensorflow::DataType::DT_UINT16: shape_type = shape::uint16_type; break;
case tensorflow::DataType::DT_COMPLEX128: case tensorflow::DataType::DT_COMPLEX128:
break; // throw std::runtime_error("Unsupported type COMPLEX128");
case tensorflow::DataType::DT_HALF: shape_type = shape::half_type; break;
case tensorflow::DataType::DT_RESOURCE: case tensorflow::DataType::DT_RESOURCE:
break; // throw std::runtime_error("Unsupported type RESOURCE");
case tensorflow::DataType::DT_VARIANT: case tensorflow::DataType::DT_VARIANT:
break; // throw std::runtime_error("Unsupported type VARIANT");
case tensorflow::DataType::DT_UINT32: shape_type = shape::uint32_type; break;
case tensorflow::DataType::DT_UINT64:
shape_type = shape::uint64_type;
break;
// tf pb should not use these types // tf pb should not use these types
case tensorflow::DataType::DT_FLOAT_REF: break; case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_DOUBLE_REF: break; case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT32_REF: break; case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_UINT8_REF: break; case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_INT16_REF: break; case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT8_REF: break; case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_STRING_REF: break; case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: break; case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_INT64_REF: break; case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_BOOL_REF: break; case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_QINT8_REF: break; case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QUINT8_REF: break; case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QINT32_REF: break; case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: break; case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_QINT16_REF: break; case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_QUINT16_REF: break; case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_UINT16_REF: break; case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: break; case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_HALF_REF: break; case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_RESOURCE_REF: break; case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF: break; case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF: break; case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF: break; case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:\
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break; case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: break;
} }
return shape_type; return shape_type;
...@@ -911,61 +897,59 @@ struct tf_parser ...@@ -911,61 +897,59 @@ struct tf_parser
const std::string& s = t.tensor_content(); const std::string& s = t.tensor_content();
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return literal{{shape::float_type, dims}, s.data()}; return literal{{shape::float_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL:
case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()}; case tensorflow::DataType::DT_INT8: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
return literal{{shape::uint16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT16: case tensorflow::DataType::DT_INT16:
return literal{{shape::int16_type, dims}, s.data()}; return literal{{shape::int16_type, dims}, s.data()};
case tensorflow::DataType::DT_INT32: case tensorflow::DataType::DT_INT32:
return literal{{shape::int32_type, dims}, s.data()}; return literal{{shape::int32_type, dims}, s.data()};
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return literal{{shape::int64_type, dims}, s.data()}; return literal{{shape::int64_type, dims}, s.data()};
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: return literal{{shape::int8_type, dims}, s.data()};
case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()}; case tensorflow::DataType::DT_HALF: return literal{{shape::half_type, dims}, s.data()};
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, s.data()}; return literal{{shape::double_type, dims}, s.data()};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_QINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_QINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error(""); throw std::runtime_error("");
} }
...@@ -973,11 +957,9 @@ struct tf_parser ...@@ -973,11 +957,9 @@ struct tf_parser
} }
switch(t.dtype()) switch(t.dtype())
{ {
case tensorflow::DataType::DT_INVALID: throw std::runtime_error("");
case tensorflow::DataType::DT_FLOAT: case tensorflow::DataType::DT_FLOAT:
return create_literal( return create_literal(
shape::float_type, dims, get_data_vals(t.float_val(), shape_size)); shape::float_type, dims, get_data_vals(t.float_val(), shape_size));
case tensorflow::DataType::DT_UINT8: throw std::runtime_error("");
case tensorflow::DataType::DT_INT8: case tensorflow::DataType::DT_INT8:
return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size)); return create_literal(shape::int8_type, dims, get_data_vals(t.int_val(), shape_size));
case tensorflow::DataType::DT_UINT16: case tensorflow::DataType::DT_UINT16:
...@@ -989,7 +971,6 @@ struct tf_parser ...@@ -989,7 +971,6 @@ struct tf_parser
case tensorflow::DataType::DT_INT64: case tensorflow::DataType::DT_INT64:
return create_literal( return create_literal(
shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size)); shape::int64_type, dims, get_data_vals(t.int64_val(), shape_size));
case tensorflow::DataType::DT_STRING: throw std::runtime_error("");
case tensorflow::DataType::DT_BOOL: case tensorflow::DataType::DT_BOOL:
return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size)); return create_literal(shape::int32_type, dims, get_data_vals(t.bool_val(), shape_size));
case tensorflow::DataType::DT_HALF: case tensorflow::DataType::DT_HALF:
...@@ -1005,45 +986,46 @@ struct tf_parser ...@@ -1005,45 +986,46 @@ struct tf_parser
} }
case tensorflow::DataType::DT_DOUBLE: case tensorflow::DataType::DT_DOUBLE:
return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)}; return literal{{shape::double_type, dims}, get_data_vals(t.double_val(), shape_size)};
case tensorflow::DataType::DT_UINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_INVALID:
case tensorflow::DataType::DT_UINT64: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8:
case tensorflow::DataType::DT_COMPLEX64: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING:
case tensorflow::DataType::DT_COMPLEX128: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT32:
case tensorflow::DataType::DT_QINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT64:
case tensorflow::DataType::DT_QUINT8: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64:
case tensorflow::DataType::DT_QINT32: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128:
case tensorflow::DataType::DT_BFLOAT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8:
case tensorflow::DataType::DT_QINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8:
case tensorflow::DataType::DT_QUINT16: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32:
case tensorflow::DataType::DT_RESOURCE: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16:
case tensorflow::DataType::DT_VARIANT: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16:
case tensorflow::DataType::DT_FLOAT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16:
case tensorflow::DataType::DT_DOUBLE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE:
case tensorflow::DataType::DT_INT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_VARIANT:
case tensorflow::DataType::DT_UINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_FLOAT_REF:
case tensorflow::DataType::DT_INT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_DOUBLE_REF:
case tensorflow::DataType::DT_INT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT32_REF:
case tensorflow::DataType::DT_STRING_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT8_REF:
case tensorflow::DataType::DT_COMPLEX64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT16_REF:
case tensorflow::DataType::DT_INT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT8_REF:
case tensorflow::DataType::DT_BOOL_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_STRING_REF:
case tensorflow::DataType::DT_QINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX64_REF:
case tensorflow::DataType::DT_QUINT8_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_INT64_REF:
case tensorflow::DataType::DT_QINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BOOL_REF:
case tensorflow::DataType::DT_BFLOAT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT8_REF:
case tensorflow::DataType::DT_QINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT8_REF:
case tensorflow::DataType::DT_QUINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT32_REF:
case tensorflow::DataType::DT_UINT16_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_BFLOAT16_REF:
case tensorflow::DataType::DT_COMPLEX128_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QINT16_REF:
case tensorflow::DataType::DT_HALF_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_QUINT16_REF:
case tensorflow::DataType::DT_RESOURCE_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_UINT16_REF:
case tensorflow::DataType::DT_VARIANT_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_COMPLEX128_REF:
case tensorflow::DataType::DT_UINT32_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_HALF_REF:
case tensorflow::DataType::DT_UINT64_REF: throw std::runtime_error(""); case tensorflow::DataType::DT_RESOURCE_REF:
case tensorflow::DataType::DT_VARIANT_REF:
case tensorflow::DataType::DT_UINT32_REF:
case tensorflow::DataType::DT_UINT64_REF:
case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_: case tensorflow::DataType::DataType_INT_MAX_SENTINEL_DO_NOT_USE_:
throw std::runtime_error(""); case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_: throw std::runtime_error("");
case tensorflow::DataType::DataType_INT_MIN_SENTINEL_DO_NOT_USE_:
throw std::runtime_error("");
} }
MIGRAPHX_THROW("Invalid tensor type"); MIGRAPHX_THROW("Invalid tensor type");
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment