Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
634ea0f2
Commit
634ea0f2
authored
Jan 26, 2019
by
Khalique
Browse files
continued tf pb progress, adjusting dims for conv
parent
b12844ec
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
12 deletions
+41
-12
src/tf/read_tf.cpp
src/tf/read_tf.cpp
+1
-1
src/tf/tf.cpp
src/tf/tf.cpp
+40
-11
No files found.
src/tf/read_tf.cpp
View file @
634ea0f2
...
@@ -7,7 +7,7 @@ int main(int argc, char const* argv[])
...
@@ -7,7 +7,7 @@ int main(int argc, char const* argv[])
bool
is_nhwc
=
true
;
bool
is_nhwc
=
true
;
if
(
argc
>
2
)
if
(
argc
>
2
)
{
{
if
(
argv
[
2
]
==
"nchw"
)
if
(
strcmp
(
argv
[
2
]
,
"nchw"
)
==
0
)
is_nhwc
=
false
;
is_nhwc
=
false
;
}
}
std
::
string
file
=
argv
[
1
];
std
::
string
file
=
argv
[
1
];
...
...
src/tf/tf.cpp
View file @
634ea0f2
...
@@ -50,10 +50,23 @@ struct tf_parser
...
@@ -50,10 +50,23 @@ struct tf_parser
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
}
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
}
template
<
class
F
>
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
)
{
{
ops
.
emplace
(
name
,
[
=
](
auto
&&
...
xs
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
});
}
}
...
@@ -61,7 +74,7 @@ struct tf_parser
...
@@ -61,7 +74,7 @@ struct tf_parser
template
<
class
T
>
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
...
@@ -115,7 +128,7 @@ struct tf_parser
...
@@ -115,7 +128,7 @@ struct tf_parser
template
<
class
T
>
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
)
{
{
ops
.
emplace
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
add_op
(
name
,
[
this
,
x
](
attribute_map
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
return
prog
.
add_instruction
(
x
,
args
);
});
});
}
}
...
@@ -125,7 +138,7 @@ struct tf_parser
...
@@ -125,7 +138,7 @@ struct tf_parser
{
{
float
epsilon
=
1e-4
f
;
float
epsilon
=
1e-4
f
;
float
momentum
=
1.
f
;
float
momentum
=
1.
f
;
op
::
batch_norm_inference
::
bn_infer_mode_t
bn_mode
=
op
::
batch_norm_inference
::
spatial
;
op
::
batch_norm_inference
::
bn_infer_mode_t
bn_mode
=
op
::
batch_norm_inference
::
per_activation
;
if
(
contains
(
attributes
,
"epsilon"
))
if
(
contains
(
attributes
,
"epsilon"
))
{
{
epsilon
=
attributes
.
at
(
"epsilon"
).
f
();
epsilon
=
attributes
.
at
(
"epsilon"
).
f
();
...
@@ -182,16 +195,32 @@ struct tf_parser
...
@@ -182,16 +195,32 @@ struct tf_parser
}
}
if
(
contains
(
attributes
,
"strides"
))
if
(
contains
(
attributes
,
"strides"
))
{
{
copy
(
attributes
.
at
(
"strides"
).
list
().
i
(),
op
.
stride
.
begin
());
std
::
vector
<
std
::
size_t
>
stride
(
4
);
copy
(
attributes
.
at
(
"strides"
).
list
().
i
(),
stride
.
begin
());
if
(
stride
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"stride should have 4 values"
);
}
op
.
stride
[
0
]
=
stride
[
0
];
op
.
stride
[
1
]
=
stride
[
3
];
op
.
stride
[
2
]
=
stride
[
1
];
op
.
stride
[
3
]
=
stride
[
2
];
}
}
if
(
contains
(
attributes
,
"dilations"
))
if
(
contains
(
attributes
,
"dilations"
))
{
{
copy
(
attributes
.
at
(
"dilations"
).
list
().
i
(),
op
.
dilation
.
begin
());
std
::
vector
<
std
::
size_t
>
dilation
(
4
);
copy
(
attributes
.
at
(
"dilations"
).
list
().
i
(),
dilation
.
begin
());
if
(
dilation
.
size
()
!=
4
)
{
MIGRAPHX_THROW
(
"dilation should have 4 values"
);
}
op
.
dilation
[
0
]
=
dilation
[
0
];
op
.
dilation
[
1
]
=
dilation
[
3
];
op
.
dilation
[
2
]
=
dilation
[
1
];
op
.
dilation
[
3
]
=
dilation
[
2
];
}
}
auto
l0
=
args
[
1
];
auto
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
2
,
3
,
0
,
1
}},
args
[
1
]);
if
(
is_nhwc
)
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l0
);
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
l0
});
return
prog
.
add_instruction
(
op
,
{
args
[
0
],
l0
});
}
}
...
@@ -245,7 +274,7 @@ struct tf_parser
...
@@ -245,7 +274,7 @@ struct tf_parser
}
}
else
else
{
{
throw
std
::
runtime_error
(
"Failed reading"
);
throw
std
::
runtime_error
(
"Failed reading
tf file
"
);
}
}
}
}
...
@@ -268,7 +297,7 @@ struct tf_parser
...
@@ -268,7 +297,7 @@ struct tf_parser
}
}
for
(
auto
&&
p
:
nodes
)
for
(
auto
&&
p
:
nodes
)
{
{
this
->
parse_node
(
get_name
(
p
.
second
)
);
this
->
parse_node
(
p
.
first
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment