Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
cdf96caf
Commit
cdf96caf
authored
May 24, 2019
by
Paul
Browse files
Transpose every instruction
parent
0d796941
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
80 deletions
+86
-80
src/tf/tf.cpp
src/tf/tf.cpp
+86
-80
No files found.
src/tf/tf.cpp
View file @
cdf96caf
...
...
@@ -36,6 +36,49 @@ struct tf_parser
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
bool
should_transpose
(
instruction_ref
ins
)
{
return
is_nhwc
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
}
instruction_ref
to_nhwc
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
2
,
3
,
1
}},
ins
);
return
ins
;
}
instruction_ref
to_nchw
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
ins
);
return
ins
;
}
instruction_ref
to_kcxy
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
ins
);
return
ins
;
}
instruction_ref
make_contiguous
(
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
return
ins
;
else
return
prog
.
add_instruction
(
op
::
contiguous
{},
ins
);
}
std
::
vector
<
instruction_ref
>
to_nchw
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
instruction_ref
>
result
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
result
.
begin
(),
[
&
](
auto
ins
)
{
return
to_nchw
(
ins
);
});
return
result
;
}
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
)
const
{
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
...
...
@@ -116,41 +159,43 @@ struct tf_parser
add_mem_op
(
"AvgPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
,
false
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
,
false
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
,
false
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
ops
.
emplace
(
name
,
f
);
if
(
transpose
)
{
ops
.
emplace
(
name
,
op_func
{[
=
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
->
instruction_ref
{
return
to_nhwc
(
f
(
attributes
,
to_nchw
(
args
)));
}});
}
else
{
ops
.
emplace
(
name
,
f
);
}
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
}
,
transpose
);
}
template
<
class
T
>
...
...
@@ -159,15 +204,15 @@ struct tf_parser
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
auto
l0
=
args
[
1
];
if
(
contains
(
attributes
,
"data_format"
))
{
if
(
is_nhwc
)
{
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
args
[
1
]);
}
// TODO
// if(is_nhwc)
// {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
}
return
add_broadcastable_binary_op
(
args
[
0
],
l0
,
x
);
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
]
,
x
);
});
}
...
...
@@ -207,20 +252,20 @@ struct tf_parser
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
to_nchw
(
l0
)
,
to_nchw
(
l1
))
);
}
else
{
return
prog
.
add_instruction
(
x
,
{
arg0
,
arg1
}
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
{
to_nchw
(
arg0
)
,
to_nchw
(
arg1
)})
);
}
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
}
,
transpose
);
}
instruction_ref
...
...
@@ -250,11 +295,11 @@ struct tf_parser
{
// get index for axis within args
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
size_t
axis
=
parse_axis
(
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
()
)
;
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
op
::
concat
op
{
axis
};
// return only first N arguments (assuming last index is the axis value)
return
prog
.
add_instruction
(
op
,
std
::
vector
<
instruction_ref
>
(
args
.
begin
(),
args
.
begin
()
+
args
.
size
()
-
1
));
return
to_nhwc
(
prog
.
add_instruction
(
op
,
std
::
vector
<
instruction_ref
>
(
args
.
begin
(),
args
.
begin
()
+
args
.
size
()
-
1
))
)
;
}
instruction_ref
parse_constant
(
const
std
::
string
&
,
...
...
@@ -324,22 +369,7 @@ struct tf_parser
op
.
dilation
[
0
]
=
dilation
[
2
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
weights
});
return
prog
.
add_instruction
(
op
,
{
to_nchw
(
args
[
0
]),
to_kcxy
(
to_nchw
(
args
[
1
]))});
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
...
...
@@ -369,19 +399,7 @@ struct tf_parser
op
.
stride
[
0
]
=
stride
[
2
];
op
.
stride
[
1
]
=
stride
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
auto
weights
=
to_kcxy
(
to_nchw
(
args
[
1
]));
std
::
vector
<
int64_t
>
new_weights_shape
;
copy
(
weights
->
get_shape
().
lens
(),
std
::
back_inserter
(
new_weights_shape
));
...
...
@@ -429,7 +447,7 @@ struct tf_parser
instruction_ref
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
()
)
;
auto
axes
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
std
::
vector
<
int32_t
>
hw_axes
{
2
,
3
};
// check if conditions for GlobalAvgPool are met
...
...
@@ -463,16 +481,13 @@ struct tf_parser
MIGRAPHX_THROW
(
"TF_PARSER: axis value of "
+
to_string
(
axis
)
+
" must be smaller than input size "
+
to_string
(
input_size
));
}
// check if input arg needs axis to be converted to NCHW
if
(
input_size
>=
4
)
axis
=
parse_axis
(
axis
);
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
instruction_ref
arg
)
{
return
prog
.
add_instruction
(
op
::
unsqueeze
{{
axis
}},
arg
);
});
return
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
);
return
to_nhwc
(
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
)
)
;
}
instruction_ref
...
...
@@ -555,7 +570,7 @@ struct tf_parser
MIGRAPHX_THROW
(
"reshape needs 2 arguments (input, new_shape)"
);
auto
s
=
args
[
1
]
->
eval
();
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
void
parse_from
(
std
::
istream
&
is
)
...
...
@@ -586,7 +601,7 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
args
)
{
op
::
squeeze
op
;
auto
axes
=
parse_axes
(
attributes
,
"squeeze_dims"
);
auto
axes
=
attributes
.
at
(
"squeeze_dims"
)
.
list
().
i
()
;
copy
(
axes
,
std
::
back_inserter
(
op
.
axes
));
auto
args0_dims
=
args
[
0
]
->
get_shape
().
lens
();
if
(
op
.
axes
.
empty
())
// no squeeze_dims provided, remove any dim that equals 1
...
...
@@ -599,7 +614,7 @@ struct tf_parser
}
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
instruction_ref
parse_stridedslice
(
const
std
::
string
&
,
...
...
@@ -610,11 +625,6 @@ struct tf_parser
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
size_t
num_axes
=
args
[
0
]
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
reorder_data
(
starts
);
reorder_data
(
ends
);
}
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
ends
.
begin
(),
ends
.
end
());
...
...
@@ -633,13 +643,9 @@ struct tf_parser
if
(((
shrink_axis_mask
>>
i
)
&
bitwise_compare
)
==
1
)
squeeze_axes
.
push_back
(
i
);
}
if
(
num_axes
>=
4
)
{
squeeze_axes
=
parse_axes
(
squeeze_axes
);
}
auto
l0
=
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
);
auto
l0
=
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
return
to_nhwc
(
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
)
)
;
}
void
parse_graph
(
const
tensorflow
::
GraphDef
&
graph
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment