Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
dabb2049
Commit
dabb2049
authored
Jul 11, 2019
by
Paul
Browse files
Fix duplicate branches
parent
ab35b581
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
145 additions
and
179 deletions
+145
-179
CMakeLists.txt
CMakeLists.txt
+1
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+32
-49
src/tf/tf.cpp
src/tf/tf.cpp
+112
-130
No files found.
CMakeLists.txt
View file @
dabb2049
...
@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
...
@@ -81,6 +81,7 @@ rocm_enable_clang_tidy(
-modernize-use-override
-modernize-use-override
-modernize-pass-by-value
-modernize-pass-by-value
-modernize-use-default-member-init
-modernize-use-default-member-init
-modernize-use-trailing-return-type
-modernize-use-transparent-functors
-modernize-use-transparent-functors
-performance-type-promotion-in-math-fn
-performance-type-promotion-in-math-fn
-readability-braces-around-statements
-readability-braces-around-statements
...
...
src/onnx/onnx.cpp
View file @
dabb2049
...
@@ -1469,16 +1469,16 @@ struct onnx_parser
...
@@ -1469,16 +1469,16 @@ struct onnx_parser
{
{
switch
(
attr
.
type
())
switch
(
attr
.
type
())
{
{
case
onnx
::
AttributeProto
::
UNDEFINED
:
return
{};
case
onnx
::
AttributeProto
::
FLOAT
:
return
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
FLOAT
:
return
literal
{
attr
.
f
()};
case
onnx
::
AttributeProto
::
INT
:
return
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
INT
:
return
literal
{
attr
.
i
()};
case
onnx
::
AttributeProto
::
STRING
:
return
{};
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
TENSOR
:
return
parse_tensor
(
attr
.
t
());
case
onnx
::
AttributeProto
::
GRAPH
:
return
{};
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
FLOATS
:
return
from_repeated
(
shape
::
float_type
,
attr
.
floats
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
INTS
:
return
from_repeated
(
shape
::
int64_type
,
attr
.
ints
());
case
onnx
::
AttributeProto
::
STRINGS
:
return
{};
case
onnx
::
AttributeProto
::
UNDEFINED
:
case
onnx
::
AttributeProto
::
TENSORS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPH
:
case
onnx
::
AttributeProto
::
STRING
:
case
onnx
::
AttributeProto
::
STRINGS
:
case
onnx
::
AttributeProto
::
TENSORS
:
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
case
onnx
::
AttributeProto
::
GRAPHS
:
return
{};
}
}
MIGRAPHX_THROW
(
"Invalid attribute type"
);
MIGRAPHX_THROW
(
"Invalid attribute type"
);
...
@@ -1492,47 +1492,35 @@ struct onnx_parser
...
@@ -1492,47 +1492,35 @@ struct onnx_parser
const
std
::
string
&
s
=
t
.
raw_data
();
const
std
::
string
&
s
=
t
.
raw_data
();
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT16
:
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
UINT8
:
return
create_literal
(
shape
::
half_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
UNDEFINED
:
return
create_literal
(
shape
::
double_type
,
dims
,
s
.
data
());
case
onnx
::
TensorProto
::
UINT32
:
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
switch
(
t
.
data_type
())
switch
(
t
.
data_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
UINT8
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
INT8
:
case
onnx
::
TensorProto
::
INT8
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
UINT16
:
case
onnx
::
TensorProto
::
UINT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
INT16
:
case
onnx
::
TensorProto
::
INT16
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
INT32
:
case
onnx
::
TensorProto
::
INT32
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
INT64
:
case
onnx
::
TensorProto
::
INT64
:
return
create_literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
());
return
create_literal
(
shape
::
int64_type
,
dims
,
t
.
int64_data
());
case
onnx
::
TensorProto
::
DOUBLE
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT
:
return
create_literal
(
shape
::
float_type
,
dims
,
t
.
float_data
());
case
onnx
::
TensorProto
::
BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
t
.
int32_data
());
case
onnx
::
TensorProto
::
FLOAT16
:
case
onnx
::
TensorProto
::
FLOAT16
:
{
{
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
std
::
vector
<
uint16_t
>
data_uint16
(
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
());
...
@@ -1543,11 +1531,12 @@ struct onnx_parser
...
@@ -1543,11 +1531,12 @@ struct onnx_parser
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
[](
uint16_t
raw_val
)
{
return
*
reinterpret_cast
<
half
*>
(
&
raw_val
);
});
return
create_literal
(
shape
::
half_type
,
dims
,
data_half
);
return
create_literal
(
shape
::
half_type
,
dims
,
data_half
);
}
}
case
onnx
::
TensorProto
::
DOUBLE
:
case
onnx
::
TensorProto
::
UNDEFINED
:
return
create_literal
(
shape
::
double_type
,
dims
,
t
.
double_data
());
case
onnx
::
TensorProto
::
UINT8
:
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT32
:
case
onnx
::
TensorProto
::
COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
COMPLEX128
:
throw
std
::
runtime_error
(
""
);
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
...
@@ -1575,28 +1564,22 @@ struct onnx_parser
...
@@ -1575,28 +1564,22 @@ struct onnx_parser
shape
::
type_t
shape_type
{};
shape
::
type_t
shape_type
{};
switch
(
t
.
tensor_type
().
elem_type
())
switch
(
t
.
tensor_type
().
elem_type
())
{
{
case
onnx
::
TensorProto
::
UNDEFINED
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
onnx
::
TensorProto
::
STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
shape_type
=
shape
::
half_type
;
break
;
case
onnx
::
TensorProto
::
FLOAT16
:
shape_type
=
shape
::
half_type
;
break
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
onnx
::
TensorProto
::
UINT8
:
case
onnx
::
TensorProto
::
STRING
:
case
onnx
::
TensorProto
::
BOOL
:
case
onnx
::
TensorProto
::
UNDEFINED
:
case
onnx
::
TensorProto
::
COMPLEX64
:
case
onnx
::
TensorProto
::
COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type");
case
onnx
::
TensorProto
::
COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
}
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
auto
&&
tensor_dims
=
t
.
tensor_type
().
shape
().
dim
();
...
...
src/tf/tf.cpp
View file @
dabb2049
...
@@ -831,72 +831,58 @@ struct tf_parser
...
@@ -831,72 +831,58 @@ struct tf_parser
shape
::
type_t
shape_type
{};
shape
::
type_t
shape_type
{};
switch
(
t
)
switch
(
t
)
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
break
;
// throw std::runtime_error("Unsupported type UNDEFINED");
case
tensorflow
::
DataType
::
DT_FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
tensorflow
::
DataType
::
DT_FLOAT
:
shape_type
=
shape
::
float_type
;
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT32
:
shape_type
=
shape
::
int32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT8
:
break
;
// throw std::runtime_error("Unsupported type UINT8");
case
tensorflow
::
DataType
::
DT_INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT16
:
shape_type
=
shape
::
int16_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT8
:
shape_type
=
shape
::
int8_type
;
break
;
case
tensorflow
::
DataType
::
DT_INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
tensorflow
::
DataType
::
DT_HALF
:
shape_type
=
shape
::
half_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_STRING
:
break
;
// throw std::runtime_error("Unsupported type STRING");
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX64");
case
tensorflow
::
DataType
::
DT_INT64
:
shape_type
=
shape
::
int64_type
;
break
;
case
tensorflow
::
DataType
::
DT_BOOL
:
case
tensorflow
::
DataType
::
DT_BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QINT8
:
break
;
// throw std::runtime_error("Unsupported type QINT8");
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
break
;
// throw std::runtime_error("Unsupported type QUINT8");
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_QINT32
:
break
;
// throw std::runtime_error("Unsupported type QINT32");
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
break
;
// throw std::runtime_error("Unsupported type BFLOAT16");
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_QINT16
:
break
;
// throw std::runtime_error("Unsupported type QINT16");
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
break
;
// throw std::runtime_error("Unsupported type QUINT16");
case
tensorflow
::
DataType
::
DT_UINT16
:
shape_type
=
shape
::
uint16_type
;
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
break
;
// throw std::runtime_error("Unsupported type COMPLEX128");
case
tensorflow
::
DataType
::
DT_HALF
:
shape_type
=
shape
::
half_type
;
break
;
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
break
;
// throw std::runtime_error("Unsupported type RESOURCE");
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
break
;
// throw std::runtime_error("Unsupported type VARIANT");
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
// tf pb should not use these types
// tf pb should not use these types
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
break
;
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
break
;
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
break
;
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
break
;
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
break
;
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
break
;
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
break
;
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
break
;
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
break
;
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
break
;
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
\
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
break
;
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
break
;
}
}
return
shape_type
;
return
shape_type
;
...
@@ -911,61 +897,59 @@ struct tf_parser
...
@@ -911,61 +897,59 @@ struct tf_parser
const
std
::
string
&
s
=
t
.
tensor_content
();
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
switch
(
t
.
dtype
())
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_
UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_
BOOL
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
uint16_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int16_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int16_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int8_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32
:
case
tensorflow
::
DataType
::
DT_QINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_QINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
throw
std
::
runtime_error
(
""
);
}
}
...
@@ -973,11 +957,9 @@ struct tf_parser
...
@@ -973,11 +957,9 @@ struct tf_parser
}
}
switch
(
t
.
dtype
())
switch
(
t
.
dtype
())
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
create_literal
(
return
create_literal
(
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
shape
::
float_type
,
dims
,
get_data_vals
(
t
.
float_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
case
tensorflow
::
DataType
::
DT_INT8
:
return
create_literal
(
shape
::
int8_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
return
create_literal
(
shape
::
int8_type
,
dims
,
get_data_vals
(
t
.
int_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_UINT16
:
...
@@ -989,7 +971,6 @@ struct tf_parser
...
@@ -989,7 +971,6 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_INT64
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
create_literal
(
return
create_literal
(
shape
::
int64_type
,
dims
,
get_data_vals
(
t
.
int64_val
(),
shape_size
));
shape
::
int64_type
,
dims
,
get_data_vals
(
t
.
int64_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
case
tensorflow
::
DataType
::
DT_BOOL
:
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
bool_val
(),
shape_size
));
return
create_literal
(
shape
::
int32_type
,
dims
,
get_data_vals
(
t
.
bool_val
(),
shape_size
));
case
tensorflow
::
DataType
::
DT_HALF
:
case
tensorflow
::
DataType
::
DT_HALF
:
...
@@ -1005,45 +986,46 @@ struct tf_parser
...
@@ -1005,45 +986,46 @@ struct tf_parser
}
}
case
tensorflow
::
DataType
::
DT_DOUBLE
:
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
return
literal
{{
shape
::
double_type
,
dims
},
get_data_vals
(
t
.
double_val
(),
shape_size
)};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32
:
case
tensorflow
::
DataType
::
DT_QINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
case
tensorflow
::
DataType
::
DT_QUINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
case
tensorflow
::
DataType
::
DT_QINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8
:
case
tensorflow
::
DataType
::
DT_QINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8
:
case
tensorflow
::
DataType
::
DT_QUINT16
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32
:
case
tensorflow
::
DataType
::
DT_RESOURCE
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16
:
case
tensorflow
::
DataType
::
DT_VARIANT
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16
:
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16
:
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE
:
case
tensorflow
::
DataType
::
DT_INT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_VARIANT
:
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT_REF
:
case
tensorflow
::
DataType
::
DT_INT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_DOUBLE_REF
:
case
tensorflow
::
DataType
::
DT_INT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT32_REF
:
case
tensorflow
::
DataType
::
DT_STRING_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT16_REF
:
case
tensorflow
::
DataType
::
DT_INT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8_REF
:
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING_REF
:
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64_REF
:
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT64_REF
:
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL_REF
:
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT8_REF
:
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT8_REF
:
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BFLOAT16_REF
:
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QINT16_REF
:
case
tensorflow
::
DataType
::
DT_HALF_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_QUINT16_REF
:
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT16_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_HALF_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_RESOURCE_REF
:
case
tensorflow
::
DataType
::
DT_VARIANT_REF
:
case
tensorflow
::
DataType
::
DT_UINT32_REF
:
case
tensorflow
::
DataType
::
DT_UINT64_REF
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
case
tensorflow
::
DataType
::
DataType_INT_MAX_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DataType_INT_MIN_SENTINEL_DO_NOT_USE_
:
throw
std
::
runtime_error
(
""
);
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment