Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8bf2940a
Commit
8bf2940a
authored
Jan 29, 2019
by
Khalique
Browse files
continued progress, fixed issues with conv, pooling, added reshape
parent
cbd244d1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
11 deletions
+25
-11
src/tf/tf.cpp
src/tf/tf.cpp
+25
-11
No files found.
src/tf/tf.cpp
View file @
8bf2940a
...
@@ -48,6 +48,8 @@ struct tf_parser
...
@@ -48,6 +48,8 @@ struct tf_parser
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
// add_mem_op("Reshape", &tf_parser::parse_reshape);
}
}
template
<
class
F
>
template
<
class
F
>
...
@@ -156,7 +158,8 @@ struct tf_parser
...
@@ -156,7 +158,8 @@ struct tf_parser
std
::
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
std
::
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
std
::
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
std
::
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
op
::
concat
op
{
axis
};
op
::
concat
op
{
axis
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
// return only first N arguments (assuming last index is the axis value)
return
prog
.
add_instruction
(
op
,
std
::
vector
<
instruction_ref
>
(
args
.
begin
(),
args
.
begin
()
+
axis
));
}
}
instruction_ref
parse_constant
(
const
std
::
string
&
,
instruction_ref
parse_constant
(
const
std
::
string
&
,
...
@@ -232,9 +235,14 @@ struct tf_parser
...
@@ -232,9 +235,14 @@ struct tf_parser
op
.
dilation
[
1
]
=
dilation
[
3
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
}
}
}
auto
l0
=
args
[
0
];
auto
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
2
,
3
,
0
,
1
}},
args
[
1
]);
if
(
l0
->
name
()
==
"@param"
)
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
l0
});
{
if
(
is_nhwc
)
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l0
);
}
auto
l1
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
return
prog
.
add_instruction
(
op
,
{
l0
,
l1
});
}
}
instruction_ref
parse_pooling
(
const
std
::
string
&
name
,
instruction_ref
parse_pooling
(
const
std
::
string
&
name
,
...
@@ -258,7 +266,7 @@ struct tf_parser
...
@@ -258,7 +266,7 @@ struct tf_parser
if
(
contains
(
attributes
,
"strides"
))
if
(
contains
(
attributes
,
"strides"
))
{
{
std
::
vector
<
std
::
size_t
>
stride
;
std
::
vector
<
std
::
size_t
>
stride
;
copy
(
attributes
.
at
(
"stride"
).
list
().
i
(),
std
::
back_inserter
(
stride
));
copy
(
attributes
.
at
(
"stride
s
"
).
list
().
i
(),
std
::
back_inserter
(
stride
));
if
(
stride
.
size
()
!=
4
)
if
(
stride
.
size
()
!=
4
)
{
{
MIGRAPHX_THROW
(
"strides should have 4 values"
);
MIGRAPHX_THROW
(
"strides should have 4 values"
);
...
@@ -297,6 +305,17 @@ struct tf_parser
...
@@ -297,6 +305,17 @@ struct tf_parser
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
}
}
instruction_ref
parse_reshape
(
const
std
::
string
&
,
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
op
::
reshape
op
;
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"reshape needs 2 arguments (input, new_shape)"
);
literal
s
=
args
[
1
]
->
get_literal
();
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
tensorflow
::
GraphDef
graph
;
tensorflow
::
GraphDef
graph
;
...
@@ -321,11 +340,6 @@ struct tf_parser
...
@@ -321,11 +340,6 @@ struct tf_parser
std
::
vector
<
size_t
>
dims
=
parse_dims
(
input_attrs
.
at
(
"shape"
).
shape
());
std
::
vector
<
size_t
>
dims
=
parse_dims
(
input_attrs
.
at
(
"shape"
).
shape
());
shape
s
=
shape
{
shape_type
,
dims
};
shape
s
=
shape
{
shape_type
,
dims
};
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
if
(
is_nhwc
)
{
// nhwc to nchw
prog
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
instructions
[
name
]);
}
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
...
@@ -339,7 +353,7 @@ struct tf_parser
...
@@ -339,7 +353,7 @@ struct tf_parser
{
{
auto
&&
node
=
nodes
.
at
(
name
);
auto
&&
node
=
nodes
.
at
(
name
);
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
std
::
cout
<<
name
<<
std
::
endl
;
//
std::cout << name << std::endl;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment