Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6e058792
Commit
6e058792
authored
Jan 23, 2019
by
Khalique
Browse files
formatting
parent
c8a91e20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
34 deletions
+31
-34
src/tf/tf.cpp
src/tf/tf.cpp
+31
-34
No files found.
src/tf/tf.cpp
View file @
6e058792
...
@@ -24,10 +24,10 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -24,10 +24,10 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
tf_parser
struct
tf_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
AttrValue
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
AttrValue
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
NodeDef
>
;
using
node_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
NodeDef
>
;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
node_map
nodes
;
std
::
vector
<
tensorflow
::
NodeDef
>
input_nodes
;
std
::
vector
<
tensorflow
::
NodeDef
>
input_nodes
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
...
@@ -130,7 +130,7 @@ struct tf_parser
...
@@ -130,7 +130,7 @@ struct tf_parser
{
{
epsilon
=
attributes
.
at
(
"epsilon"
).
f
();
epsilon
=
attributes
.
at
(
"epsilon"
).
f
();
}
}
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
};
op
::
batch_norm_inference
op
{
epsilon
,
momentum
,
bn_mode
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
}
...
@@ -140,7 +140,7 @@ struct tf_parser
...
@@ -140,7 +140,7 @@ struct tf_parser
{
{
// get index for axis within args
// get index for axis within args
std
::
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
std
::
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
std
::
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
std
::
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
op
::
concat
op
{
axis
};
op
::
concat
op
{
axis
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
}
...
@@ -164,7 +164,7 @@ struct tf_parser
...
@@ -164,7 +164,7 @@ struct tf_parser
{
{
op
.
padding_mode
=
op
::
convolution
::
same
;
op
.
padding_mode
=
op
::
convolution
::
same
;
}
}
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
else
if
(
pad_mode
.
find
(
"EXPLICIT"
)
!=
std
::
string
::
npos
)
{
{
std
::
vector
<
std
::
size_t
>
padding
(
4
);
std
::
vector
<
std
::
size_t
>
padding
(
4
);
copy
(
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
padding
.
begin
());
copy
(
attributes
.
at
(
"explicit_paddings"
).
list
().
i
(),
padding
.
begin
());
...
@@ -200,7 +200,7 @@ struct tf_parser
...
@@ -200,7 +200,7 @@ struct tf_parser
// std::vector<instruction_ref> args)
// std::vector<instruction_ref> args)
// {
// {
// op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// op::pooling op{starts_with(name, "Max") ? "max" : "average"};
// if(contains(attributes, "pads"))
// if(contains(attributes, "pads"))
// {
// {
// std::vector<std::size_t> padding(4);
// std::vector<std::size_t> padding(4);
...
@@ -254,12 +254,12 @@ struct tf_parser
...
@@ -254,12 +254,12 @@ struct tf_parser
nodes
=
get_nodes
(
graph
,
input_nodes
);
nodes
=
get_nodes
(
graph
,
input_nodes
);
for
(
auto
&&
input
:
input_nodes
)
for
(
auto
&&
input
:
input_nodes
)
{
{
const
std
::
string
&
name
=
input
.
name
();
const
std
::
string
&
name
=
input
.
name
();
attribute_map
input_attrs
=
get_attributes
(
input
);
attribute_map
input_attrs
=
get_attributes
(
input
);
shape
::
type_t
shape_type
=
parse_type
(
input_attrs
.
at
(
"dtype"
).
type
());
shape
::
type_t
shape_type
=
parse_type
(
input_attrs
.
at
(
"dtype"
).
type
());
std
::
vector
<
size_t
>
dims
=
parse_dims
(
input_attrs
.
at
(
"shape"
).
shape
());
std
::
vector
<
size_t
>
dims
=
parse_dims
(
input_attrs
.
at
(
"shape"
).
shape
());
shape
s
=
shape
{
shape_type
,
dims
};
shape
s
=
shape
{
shape_type
,
dims
};
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
if
(
is_nhwc
)
if
(
is_nhwc
)
{
{
// nhwc to nchw
// nhwc to nchw
...
@@ -308,19 +308,17 @@ struct tf_parser
...
@@ -308,19 +308,17 @@ struct tf_parser
static
attribute_map
get_attributes
(
const
tensorflow
::
NodeDef
&
node
)
static
attribute_map
get_attributes
(
const
tensorflow
::
NodeDef
&
node
)
{
{
attribute_map
result
;
attribute_map
result
;
for
(
auto
&&
attr
:
node
.
attr
())
for
(
auto
&&
attr
:
node
.
attr
())
{
{
result
[
attr
.
first
]
=
attr
.
second
;
result
[
attr
.
first
]
=
attr
.
second
;
}
}
return
result
;
return
result
;
}
}
static
std
::
string
get_name
(
const
tensorflow
::
NodeDef
&
node
)
static
std
::
string
get_name
(
const
tensorflow
::
NodeDef
&
node
)
{
return
node
.
name
();
}
{
return
node
.
name
();
}
static
node_map
get_nodes
(
const
tensorflow
::
GraphDef
&
graph
,
std
::
vector
<
tensorflow
::
NodeDef
>&
input_nodes
)
static
node_map
get_nodes
(
const
tensorflow
::
GraphDef
&
graph
,
std
::
vector
<
tensorflow
::
NodeDef
>&
input_nodes
)
{
{
node_map
result
;
node_map
result
;
for
(
auto
&&
node
:
graph
.
node
())
for
(
auto
&&
node
:
graph
.
node
())
...
@@ -381,8 +379,7 @@ struct tf_parser
...
@@ -381,8 +379,7 @@ struct tf_parser
break
;
// throw std::runtime_error("Unsupported type VARIANT");
break
;
// throw std::runtime_error("Unsupported type VARIANT");
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
case
tensorflow
::
DataType
::
DT_UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
default:
default:
break
;
break
;
}
}
return
shape_type
;
return
shape_type
;
}
}
...
@@ -397,27 +394,32 @@ struct tf_parser
...
@@ -397,27 +394,32 @@ struct tf_parser
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
if
(
!
t
.
tensor_content
().
empty
())
// has raw data
{
{
const
std
::
string
&
s
=
t
.
tensor_content
();
const
std
::
string
&
s
=
t
.
tensor_content
();
switch
(
t
.
dtype
())
switch
(
t
.
dtype
())
{
{
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INVALID
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_FLOAT
:
return
literal
{{
shape
::
float_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT8
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT8
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT16
:
case
tensorflow
::
DataType
::
DT_INT16
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT16
:
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT32
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_STRING
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_HALF
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT32
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
default:
default:
break
;
break
;
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
...
@@ -449,11 +451,9 @@ struct tf_parser
...
@@ -449,11 +451,9 @@ struct tf_parser
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_UINT64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX64
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
case
tensorflow
::
DataType
::
DT_COMPLEX128
:
throw
std
::
runtime_error
(
""
);
default:
default:
break
;
break
;
}
}
MIGRAPHX_THROW
(
"Invalid tensor type"
);
MIGRAPHX_THROW
(
"Invalid tensor type"
);
}
}
static
std
::
vector
<
size_t
>
parse_dims
(
const
tensorflow
::
TensorShapeProto
&
s
)
static
std
::
vector
<
size_t
>
parse_dims
(
const
tensorflow
::
TensorShapeProto
&
s
)
...
@@ -466,9 +466,6 @@ struct tf_parser
...
@@ -466,9 +466,6 @@ struct tf_parser
}
}
return
dims
;
return
dims
;
}
}
};
};
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
program
parse_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment