Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dabb2049
Commit
dabb2049
authored
Jul 11, 2019
by
Paul
Browse files
Fix duplicate branches
parent
ab35b581
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
145 additions
and
179 deletions
+145
-179
CMakeLists.txt
CMakeLists.txt
+1
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+32
-49
src/tf/tf.cpp
src/tf/tf.cpp
+112
-130
No files found.
CMakeLists.txt
View file @
dabb2049
...
@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
...
@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
-modernize-use-override
-modernize-use-override
-modernize-pass-by-value
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-default-member-init
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-modernize-use-transparent-functors
-performance-type-promotion-in-math-fn
-performance-type-promotion-in-math-fn
-readability-braces-around-statements
-readability-braces-around-statements
...
...
src/onnx/onnx.cpp
View file @
dabb2049
...
@@ -1469,16 +1469,16 @@ struct onnx_parser
...
@@ -1469,16 +1469,16 @@ struct onnx_parser
{
{
switch
(
attr
.
type
())
switch
(
attr
.
type
())
{
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
FLOAT
:
return
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
INT
:
return
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
UNDEFINED
:
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPH
:
case
onnx
::
AttributeProto
::
STRING
:
case
onnx
::
AttributeProto
::
STRINGS
:
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
MIGRAPHX_THROW
(
"Invalid attribute type"
);
MIGRAPHX_THROW
(
"Invalid attribute type"
);
...
@@ -1492,47 +1492,35 @@ struct onnx_parser
...
@@ -1492,47 +1492,35 @@ struct onnx_parser
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT16
:
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
UINT8
:
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
UNDEFINED
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT32
:
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
UINT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
INT64
:
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
());
return
create_literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
FLOAT16
:
{
{
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
...
@@ -1543,11 +1531,12 @@ struct onnx_parser
...
@@ -1543,11 +1531,12 @@ struct onnx_parser
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
return
create_literal
(
shape
::
half_type
,
dims
,
data_half
);
return
create_literal
(
shape
::
half_type
,
dims
,
data_half
);
}
}
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
UNDEFINED
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
UINT8
:
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT32
:
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
...
@@ -1575,28 +1564,22 @@ struct onnx_parser
...
@@ -1575,28 +1564,22 @@ struct onnx_parser
shape
::
type_t
shape_type
{};
shape
::
type_t
shape_type
{};
switch
(
t
.
tensor_type
().
elem_type
())
switch
(
t
.
tensor_type
().
elem_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
shape_type
=
shape
::
half_type
;
break
;
case
onnx
::
TensorProto
::
FLOAT16
:
shape_type
=
shape
::
half_type
;
break
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
...
...
src/tf/tf.cpp
View file @
dabb2049
...
@@ -831,72 +831,58 @@ struct tf_parser
...
@@ -831,72 +831,58 @@ struct tf_parser
shape
::
type_t
shape_type
{};
shape
::
type_t
shape_type
{};
switch
(
t
)
switch
(
t
)
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
tensorflow
::
DataType
::
DT_FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
tensorflow
::
DataType
::
DT_FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
tensorflow
::
DataType
::
DT_INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
tensorflow
::
DataType
::
DT_HALF
:
shape_type
=
shape
::
half_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
tensorflow
::
DataType
::
DT_INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
tensorflow
::
DataType
::
DT_BOOL
:
case
tensorflow
::
DataType
::
DT_BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QINT8
:
break
;
// throw std::runtime_error("Unsupported type QINT8");
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
break
;
// throw std::runtime_error("Unsupported type QUINT8");
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_QINT32
:
break
;
// throw std::runtime_error("Unsupported type QINT32");
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
break
;
// throw std::runtime_error("Unsupported type BFLOAT16");
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_QINT16
:
break
;
// throw std::runtime_error("Unsupported type QINT16");
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
break
;
// throw std::runtime_error("Unsupported type QUINT16");
case
tensorflow
::
DataType
::
DT_UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
case
tensorflow
::
DataType
::
DT_HALF
:
shape_type
=
shape
::
half_type
;
break
;
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
break
;
// throw std::runtime_error("Unsupported type RESOURCE");
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
break
;
// throw std::runtime_error("Unsupported type VARIANT");
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
// tf pb should not use these types
// tf pb should not use these types
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
break
;
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
break
;
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
break
;
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
break
;
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
break
;
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
break
;
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
break
;
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
\
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
break
;
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
break
;
}
}
return
shape_type
;
return
shape_type
;
...
@@ -911,61 +897,59 @@ struct tf_parser
...
@@ -911,61 +897,59 @@ struct tf_parser
const
std
::
string
&
s
=
t
.
tensor_content
();
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
switch
(
t
.
dtype
())
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_
UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_
BOOL
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
uint16_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int16_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int16_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32
:
case
tensorflow
::
DataType
::
DT_QINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_QINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
throw
std
::
runtime_error
(
""
);
}
}
...
@@ -973,11 +957,9 @@ struct tf_parser
...
@@ -973,11 +957,9 @@ struct tf_parser
}
}
switch
(
t
.
dtype
())
switch
(
t
.
dtype
())
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
create_literal
(
return
create_literal
(
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
create_literal
(
shape
::
int8_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
return
create_literal
(
shape
::
int8_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
...
@@ -989,7 +971,6 @@ struct tf_parser
...
@@ -989,7 +971,6 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
create_literal
(
return
create_literal
(
shape
::
int64_type
,
dims
,
get_data_vals
(
t
.
int64_val
(),
shape_size
));
shape
::
int64_type
,
dims
,
get_data_vals
(
t
.
int64_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
case
tensorflow
::
DataType
::
DT_BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
bool_val
(),
shape_size
));
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
bool_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_HALF
:
case
tensorflow
::
DataType
::
DT_HALF
:
...
@@ -1005,45 +986,46 @@ struct tf_parser
...
@@ -1005,45 +986,46 @@ struct tf_parser
}
}
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32
:
case
tensorflow
::
DataType
::
DT_QINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_QINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment