Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6e058792
Commit
6e058792
authored
Jan 23, 2019
by
Khalique
Browse files
formatting
parent
c8a91e20
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
34 deletions
+31
-34
src/tf/tf.cpp
src/tf/tf.cpp
+31
-34
No files found.
src/tf/tf.cpp
View file @
6e058792
...
@@ -164,7 +164,7 @@ struct tf_parser
...
@@ -164,7 +164,7 @@ struct tf_parser
{
{
op
.
padding_mode
=
op
::
convolution
::
same
;
op
.
padding_mode
=
op
::
convolution
::
same
;
}
}
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
{
{
std
::
vector
<
std
::
size_t
>
padding
(
4
);
std
::
vector
<
std
::
size_t
>
padding
(
4
);
copy
(
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
padding
.
begin
());
copy
(
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
padding
.
begin
());
...
@@ -308,19 +308,17 @@ struct tf_parser
...
@@ -308,19 +308,17 @@ struct tf_parser
static
attribute_map
get_attributes
(
const
tensorflow
::
NodeDef
&
node
)
static
attribute_map
get_attributes
(
const
tensorflow
::
NodeDef
&
node
)
{
{
attribute_map
result
;
attribute_map
result
;
for
(
auto
&&
attr
:
node
.
attr
())
for
(
auto
&&
attr
:
node
.
attr
())
{
{
result
[
attr
.
first
]
=
attr
.
second
;
result
[
attr
.
first
]
=
attr
.
second
;
}
}
return
result
;
return
result
;
}
}
static
std
::
string
get_name
(
const
tensorflow
::
NodeDef
&
node
)
static
std
::
string
get_name
(
const
tensorflow
::
NodeDef
&
node
)
{
return
node
.
name
();
}
{
return
node
.
name
();
}
static
node_map
get_nodes
(
const
tensorflow
::
GraphDef
&
graph
,
std
::
vector
<
tensorflow
::
NodeDef
>&
input_nodes
)
static
node_map
get_nodes
(
const
tensorflow
::
GraphDef
&
graph
,
std
::
vector
<
tensorflow
::
NodeDef
>&
input_nodes
)
{
{
node_map
result
;
node_map
result
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
...
@@ -381,8 +379,7 @@ struct tf_parser
...
@@ -381,8 +379,7 @@ struct tf_parser
break
;
// throw std::runtime_error("Unsupported type VARIANT");
break
;
// throw std::runtime_error("Unsupported type VARIANT");
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
default:
default:
break
;
break
;
}
}
return
shape_type
;
return
shape_type
;
}
}
...
@@ -397,27 +394,32 @@ struct tf_parser
...
@@ -397,27 +394,32 @@ struct tf_parser
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
{
{
const
std
::
string
&
s
=
t
.
tensor_content
();
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
switch
(
t
.
dtype
())
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
default:
default:
break
;
break
;
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
...
@@ -449,11 +451,9 @@ struct tf_parser
...
@@ -449,11 +451,9 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
default:
default:
break
;
break
;
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
static
std
::
vector
<
size_t
>
parse_dims
(
const
tensorflow
::
TensorShapeProto
&
s
)
static
std
::
vector
<
size_t
>
parse_dims
(
const
tensorflow
::
TensorShapeProto
&
s
)
...
@@ -466,9 +466,6 @@ struct tf_parser
...
@@ -466,9 +466,6 @@ struct tf_parser
}
}
return
dims
;
return
dims
;
}
}
};
};
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment