Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bc80dee8
Unverified
Commit
bc80dee8
authored
Jul 08, 2019
by
mvermeulen
Committed by
GitHub
Jul 08, 2019
Browse files
Merge pull request #265 from ROCmSoftwarePlatform/tf-transpose
Transpose each layer of TF
parents
8d5a2210
2ee59b2b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
940 additions
and
265 deletions
+940
-265
src/driver/main.cpp
src/driver/main.cpp
+24
-0
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+17
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+171
-45
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+179
-56
src/targets/gpu/device/concat.cpp
src/targets/gpu/device/concat.cpp
+5
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+40
-14
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-2
src/tf/tf.cpp
src/tf/tf.cpp
+108
-100
test/matcher.cpp
test/matcher.cpp
+189
-0
test/shape_test.cpp
test/shape_test.cpp
+10
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+141
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+52
-44
No files found.
src/driver/main.cpp
View file @
bc80dee8
...
...
@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -17,6 +25,7 @@ struct loader
std
::
string
file_type
;
bool
is_nhwc
=
true
;
unsigned
trim
=
0
;
bool
optimize
=
false
;
void
parse
(
argument_parser
&
ap
)
{
...
...
@@ -26,6 +35,7 @@ struct loader
ap
(
is_nhwc
,
{
"--nhwc"
},
ap
.
help
(
"Treat tensorflow format as nhwc"
),
ap
.
set_value
(
true
));
ap
(
is_nhwc
,
{
"--nchw"
},
ap
.
help
(
"Treat tensorflow format as nchw"
),
ap
.
set_value
(
false
));
ap
(
trim
,
{
"--trim"
,
"-t"
},
ap
.
help
(
"Trim instructions from the end"
));
ap
(
optimize
,
{
"--optimize"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
}
program
load
()
...
...
@@ -48,6 +58,20 @@ struct loader
auto
last
=
std
::
prev
(
p
.
end
(),
trim
);
p
.
remove_instructions
(
last
,
p
.
end
());
}
if
(
optimize
)
migraphx
::
run_passes
(
p
,
{
migraphx
::
eliminate_identity
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_algebra
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
propagate_constant
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
eliminate_pad
{},
migraphx
::
dead_code_elimination
{},
});
return
p
;
}
};
...
...
src/include/migraphx/functional.hpp
View file @
bc80dee8
...
...
@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs)
};
}
template
<
class
T
>
struct
always_f
{
T
x
;
template
<
class
...
Ts
>
constexpr
T
operator
()(
Ts
&&
...)
const
{
return
x
;
}
};
template
<
class
T
>
auto
always
(
T
x
)
{
return
always_f
<
T
>
{
x
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
bc80dee8
...
...
@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -20,6 +21,12 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
}
private:
instruction_ref
last
;
};
...
...
@@ -205,74 +212,147 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
}
/// Find matches for an instruction in the program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
p
.
end
())
return
;
m
.
apply
(
p
,
r
);
match
=
true
;
},
ms
...);
}
/// Find matches in a program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
p
.
end
())
return
;
m
.
apply
(
p
,
r
);
match
=
true
;
},
ms
...);
find_matches
(
p
,
ins
,
ms
...);
}
}
template
<
class
...
Ts
>
auto
all_of
(
Ts
...
ms
)
struct
lazy_and
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
and
g
();
}
};
template
<
class
...
Ts
>
auto
none_of
(
Ts
...
ms
)
struct
lazy_or
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
or
g
();
}
};
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
match_fold_f
{
template
<
class
...
Ms
>
static
bool
fold_matchers
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Ms
...
ms
)
{
Op
op
;
auto
matched
=
[
&
](
auto
m
)
{
return
[
=
,
&
ctx
]
{
return
ctx
.
matched
(
m
,
ins
);
};
};
return
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
always
(
x
),
matched
(
y
));
})(
Start
,
ms
...);
}
template
<
class
Pack
>
static
bool
fold_matchers_pack
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Pack
p
)
{
return
p
([
&
](
auto
...
ms
)
{
return
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
});
}
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
if
(
matches
==
Matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
Selector
>
auto
operator
[](
Selector
select
)
const
{
return
[
=
](
auto
...
ms
)
{
// Workaround ICE on gcc by packing matchers into an object
auto
mpack
=
pack
(
ms
...);
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
Op
op
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
auto
fm
=
[
&
]
{
return
match_fold_f
::
fold_matchers_pack
(
ctx
,
ins
,
mpack
);
};
matches
=
op
(
always
(
matches
),
fm
);
});
if
(
matches
==
Matches
)
return
start
;
return
ctx
.
not_found
();
});
};
}
};
const
constexpr
auto
all_of
=
match_fold_f
<
lazy_and
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
inline
auto
inputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
template
<
class
...
Ts
>
auto
any_of
(
Ts
...
ms
)
inline
auto
outputs
()
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
false
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
};
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
not_standard_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
broadcasted
();
}
MIGRAPHX_PRED_MATCHER
(
transpose_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
transposed
();
}
MIGRAPHX_PRED_MATCHER
(
same_input_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
...
...
@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
}
inline
auto
name
(
std
::
string
name
)
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
instruction_ref
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
}
return
ctx
.
not_found
();
})(
start
);
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
inline
auto
nargs
(
std
::
size_t
n
)
...
...
@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
template
<
class
M
>
auto
same_shape
(
M
m
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
i
=
m
.
match
(
ctx
,
ins
);
if
(
i
!=
ctx
.
not_found
()
and
i
->
get_shape
()
==
ins
->
get_shape
())
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ms
>
auto
same_shape
(
Ms
...
ms
)
{
return
all_of
(
same_shape
(
ms
)...);
}
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/simplify_reshapes.cpp
View file @
bc80dee8
...
...
@@ -2,14 +2,17 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
is_reshaper
(
instruction_ref
ins
)
const
auto
&
reshaper_names
(
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
...
...
@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze"
};
// clang-format on
return
contains
(
names
,
ins
->
name
())
;
return
names
;
}
bool
is_transpose_output
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
bool
is_reshaper
(
instruction_ref
ins
)
{
return
contains
(
reshaper_names
(),
ins
->
name
());
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
...
...
@@ -42,62 +38,189 @@ instruction_ref find_transpose_input(instruction_ref ins)
return
ins
;
}
void
simplify_reshapes
::
apply
(
program
&
p
)
co
ns
t
auto
get_transpose_dims
(
instruction_ref
i
ns
)
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
return
any_cast
<
const
op
::
transpose
&>
(
ins
->
get_operator
()).
dims
;
}
std
::
vector
<
int64_t
>
reorder_dims
(
std
::
vector
<
int64_t
>
dims
,
std
::
vector
<
int64_t
>
permutation
)
{
std
::
vector
<
int64_t
>
result
(
dims
.
size
());
assert
(
dims
.
size
()
==
permutation
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
if
(
is_reshaper
(
ins
))
result
[
i
]
=
dims
[
permutation
[
i
]];
}
return
result
;
}
bool
is_no_transpose
(
const
std
::
vector
<
int64_t
>&
dims
)
{
if
(
dims
.
empty
())
return
true
;
if
(
dims
.
front
()
!=
0
)
return
false
;
return
std
::
adjacent_find
(
dims
.
begin
(),
dims
.
end
(),
[](
auto
x
,
auto
y
)
{
return
(
y
-
x
)
!=
1
;
})
==
dims
.
end
();
}
template
<
class
Vector
,
class
Op
>
std
::
vector
<
int64_t
>
sort_permutation
(
const
Vector
&
data
,
Op
op
)
{
std
::
vector
<
std
::
int64_t
>
result
(
data
.
size
());
std
::
iota
(
result
.
begin
(),
result
.
end
(),
0
);
std
::
sort
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
op
(
data
[
x
],
data
[
y
]);
});
return
result
;
}
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
)
{
return
sort_permutation
(
permutation
,
std
::
less
<>
{});
}
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
)
{
return
sort_permutation
(
s
.
strides
(),
std
::
greater
<>
{});
}
struct
find_reshaper
{
auto
matcher
()
const
{
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
else
if
(
ins
->
name
()
==
"transpose"
)
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
};
struct
find_nop_reshapes
{
auto
matcher
()
const
{
auto
reshapes
=
reshaper_names
();
reshapes
.
insert
(
"transpose"
);
reshapes
.
insert
(
"slice"
);
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
)));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
p
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
};
struct
find_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
x
=
ins
;
auto
t
=
ins
;
std
::
vector
<
std
::
int64_t
>
dims
(
ins
->
get_shape
().
lens
().
size
());
std
::
iota
(
dims
.
begin
(),
dims
.
end
(),
0
);
do
{
dims
=
reorder_dims
(
get_transpose_dims
(
t
),
dims
);
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
return
;
if
(
is_no_transpose
(
dims
))
{
if
(
is_transpose_output
(
ins
))
continue
;
auto
x
=
ins
;
auto
t
=
ins
;
do
{
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
else
{
p
.
replace_instruction
(
ins
,
op
::
transpose
{{
dims
}},
t
->
inputs
().
front
());
}
}
};
struct
find_concat_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
same_input_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
assert
(
s
.
transposed
());
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
permutation
=
find_permutation
(
s
);
auto
ipermutation
=
invert_permutation
(
permutation
);
op
.
axis
=
ipermutation
[
op
.
axis
];
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
if
(
i
->
name
()
==
"transpose"
and
i
->
inputs
().
front
()
->
get_shape
().
standard
())
return
i
->
inputs
().
front
();
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutation
},
concat
);
assert
(
ins
->
get_shape
().
lens
()
==
t
->
get_shape
().
lens
());
p
.
replace_instruction
(
ins
,
t
);
}
};
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
match
::
find_matches
(
p
,
ins
,
find_nop_reshapes
{},
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
}
}
...
...
src/targets/gpu/device/concat.cpp
View file @
bc80dee8
...
...
@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto
&&
arg
=
args
[
j
];
std
::
size_t
nelements
=
arg
.
get_shape
().
elements
();
auto
offset
=
offsets
[
j
];
hip_visit_all
(
args
.
back
(),
arg
)([
&
](
auto
output
,
auto
input
)
{
shape
arg_shape
{
arg
.
get_shape
().
type
(),
arg
.
get_shape
().
lens
()};
hip_visit_all
(
args
.
back
(),
arg
,
arg_shape
)([
&
](
auto
output
,
auto
input
,
auto
input_shape
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
idx
=
output
.
get_shape
().
index
(
input
.
get_shape
().
multi
(
i
));
output
.
data
()[
idx
+
offset
]
=
input
.
data
()[
i
];
auto
input_idx
=
input_shape
.
multi
(
i
);
auto
idx
=
output
.
get_shape
().
index
(
input_idx
);
output
.
data
()[
idx
+
offset
]
=
input
[
input_idx
];
});
});
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
bc80dee8
...
...
@@ -200,12 +200,33 @@ struct hip_add_relu
}
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
// Ensure the first arguments is the standard one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
args
.
front
());
}
struct
find_add_relu
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
)).
bind
(
"add"
)));
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
()))
.
bind
(
"add"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
...
@@ -213,6 +234,9 @@ struct find_add_relu
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
if
(
add_ins
->
name
()
==
"gpu::add"
)
...
...
@@ -226,24 +250,26 @@ struct find_triadd
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
any
().
bind
(
"input"
)));
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
any
(
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())).
bind
(
"input"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
assert
(
add_ins
!=
input_ins
);
auto
is_broadcasted
=
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
};
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
return
;
args
.
insert
(
args
.
begin
(),
input_ins
);
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
);
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
}
...
...
@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const
// clang-format off
match
::
find_matches
(
p
,
find_triadd
{});
match
::
find_matches
(
p
,
//
find_conv_bias_relu{ctx},
//
find_conv_bias{ctx},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_add_relu
{}
);
// clang-format on
...
...
src/targets/gpu/target.cpp
View file @
bc80dee8
...
...
@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off
return
{
dead_code_elimination
{},
simplify_reshapes
{},
dead_code_elimination
{},
eliminate_identity
{},
eliminate_pad
{},
...
...
@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{},
simplify_algebra
{},
dead_code_elimination
{},
propagate_constant
{},
dead_code_elimination
{},
auto_contiguous
{},
simplify_reshapes
{},
dead_code_elimination
{},
propagate_constant
{},
dead_code_elimination
{},
lowering
{
ctx
},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
...
...
src/tf/tf.cpp
View file @
bc80dee8
...
...
@@ -37,6 +37,48 @@ struct tf_parser
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
bool
should_transpose
(
instruction_ref
ins
)
const
{
return
is_nhwc
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
}
instruction_ref
to_nhwc
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
2
,
3
,
1
}},
ins
);
return
ins
;
}
instruction_ref
to_nchw
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
ins
);
return
ins
;
}
instruction_ref
to_kcxy
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
ins
);
return
ins
;
}
instruction_ref
make_contiguous
(
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
return
ins
;
else
return
prog
.
add_instruction
(
op
::
contiguous
{},
ins
);
}
std
::
vector
<
instruction_ref
>
to_nchw
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
instruction_ref
>
result
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
result
.
begin
(),
[
&
](
auto
ins
)
{
return
this
->
to_nchw
(
ins
);
});
return
result
;
}
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
)
const
{
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
...
...
@@ -119,59 +161,67 @@ struct tf_parser
add_mem_op
(
"AvgPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
ops
.
emplace
(
name
,
f
);
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
if
(
transpose
)
{
ops
.
emplace
(
name
,
op_func
{[
=
](
const
attribute_map
&
attributes
,
const
std
::
vector
<
instruction_ref
>&
args
)
->
instruction_ref
{
return
to_nhwc
(
f
(
attributes
,
to_nchw
(
args
)));
}});
}
else
{
ops
.
emplace
(
name
,
f
);
}
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
transpose
);
}
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
auto
l0
=
args
[
1
];
if
(
contains
(
attributes
,
"data_format"
))
{
if
(
is_nhwc
)
{
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
args
[
1
]);
}
}
return
add_broadcastable_binary_op
(
args
[
0
],
l0
,
x
);
});
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
// TODO
// if(contains(attributes, "data_format"))
// {
// if(is_nhwc)
// {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
// }
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
},
false
);
}
template
<
class
T
>
...
...
@@ -210,20 +260,22 @@ struct tf_parser
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
to_nchw
(
l0
)
,
to_nchw
(
l1
))
);
}
else
{
return
prog
.
add_instruction
(
x
,
{
arg0
,
arg1
}
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
{
to_nchw
(
arg0
)
,
to_nchw
(
arg1
)})
);
}
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
},
transpose
);
}
instruction_ref
...
...
@@ -253,7 +305,7 @@ struct tf_parser
{
// get index for axis within args
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
size_t
axis
=
parse_axis
(
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
()
)
;
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
op
::
concat
op
{
axis
};
// return only first N arguments (assuming last index is the axis value)
return
prog
.
add_instruction
(
...
...
@@ -264,16 +316,8 @@ struct tf_parser
attribute_map
attributes
,
const
std
::
vector
<
instruction_ref
>&
)
{
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
auto
l0
=
prog
.
add_literal
(
v
);
size_t
num_axes
=
l0
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
std
::
vector
<
int64_t
>
transpose_axes
=
get_axes
(
num_axes
);
reorder_data
(
transpose_axes
);
l0
=
prog
.
add_instruction
(
op
::
transpose
{
transpose_axes
},
l0
);
}
return
l0
;
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
return
prog
.
add_literal
(
v
);
}
instruction_ref
...
...
@@ -304,22 +348,9 @@ struct tf_parser
op
.
dilation
[
0
]
=
dilation
[
2
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
auto
l0
=
args
[
0
];
auto
weights
=
to_kcxy
(
args
[
1
]);
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
...
...
@@ -368,8 +399,7 @@ struct tf_parser
op
.
padding
[
1
]
=
padding
[
1
];
}
}
return
prog
.
add_instruction
(
op
,
{
l0
,
weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
to_kcxy
(
args
[
1
])});
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
...
...
@@ -392,6 +422,8 @@ struct tf_parser
op
.
stride
[
0
]
=
stride
[
2
];
op
.
stride
[
1
]
=
stride
[
3
];
}
auto
weights
=
to_kcxy
(
args
[
1
]);
if
(
contains
(
attributes
,
"dilations"
))
{
std
::
vector
<
size_t
>
dilation
;
...
...
@@ -405,20 +437,6 @@ struct tf_parser
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
...
...
@@ -466,8 +484,8 @@ struct tf_parser
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
// Make sure weights are contiguous before doing reshape
auto
c
weights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
c
weights
);
auto
new_
weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
make_contiguous
(
weights
)
)
;
return
prog
.
add_instruction
(
op
,
{
l0
,
new_weights
});
}
...
...
@@ -535,16 +553,14 @@ struct tf_parser
MIGRAPHX_THROW
(
"TF_PARSER: axis value of "
+
to_string
(
axis
)
+
" must be smaller than input size "
+
to_string
(
input_size
));
}
// check if input arg needs axis to be converted to NCHW
if
(
input_size
>=
4
)
axis
=
parse_axis
(
axis
);
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
instruction_ref
arg
)
{
return
prog
.
add_instruction
(
op
::
unsqueeze
{{
axis
}},
arg
);
});
return
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
);
return
to_nhwc
(
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
));
}
instruction_ref
...
...
@@ -647,7 +663,7 @@ struct tf_parser
MIGRAPHX_THROW
(
"reshape needs 2 arguments (input, new_shape)"
);
auto
s
=
args
[
1
]
->
eval
();
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
void
parse_from
(
std
::
istream
&
is
)
...
...
@@ -678,7 +694,7 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
args
)
{
op
::
squeeze
op
;
auto
axes
=
parse_axes
(
attributes
,
"squeeze_dims"
);
auto
axes
=
attributes
.
at
(
"squeeze_dims"
)
.
list
().
i
()
;
copy
(
axes
,
std
::
back_inserter
(
op
.
axes
));
auto
args0_dims
=
args
[
0
]
->
get_shape
().
lens
();
if
(
op
.
axes
.
empty
())
// no squeeze_dims provided, remove any dim that equals 1
...
...
@@ -691,7 +707,7 @@ struct tf_parser
}
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
instruction_ref
parse_stridedslice
(
const
std
::
string
&
,
...
...
@@ -702,11 +718,6 @@ struct tf_parser
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
size_t
num_axes
=
args
[
0
]
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
reorder_data
(
starts
);
reorder_data
(
ends
);
}
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
ends
.
begin
(),
ends
.
end
());
...
...
@@ -725,13 +736,9 @@ struct tf_parser
if
(((
shrink_axis_mask
>>
i
)
&
bitwise_compare
)
==
1
)
squeeze_axes
.
push_back
(
i
);
}
if
(
num_axes
>=
4
)
{
squeeze_axes
=
parse_axes
(
squeeze_axes
);
}
auto
l0
=
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
);
auto
l0
=
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
return
to_nhwc
(
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
)
)
;
}
void
parse_graph
(
const
tensorflow
::
GraphDef
&
graph
)
...
...
@@ -748,7 +755,7 @@ struct tf_parser
reorder_data
(
dims
);
}
shape
s
=
shape
{
shape_type
,
dims
};
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
to_nhwc
(
prog
.
add_parameter
(
name
,
s
)
)
;
}
for
(
auto
&&
p
:
nodes
)
{
...
...
@@ -1098,6 +1105,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else
parser
.
parse_from
(
input
);
#endif
parser
.
to_nchw
(
std
::
prev
(
parser
.
prog
.
end
()));
return
std
::
move
(
parser
.
prog
);
}
...
...
test/matcher.cpp
View file @
bc80dee8
...
...
@@ -148,6 +148,56 @@ TEST_CASE(match_arg7)
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_arg8
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
all_of
(
match
::
arg
(
0
)(
match
::
name
(
"@literal"
)),
match
::
arg
(
1
)(
match
::
name
(
"@literal"
))),
match
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_nargs1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
nargs
(
2
));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_nargs2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
nargs
(
2
),
match
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_nargs3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
all_of
(
match
::
nargs
(
2
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_args1
)
{
migraphx
::
program
p
;
...
...
@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_all_of3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
all_of
(
match
::
all_of
(
match
::
arg
(
0
)(
match
::
name
(
"@literal"
)),
match
::
arg
(
1
)(
match
::
name
(
"@literal"
)))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_any_of1
)
{
migraphx
::
program
p
;
...
...
@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass1
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
minus_pass2
=
p
.
add_instruction
(
pass_op
{},
minus_pass1
);
auto
minus_pass3
=
p
.
add_instruction
(
pass_op
{},
minus_pass2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass3
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
two
});
}
TEST_CASE
(
match_skip_output5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output6
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output7
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus1
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus2
=
p
.
add_instruction
(
minus_op
{},
two
,
minus1
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
minus2
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"minus"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus1
});
}
TEST_CASE
(
match_bind1
)
{
migraphx
::
program
p
;
...
...
test/shape_test.cpp
View file @
bc80dee8
...
...
@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_transposed
)
TEST_CASE
(
test_shape_transposed
1
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
EXPECT
(
not
s
.
standard
());
...
...
@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_transposed2
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
,
2
},
{
2
,
2
,
2
,
2
,
1
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_broadcasted
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
...
...
test/simplify_reshapes_test.cpp
View file @
bc80dee8
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
...
...
@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous)
EXPECT
(
p
.
has_instruction
(
t
));
}
TEST_CASE
(
transpose_partial1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
2
,
0
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
transpose_partial2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
2
,
0
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
t2
);
p
.
add_instruction
(
pass_op
{},
t3
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
}
TEST_CASE
(
transpose_partial3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
2
,
0
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
t2
);
auto
t4
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
t3
);
p
.
add_instruction
(
pass_op
{},
t4
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
}
TEST_CASE
(
nop_transpose1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
x
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
nop_transpose2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t2
);
auto
t4
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t3
);
p
.
add_instruction
(
pass_op
{},
t4
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
4
);
}
TEST_CASE
(
nop_transpose3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
x
,
y
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
,
3
}},
concat
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
concat_transpose1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
2
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
3
);
}
TEST_CASE
(
concat_transpose2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/tf/tf_test.cpp
View file @
bc80dee8
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
#include "test.hpp"
migraphx
::
program
optimize_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
{
auto
prog
=
migraphx
::
parse_tf
(
name
,
is_nhwc
);
if
(
is_nhwc
)
migraphx
::
run_passes
(
prog
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
eliminate_identity
{}});
return
prog
;
}
TEST_CASE
(
add_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}});
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l0
,
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"add_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"add_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test)
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{
s0
.
lens
()},
l0
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{
s0
.
lens
()},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l2
,
l3
);
auto
prog
=
migraphx
::
pars
e_tf
(
"add_bcast_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"add_bcast_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test)
auto
l4
=
p
.
add_parameter
(
"4"
,
s0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
s0
,
const_vals
});
p
.
add_instruction
(
op
,
l0
,
l1
,
l2
,
l3
,
l4
);
auto
prog
=
migraphx
::
pars
e_tf
(
"batchnorm_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"batchnorm_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test)
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
500
}});
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
axis
,
l0
->
get_shape
().
lens
()},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l0
,
l2
);
auto
prog
=
migraphx
::
pars
e_tf
(
"biasadd_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"biasadd_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -83,7 +98,7 @@ TEST_CASE(concat_test)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
},
std
::
vector
<
int
>
{
axis
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
std
::
size_t
>
(
axis
)},
l0
,
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"concat_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"concat_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -92,7 +107,7 @@ TEST_CASE(const_test)
{
migraphx
::
program
p
;
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
std
::
vector
<
float
>
{
1.0
f
});
auto
prog
=
migraphx
::
pars
e_tf
(
"constant_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"constant_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -112,10 +127,9 @@ TEST_CASE(conv_test)
op
.
padding
=
{
1
,
1
};
op
.
stride
=
{
1
,
1
};
op
.
dilation
=
{
1
,
1
};
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
p
.
add_instruction
(
op
,
l0
,
l3
);
auto
prog
=
migraphx
::
parse_tf
(
"conv_test.pb"
,
true
);
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
3
,
2
,
0
,
1
}},
l1
);
p
.
add_instruction
(
op
,
l0
,
l2
);
auto
prog
=
optimize_tf
(
"conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test)
op
.
stride
=
{
1
,
1
};
op
.
dilation
=
{
1
,
1
};
op
.
group
=
3
;
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
3
,
2
,
0
,
1
}},
l1
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
l3
);
auto
l5
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l4
);
p
.
add_instruction
(
op
,
l0
,
l5
);
auto
prog
=
migraphx
::
pars
e_tf
(
"depthwise_conv_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"depthwise_conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -151,7 +164,7 @@ TEST_CASE(identity_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"identity_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"identity_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -166,7 +179,7 @@ TEST_CASE(matmul_test)
auto
trans_l1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
trans_l0
,
trans_l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"matmul_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"matmul_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -183,7 +196,7 @@ TEST_CASE(mean_test)
p
.
add_instruction
(
op
,
l0
);
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
,
3
}},
l3
);
auto
prog
=
migraphx
::
pars
e_tf
(
"mean_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"mean_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -193,14 +206,11 @@ TEST_CASE(mean_test_nhwc)
migraphx
::
program
p
;
migraphx
::
literal
l
{
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
}},
{
1
,
2
}};
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_literal
(
l
);
p
.
add_literal
(
l
);
migraphx
::
op
::
pooling
op
;
op
.
lengths
=
{
16
,
16
};
p
.
add_instruction
(
op
,
l0
);
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
,
3
}},
l3
);
auto
prog
=
migraphx
::
pars
e_tf
(
"mean_test_nhwc.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"mean_test_nhwc.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -212,7 +222,7 @@ TEST_CASE(mul_test)
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
l0
,
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"mul_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"mul_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -234,7 +244,7 @@ TEST_CASE(pack_test)
return
p
.
add_instruction
(
migraphx
::
op
::
unsqueeze
{{
axis
}},
arg
);
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
);
auto
prog
=
migraphx
::
pars
e_tf
(
"pack_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"pack_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -242,12 +252,15 @@ TEST_CASE(pack_test)
TEST_CASE
(
pack_test_nhwc
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
std
::
vector
<
migraphx
::
instruction_ref
>
args
{
l0
,
l1
,
l2
};
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt0
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
l0
);
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
l1
);
auto
l2
=
p
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
l2
);
std
::
vector
<
migraphx
::
instruction_ref
>
args
{
lt0
,
lt1
,
lt2
};
std
::
vector
<
migraphx
::
instruction_ref
>
unsqueezed_args
;
int64_t
nchw_axis
=
1
;
int64_t
nchw_axis
=
3
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
...
...
@@ -256,7 +269,7 @@ TEST_CASE(pack_test_nhwc)
return
p
.
add_instruction
(
migraphx
::
op
::
unsqueeze
{{
nchw_axis
}},
arg
);
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
size_t
>
(
nchw_axis
)},
unsqueezed_args
);
auto
prog
=
migraphx
::
pars
e_tf
(
"pack_test_nhwc.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"pack_test_nhwc.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -273,9 +286,9 @@ TEST_CASE(pooling_test)
max_pool_op
.
stride
=
{
2
,
2
};
avg_pool_op
.
lengths
=
{
2
,
2
};
max_pool_op
.
lengths
=
{
2
,
2
};
p
.
add_instruction
(
avg_pool_op
,
l0
);
p
.
add_instruction
(
max_pool_op
,
l0
);
auto
prog
=
migraphx
::
parse_tf
(
"pooling_test.pb"
,
true
);
// p.add_instruction(avg_pool_op, l0);
auto
prog
=
optimize_tf
(
"pooling_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -285,7 +298,7 @@ TEST_CASE(relu_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
relu
{},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"relu_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"relu_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -295,7 +308,7 @@ TEST_CASE(relu6_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
clip
{
6.0
,
0.0
},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"relu6_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"relu6_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -308,7 +321,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions
p
.
add_literal
(
migraphx
::
literal
{
s0
,
{
1
,
1
,
1
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
1
,
1
,
16
}},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"reshape_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"reshape_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -321,7 +334,7 @@ TEST_CASE(softmax_test)
auto
r
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
l0
);
auto
s
=
p
.
add_instruction
(
migraphx
::
op
::
softmax
{},
r
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
auto
prog
=
migraphx
::
pars
e_tf
(
"softmax_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"softmax_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -331,7 +344,7 @@ TEST_CASE(squeeze_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
1
}});
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
0
,
3
}},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"squeeze_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"squeeze_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -343,18 +356,13 @@ TEST_CASE(stridedslice_test)
std
::
size_t
num_axes
=
4
;
migraphx
::
op
::
slice
op
;
op
.
starts
=
{
0
,
0
,
0
,
0
};
op
.
ends
=
{
1
,
5
,
1
,
1
};
op
.
ends
=
{
1
,
1
,
1
,
5
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
// add literals for starts, ends, and strides in tf (NHWC format)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
1
,
1
,
1
,
5
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
});
auto
l1
=
p
.
add_instruction
(
op
,
l0
);
auto
shrink_axis
=
2
;
auto
shrink_axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
shrink_axis
}},
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"stridedslice_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"stridedslice_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment