Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d666954
"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c5d7524042cd82e975677cbd3185e0e7acfc19b3"
Commit
2d666954
authored
May 24, 2019
by
Paul
Browse files
Fix parse constant
parent
cdf96caf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
33 deletions
+31
-33
src/tf/tf.cpp
src/tf/tf.cpp
+31
-33
No files found.
src/tf/tf.cpp
View file @
2d666954
...
...
@@ -43,28 +43,28 @@ struct tf_parser
instruction_ref
to_nhwc
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
2
,
3
,
1
}},
ins
);
return
ins
;
}
instruction_ref
to_nchw
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
ins
);
return
ins
;
}
instruction_ref
to_kcxy
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
ins
);
return
ins
;
}
instruction_ref
make_contiguous
(
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
if
(
ins
->
get_shape
().
standard
())
return
ins
;
else
return
prog
.
add_instruction
(
op
::
contiguous
{},
ins
);
...
...
@@ -73,9 +73,8 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
to_nchw
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
instruction_ref
>
result
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
result
.
begin
(),
[
&
](
auto
ins
)
{
return
to_nchw
(
ins
);
});
std
::
transform
(
args
.
begin
(),
args
.
end
(),
result
.
begin
(),
[
&
](
auto
ins
)
{
return
to_nchw
(
ins
);
});
return
result
;
}
...
...
@@ -161,7 +160,7 @@ struct tf_parser
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
,
false
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
,
false
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
...
...
@@ -176,13 +175,15 @@ struct tf_parser
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
void
add_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
if
(
transpose
)
if
(
transpose
)
{
ops
.
emplace
(
name
,
op_func
{[
=
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
->
instruction_ref
{
return
to_nhwc
(
f
(
attributes
,
to_nchw
(
args
)));
}});
ops
.
emplace
(
name
,
op_func
{[
=
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
->
instruction_ref
{
return
to_nhwc
(
f
(
attributes
,
to_nchw
(
args
)));
}});
}
else
{
...
...
@@ -191,11 +192,13 @@ struct tf_parser
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
void
add_mem_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
transpose
);
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
transpose
);
}
template
<
class
T
>
...
...
@@ -261,11 +264,13 @@ struct tf_parser
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
,
bool
transpose
=
true
)
void
add_generic_op
(
std
::
string
name
,
T
x
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
},
transpose
);
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
},
transpose
);
}
instruction_ref
...
...
@@ -307,15 +312,7 @@ struct tf_parser
const
std
::
vector
<
instruction_ref
>&
)
{
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
auto
l0
=
prog
.
add_literal
(
v
);
size_t
num_axes
=
l0
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
std
::
vector
<
int64_t
>
transpose_axes
=
get_axes
(
num_axes
);
reorder_data
(
transpose_axes
);
l0
=
prog
.
add_instruction
(
op
::
transpose
{
transpose_axes
},
l0
);
}
return
l0
;
return
prog
.
add_literal
(
v
);
}
instruction_ref
...
...
@@ -369,7 +366,7 @@ struct tf_parser
op
.
dilation
[
0
]
=
dilation
[
2
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
return
prog
.
add_instruction
(
op
,
{
to_nchw
(
args
[
0
]
)
,
to_kcxy
(
to_nchw
(
args
[
1
])
)
});
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
to_kcxy
(
args
[
1
])});
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
...
...
@@ -487,7 +484,8 @@ struct tf_parser
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
instruction_ref
arg
)
{
return
prog
.
add_instruction
(
op
::
unsqueeze
{{
axis
}},
arg
);
});
return
to_nhwc
(
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
));
return
to_nhwc
(
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
));
}
instruction_ref
...
...
@@ -514,7 +512,7 @@ struct tf_parser
pads
[
i
+
ndims
]
=
pad_per_dim
[
i
].
second
;
}
op
.
pads
=
pads
;
return
prog
.
add_instruction
(
op
,
args
.
front
());
return
to_nhwc
(
prog
.
add_instruction
(
op
,
args
.
front
())
)
;
}
instruction_ref
parse_pooling
(
const
std
::
string
&
name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment