Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3f0e74b4
Commit
3f0e74b4
authored
Jul 08, 2019
by
Shucai Xiao
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into reduce_mean
parents
0525939c
bc80dee8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
940 additions
and
265 deletions
+940
-265
src/driver/main.cpp
src/driver/main.cpp
+24
-0
src/include/migraphx/functional.hpp
src/include/migraphx/functional.hpp
+17
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+171
-45
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+179
-56
src/targets/gpu/device/concat.cpp
src/targets/gpu/device/concat.cpp
+5
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+40
-14
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-2
src/tf/tf.cpp
src/tf/tf.cpp
+108
-100
test/matcher.cpp
test/matcher.cpp
+189
-0
test/shape_test.cpp
test/shape_test.cpp
+10
-1
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+141
-0
test/tf/tf_test.cpp
test/tf/tf_test.cpp
+52
-44
No files found.
src/driver/main.cpp
View file @
3f0e74b4
...
...
@@ -7,6 +7,14 @@
#include <migraphx/onnx.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_reshapes.hpp>
namespace
migraphx
{
namespace
driver
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -17,6 +25,7 @@ struct loader
std
::
string
file_type
;
bool
is_nhwc
=
true
;
unsigned
trim
=
0
;
bool
optimize
=
false
;
void
parse
(
argument_parser
&
ap
)
{
...
...
@@ -26,6 +35,7 @@ struct loader
ap
(
is_nhwc
,
{
"--nhwc"
},
ap
.
help
(
"Treat tensorflow format as nhwc"
),
ap
.
set_value
(
true
));
ap
(
is_nhwc
,
{
"--nchw"
},
ap
.
help
(
"Treat tensorflow format as nchw"
),
ap
.
set_value
(
false
));
ap
(
trim
,
{
"--trim"
,
"-t"
},
ap
.
help
(
"Trim instructions from the end"
));
ap
(
optimize
,
{
"--optimize"
},
ap
.
help
(
"Optimize when reading"
),
ap
.
set_value
(
true
));
}
program
load
()
...
...
@@ -48,6 +58,20 @@ struct loader
auto
last
=
std
::
prev
(
p
.
end
(),
trim
);
p
.
remove_instructions
(
last
,
p
.
end
());
}
if
(
optimize
)
migraphx
::
run_passes
(
p
,
{
migraphx
::
eliminate_identity
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_algebra
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
propagate_constant
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
eliminate_pad
{},
migraphx
::
dead_code_elimination
{},
});
return
p
;
}
};
...
...
src/include/migraphx/functional.hpp
View file @
3f0e74b4
...
...
@@ -190,6 +190,23 @@ auto pop_back_args(Ts&&... xs)
};
}
template
<
class
T
>
struct
always_f
{
T
x
;
template
<
class
...
Ts
>
constexpr
T
operator
()(
Ts
&&
...)
const
{
return
x
;
}
};
template
<
class
T
>
auto
always
(
T
x
)
{
return
always_f
<
T
>
{
x
};
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
3f0e74b4
...
...
@@ -8,6 +8,7 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/config.hpp>
#include <unordered_map>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -20,6 +21,12 @@ struct matcher_context
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
instruction_ref
not_found
()
const
{
return
last
;
}
template
<
class
M
>
bool
matched
(
M
m
,
instruction_ref
ins
)
{
return
m
.
match
(
*
this
,
ins
)
!=
this
->
not_found
();
}
private:
instruction_ref
last
;
};
...
...
@@ -205,74 +212,147 @@ matcher_result match_instruction(program& p, instruction_ref ins, M&& m)
return
result
;
}
/// Find matches for an instruction in the program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
instruction_ref
ins
,
Ms
&&
...
ms
)
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
p
.
end
())
return
;
m
.
apply
(
p
,
r
);
match
=
true
;
},
ms
...);
}
/// Find matches in a program
template
<
class
...
Ms
>
void
find_matches
(
program
&
p
,
Ms
&&
...
ms
)
{
for
(
auto
ins
:
iterator_for
(
p
))
{
bool
match
=
false
;
each_args
(
[
&
](
auto
&&
m
)
{
if
(
match
)
return
;
auto
r
=
match_instruction
(
p
,
ins
,
m
.
matcher
());
if
(
r
.
result
==
p
.
end
())
return
;
m
.
apply
(
p
,
r
);
match
=
true
;
},
ms
...);
find_matches
(
p
,
ins
,
ms
...);
}
}
template
<
class
...
Ts
>
auto
all_of
(
Ts
...
ms
)
struct
lazy_and
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
and
g
();
}
};
template
<
class
...
Ts
>
auto
none_of
(
Ts
...
ms
)
struct
lazy_or
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
y
.
match
(
ctx
,
ins
)
==
ctx
.
not_found
();
})(
true
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
template
<
class
F
,
class
G
>
bool
operator
()(
F
f
,
G
g
)
const
{
return
f
()
or
g
();
}
};
template
<
class
Op
,
bool
Start
,
bool
Matches
>
struct
match_fold_f
{
template
<
class
...
Ms
>
static
bool
fold_matchers
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Ms
...
ms
)
{
Op
op
;
auto
matched
=
[
&
](
auto
m
)
{
return
[
=
,
&
ctx
]
{
return
ctx
.
matched
(
m
,
ins
);
};
};
return
fold
([
&
](
auto
x
,
auto
y
)
{
return
op
(
always
(
x
),
matched
(
y
));
})(
Start
,
ms
...);
}
template
<
class
Pack
>
static
bool
fold_matchers_pack
(
matcher_context
&
ctx
,
instruction_ref
ins
,
Pack
p
)
{
return
p
([
&
](
auto
...
ms
)
{
return
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
});
}
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
if
(
matches
==
Matches
)
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
Selector
>
auto
operator
[](
Selector
select
)
const
{
return
[
=
](
auto
...
ms
)
{
// Workaround ICE on gcc by packing matchers into an object
auto
mpack
=
pack
(
ms
...);
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
Op
op
;
bool
matches
=
Start
;
select
(
start
,
[
&
](
auto
ins
)
{
auto
fm
=
[
&
]
{
return
match_fold_f
::
fold_matchers_pack
(
ctx
,
ins
,
mpack
);
};
matches
=
op
(
always
(
matches
),
fm
);
});
if
(
matches
==
Matches
)
return
start
;
return
ctx
.
not_found
();
});
};
}
};
const
constexpr
auto
all_of
=
match_fold_f
<
lazy_and
,
true
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
inline
auto
inputs
()
{
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
inputs
())
f
(
x
);
};
}
template
<
class
...
Ts
>
auto
any_of
(
Ts
...
ms
)
inline
auto
outputs
()
{
return
make_bf_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
or
y
.
match
(
ctx
,
ins
)
!=
ctx
.
not_found
();
})(
false
,
ms
...);
if
(
matches
)
return
ins
;
return
ctx
.
not_found
();
});
return
[](
auto
ins
,
auto
f
)
{
for
(
auto
&&
x
:
ins
->
outputs
())
f
(
x
);
};
}
MIGRAPHX_PRED_MATCHER
(
any
,
instruction_ref
)
{
return
true
;
}
MIGRAPHX_PRED_MATCHER
(
none
,
instruction_ref
)
{
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
standard_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
not_standard_shape
,
instruction_ref
ins
)
{
return
not
ins
->
get_shape
().
standard
();
}
MIGRAPHX_PRED_MATCHER
(
broadcast_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
broadcasted
();
}
MIGRAPHX_PRED_MATCHER
(
transpose_shape
,
instruction_ref
ins
)
{
return
ins
->
get_shape
().
transposed
();
}
MIGRAPHX_PRED_MATCHER
(
same_input_shapes
,
instruction_ref
ins
)
{
if
(
ins
->
inputs
().
empty
())
return
false
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_BASIC_MATCHER
(
output
,
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
...
...
@@ -289,10 +369,39 @@ MIGRAPHX_BASIC_MATCHER(used_once, matcher_context& ctx, instruction_ref ins)
return
ctx
.
not_found
();
}
inline
auto
name
(
std
::
string
name
)
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
{
auto
m
=
any_of
(
ms
...);
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
{
return
fix
<
instruction_ref
>
([
&
](
auto
self
,
auto
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
{
auto
next
=
ins
->
outputs
().
front
();
if
(
ctx
.
matched
(
m
,
next
))
{
auto
skipped_next
=
self
(
next
);
if
(
skipped_next
!=
ctx
.
not_found
())
return
skipped_next
;
}
return
next
;
}
return
ctx
.
not_found
();
})(
start
);
});
}
inline
auto
name
(
std
::
string
s
)
{
return
make_basic_pred_matcher
(
[
=
,
name
=
std
::
move
(
name
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
name
;
});
[
=
,
s
=
std
::
move
(
s
)
](
instruction_ref
ins
)
{
return
ins
->
name
()
==
s
;
});
}
inline
auto
name
(
std
::
unordered_set
<
std
::
string
>
names
)
{
return
make_basic_pred_matcher
([
=
,
names
=
std
::
move
(
names
)
](
instruction_ref
ins
)
{
return
names
.
count
(
ins
->
name
())
>
0
;
});
}
inline
auto
nargs
(
std
::
size_t
n
)
...
...
@@ -338,6 +447,23 @@ inline auto either_arg(std::size_t i, std::size_t j)
};
}
template
<
class
M
>
auto
same_shape
(
M
m
)
{
return
make_basic_fun_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
i
=
m
.
match
(
ctx
,
ins
);
if
(
i
!=
ctx
.
not_found
()
and
i
->
get_shape
()
==
ins
->
get_shape
())
return
ins
;
return
ctx
.
not_found
();
});
}
template
<
class
...
Ms
>
auto
same_shape
(
Ms
...
ms
)
{
return
all_of
(
same_shape
(
ms
)...);
}
}
// namespace match
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/simplify_reshapes.cpp
View file @
3f0e74b4
...
...
@@ -2,14 +2,17 @@
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/as_shape.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
#include <unordered_set>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
bool
is_reshaper
(
instruction_ref
ins
)
const
auto
&
reshaper_names
(
)
{
// clang-format off
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
...
...
@@ -19,17 +22,10 @@ bool is_reshaper(instruction_ref ins)
"unsqueeze"
};
// clang-format on
return
contains
(
names
,
ins
->
name
())
;
return
names
;
}
bool
is_transpose_output
(
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
!=
1
)
return
false
;
if
(
ins
->
outputs
().
front
()
->
name
()
==
"contiguous"
)
return
is_transpose_output
(
ins
->
outputs
().
front
());
return
ins
->
outputs
().
front
()
->
name
()
==
"transpose"
;
}
bool
is_reshaper
(
instruction_ref
ins
)
{
return
contains
(
reshaper_names
(),
ins
->
name
());
}
instruction_ref
find_transpose_input
(
instruction_ref
ins
)
{
...
...
@@ -42,62 +38,189 @@ instruction_ref find_transpose_input(instruction_ref ins)
return
ins
;
}
void
simplify_reshapes
::
apply
(
program
&
p
)
co
ns
t
auto
get_transpose_dims
(
instruction_ref
i
ns
)
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
return
any_cast
<
const
op
::
transpose
&>
(
ins
->
get_operator
()).
dims
;
}
std
::
vector
<
int64_t
>
reorder_dims
(
std
::
vector
<
int64_t
>
dims
,
std
::
vector
<
int64_t
>
permutation
)
{
std
::
vector
<
int64_t
>
result
(
dims
.
size
());
assert
(
dims
.
size
()
==
permutation
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
if
(
is_reshaper
(
ins
))
result
[
i
]
=
dims
[
permutation
[
i
]];
}
return
result
;
}
bool
is_no_transpose
(
const
std
::
vector
<
int64_t
>&
dims
)
{
if
(
dims
.
empty
())
return
true
;
if
(
dims
.
front
()
!=
0
)
return
false
;
return
std
::
adjacent_find
(
dims
.
begin
(),
dims
.
end
(),
[](
auto
x
,
auto
y
)
{
return
(
y
-
x
)
!=
1
;
})
==
dims
.
end
();
}
template
<
class
Vector
,
class
Op
>
std
::
vector
<
int64_t
>
sort_permutation
(
const
Vector
&
data
,
Op
op
)
{
std
::
vector
<
std
::
int64_t
>
result
(
data
.
size
());
std
::
iota
(
result
.
begin
(),
result
.
end
(),
0
);
std
::
sort
(
result
.
begin
(),
result
.
end
(),
[
&
](
auto
x
,
auto
y
)
{
return
op
(
data
[
x
],
data
[
y
]);
});
return
result
;
}
std
::
vector
<
int64_t
>
invert_permutation
(
const
std
::
vector
<
int64_t
>&
permutation
)
{
return
sort_permutation
(
permutation
,
std
::
less
<>
{});
}
std
::
vector
<
int64_t
>
find_permutation
(
const
shape
&
s
)
{
return
sort_permutation
(
s
.
strides
(),
std
::
greater
<>
{});
}
struct
find_reshaper
{
auto
matcher
()
const
{
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
if
(
std
::
any_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
&
is_reshaper
))
continue
;
// Gather reshapes
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
else
if
(
ins
->
name
()
==
"transpose"
)
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
};
struct
find_nop_reshapes
{
auto
matcher
()
const
{
auto
reshapes
=
reshaper_names
();
reshapes
.
insert
(
"transpose"
);
reshapes
.
insert
(
"slice"
);
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
)));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
p
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
};
struct
find_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
x
=
ins
;
auto
t
=
ins
;
std
::
vector
<
std
::
int64_t
>
dims
(
ins
->
get_shape
().
lens
().
size
());
std
::
iota
(
dims
.
begin
(),
dims
.
end
(),
0
);
do
{
dims
=
reorder_dims
(
get_transpose_dims
(
t
),
dims
);
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
return
;
if
(
is_no_transpose
(
dims
))
{
if
(
is_transpose_output
(
ins
))
continue
;
auto
x
=
ins
;
auto
t
=
ins
;
do
{
x
=
t
;
t
=
find_transpose_input
(
x
);
}
while
(
x
!=
t
and
t
->
name
()
==
"transpose"
);
if
(
t
==
ins
or
t
->
name
()
!=
"transpose"
)
continue
;
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
else
{
p
.
replace_instruction
(
ins
,
op
::
transpose
{{
dims
}},
t
->
inputs
().
front
());
}
}
};
struct
find_concat_transpose
{
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
same_input_shapes
(),
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
void
apply
(
program
&
p
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
s
=
ins
->
inputs
().
front
()
->
get_shape
();
assert
(
s
.
transposed
());
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
permutation
=
find_permutation
(
s
);
auto
ipermutation
=
invert_permutation
(
permutation
);
op
.
axis
=
ipermutation
[
op
.
axis
];
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
if
(
i
->
name
()
==
"transpose"
and
i
->
inputs
().
front
()
->
get_shape
().
standard
())
return
i
->
inputs
().
front
();
return
p
.
insert_instruction
(
ins
,
op
::
transpose
{
permutation
},
i
);
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
ins
,
op
::
transpose
{
ipermutation
},
concat
);
assert
(
ins
->
get_shape
().
lens
()
==
t
->
get_shape
().
lens
());
p
.
replace_instruction
(
ins
,
t
);
}
};
void
simplify_reshapes
::
apply
(
program
&
p
)
const
{
auto
end
=
std
::
prev
(
p
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
{
if
(
ins
==
end
and
ins
->
name
()
==
"contiguous"
)
continue
;
// Skip possible dead instructions
if
(
ins
->
outputs
().
empty
()
and
ins
!=
end
)
continue
;
match
::
find_matches
(
p
,
ins
,
find_nop_reshapes
{},
find_reshaper
{},
find_transpose
{},
find_concat_transpose
{});
}
}
...
...
src/targets/gpu/device/concat.cpp
View file @
3f0e74b4
...
...
@@ -20,10 +20,12 @@ argument concat(hipStream_t stream,
auto
&&
arg
=
args
[
j
];
std
::
size_t
nelements
=
arg
.
get_shape
().
elements
();
auto
offset
=
offsets
[
j
];
hip_visit_all
(
args
.
back
(),
arg
)([
&
](
auto
output
,
auto
input
)
{
shape
arg_shape
{
arg
.
get_shape
().
type
(),
arg
.
get_shape
().
lens
()};
hip_visit_all
(
args
.
back
(),
arg
,
arg_shape
)([
&
](
auto
output
,
auto
input
,
auto
input_shape
)
{
gs_launch
(
stream
,
nelements
)([
=
](
auto
i
)
{
auto
idx
=
output
.
get_shape
().
index
(
input
.
get_shape
().
multi
(
i
));
output
.
data
()[
idx
+
offset
]
=
input
.
data
()[
i
];
auto
input_idx
=
input_shape
.
multi
(
i
);
auto
idx
=
output
.
get_shape
().
index
(
input_idx
);
output
.
data
()[
idx
+
offset
]
=
input
[
input_idx
];
});
});
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
3f0e74b4
...
...
@@ -200,12 +200,33 @@ struct hip_add_relu
}
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
// Ensure the first arguments is the standard one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
args
.
front
());
}
struct
find_add_relu
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
)).
bind
(
"add"
)));
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
()))
.
bind
(
"add"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
...
@@ -213,6 +234,9 @@ struct find_add_relu
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
if
(
add_ins
->
name
()
==
"gpu::add"
)
...
...
@@ -226,24 +250,26 @@ struct find_triadd
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
any
().
bind
(
"input"
)));
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
any
(
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())).
bind
(
"input"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
input_ins
=
r
.
instructions
[
"input"
];
auto
ins
=
r
.
result
;
auto
args
=
add_ins
->
inputs
();
assert
(
add_ins
!=
input_ins
);
auto
is_broadcasted
=
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
};
if
(
std
::
count_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
)
>
1
)
return
;
args
.
insert
(
args
.
begin
(),
input_ins
);
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
is_broadcasted
);
if
(
it
!=
args
.
end
())
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_triadd
{},
args
);
}
...
...
@@ -402,8 +428,8 @@ void fuse_ops::apply(program& p) const
// clang-format off
match
::
find_matches
(
p
,
find_triadd
{});
match
::
find_matches
(
p
,
//
find_conv_bias_relu{ctx},
//
find_conv_bias{ctx},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_add_relu
{}
);
// clang-format on
...
...
src/targets/gpu/target.cpp
View file @
3f0e74b4
...
...
@@ -36,6 +36,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
// clang-format off
return
{
dead_code_elimination
{},
simplify_reshapes
{},
dead_code_elimination
{},
eliminate_identity
{},
eliminate_pad
{},
...
...
@@ -48,11 +50,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
//dead_code_elimination{},
simplify_algebra
{},
dead_code_elimination
{},
propagate_constant
{},
dead_code_elimination
{},
auto_contiguous
{},
simplify_reshapes
{},
dead_code_elimination
{},
propagate_constant
{},
dead_code_elimination
{},
lowering
{
ctx
},
eliminate_concat
{
concat_gpu_optimization
{}},
dead_code_elimination
{},
...
...
src/tf/tf.cpp
View file @
3f0e74b4
...
...
@@ -37,6 +37,48 @@ struct tf_parser
std
::
unordered_map
<
std
::
string
,
op_func
>
ops
;
bool
should_transpose
(
instruction_ref
ins
)
const
{
return
is_nhwc
and
ins
->
get_shape
().
lens
().
size
()
==
4
;
}
instruction_ref
to_nhwc
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
2
,
3
,
1
}},
ins
);
return
ins
;
}
instruction_ref
to_nchw
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
ins
);
return
ins
;
}
instruction_ref
to_kcxy
(
instruction_ref
ins
)
{
if
(
should_transpose
(
ins
))
return
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
ins
);
return
ins
;
}
instruction_ref
make_contiguous
(
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
return
ins
;
else
return
prog
.
add_instruction
(
op
::
contiguous
{},
ins
);
}
std
::
vector
<
instruction_ref
>
to_nchw
(
const
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
instruction_ref
>
result
(
args
.
size
());
std
::
transform
(
args
.
begin
(),
args
.
end
(),
result
.
begin
(),
[
&
](
auto
ins
)
{
return
this
->
to_nchw
(
ins
);
});
return
result
;
}
std
::
vector
<
size_t
>
parse_axes
(
const
attribute_map
&
attributes
,
const
std
::
string
&
s
)
const
{
auto
attrs
=
attributes
.
at
(
s
).
list
().
i
();
...
...
@@ -119,59 +161,67 @@ struct tf_parser
add_mem_op
(
"AvgPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
add_mem_op
(
"Const"
,
&
tf_parser
::
parse_constant
);
add_mem_op
(
"Conv2D"
,
&
tf_parser
::
parse_conv
);
add_mem_op
(
"DepthwiseConv2dNative"
,
&
tf_parser
::
parse_depthwiseconv
);
add_mem_op
(
"FusedBatchNorm"
,
&
tf_parser
::
parse_batchnorm
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
}
template
<
class
F
>
void
add_op
(
std
::
string
name
,
F
f
)
void
add_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
ops
.
emplace
(
name
,
f
);
}
// Multi output op
template
<
class
F
>
void
add_multi_op
(
std
::
string
name
,
F
f
)
{
ops
.
emplace
(
name
,
f
);
if
(
transpose
)
{
ops
.
emplace
(
name
,
op_func
{[
=
](
const
attribute_map
&
attributes
,
const
std
::
vector
<
instruction_ref
>&
args
)
->
instruction_ref
{
return
to_nhwc
(
f
(
attributes
,
to_nchw
(
args
)));
}});
}
else
{
ops
.
emplace
(
name
,
f
);
}
}
template
<
class
F
>
void
add_mem_op
(
std
::
string
name
,
F
f
)
void
add_mem_op
(
std
::
string
name
,
F
f
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
});
add_op
(
name
,
[
=
](
auto
&&
...
xs
)
{
return
std
::
mem_fn
(
f
)(
*
this
,
name
,
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
},
transpose
);
}
template
<
class
T
>
void
add_binary_op
(
std
::
string
name
,
T
x
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
auto
l0
=
args
[
1
];
if
(
contains
(
attributes
,
"data_format"
))
{
if
(
is_nhwc
)
{
l0
=
prog
.
add_instruction
(
op
::
transpose
{{
0
,
3
,
1
,
2
}},
args
[
1
]);
}
}
return
add_broadcastable_binary_op
(
args
[
0
],
l0
,
x
);
});
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
if
(
args
.
size
()
!=
2
)
MIGRAPHX_THROW
(
"binary operators should have 2 operands"
);
// TODO
// if(contains(attributes, "data_format"))
// {
// if(is_nhwc)
// {
// l0 = prog.add_instruction(op::transpose{{0, 3, 1, 2}}, args[1]);
// }
// }
return
add_broadcastable_binary_op
(
args
[
0
],
args
[
1
],
x
);
},
false
);
}
template
<
class
T
>
...
...
@@ -210,20 +260,22 @@ struct tf_parser
auto
l0
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg0
);
auto
l1
=
prog
.
add_instruction
(
op
::
multibroadcast
{
output_lens
},
arg1
);
return
prog
.
add_instruction
(
x
,
l0
,
l1
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
to_nchw
(
l0
)
,
to_nchw
(
l1
))
);
}
else
{
return
prog
.
add_instruction
(
x
,
{
arg0
,
arg1
}
);
return
to_nhwc
(
prog
.
add_instruction
(
x
,
{
to_nchw
(
arg0
)
,
to_nchw
(
arg1
)})
);
}
}
template
<
class
T
>
void
add_generic_op
(
std
::
string
name
,
T
x
)
void
add_generic_op
(
std
::
string
name
,
T
x
,
bool
transpose
=
true
)
{
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
});
add_op
(
name
,
[
this
,
x
](
const
attribute_map
&
,
std
::
vector
<
instruction_ref
>
args
)
{
return
prog
.
add_instruction
(
x
,
args
);
},
transpose
);
}
instruction_ref
...
...
@@ -253,7 +305,7 @@ struct tf_parser
{
// get index for axis within args
size_t
axis_idx
=
attributes
.
at
(
"N"
).
i
();
size_t
axis
=
parse_axis
(
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
()
)
;
size_t
axis
=
args
[
axis_idx
]
->
eval
().
at
<
int64_t
>
();
op
::
concat
op
{
axis
};
// return only first N arguments (assuming last index is the axis value)
return
prog
.
add_instruction
(
...
...
@@ -264,16 +316,8 @@ struct tf_parser
attribute_map
attributes
,
const
std
::
vector
<
instruction_ref
>&
)
{
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
auto
l0
=
prog
.
add_literal
(
v
);
size_t
num_axes
=
l0
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
std
::
vector
<
int64_t
>
transpose_axes
=
get_axes
(
num_axes
);
reorder_data
(
transpose_axes
);
l0
=
prog
.
add_instruction
(
op
::
transpose
{
transpose_axes
},
l0
);
}
return
l0
;
literal
v
=
parse_tensor
(
attributes
.
at
(
"value"
).
tensor
());
return
prog
.
add_literal
(
v
);
}
instruction_ref
...
...
@@ -304,22 +348,9 @@ struct tf_parser
op
.
dilation
[
0
]
=
dilation
[
2
];
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
auto
l0
=
args
[
0
];
auto
weights
=
to_kcxy
(
args
[
1
]);
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
const
std
::
string
&
pad_mode
=
attributes
.
at
(
"padding"
).
s
();
...
...
@@ -368,8 +399,7 @@ struct tf_parser
op
.
padding
[
1
]
=
padding
[
1
];
}
}
return
prog
.
add_instruction
(
op
,
{
l0
,
weights
});
return
prog
.
add_instruction
(
op
,
{
l0
,
to_kcxy
(
args
[
1
])});
}
instruction_ref
parse_depthwiseconv
(
const
std
::
string
&
,
...
...
@@ -392,6 +422,8 @@ struct tf_parser
op
.
stride
[
0
]
=
stride
[
2
];
op
.
stride
[
1
]
=
stride
[
3
];
}
auto
weights
=
to_kcxy
(
args
[
1
]);
if
(
contains
(
attributes
,
"dilations"
))
{
std
::
vector
<
size_t
>
dilation
;
...
...
@@ -405,20 +437,6 @@ struct tf_parser
op
.
dilation
[
1
]
=
dilation
[
3
];
}
auto
weights
=
args
[
1
];
// check if weights are from a constant
if
(
weights
->
name
()
!=
"@param"
)
{
if
(
is_nhwc
)
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
1
,
3
,
0
,
2
}},
args
[
1
]);
}
else
{
weights
=
prog
.
add_instruction
(
op
::
transpose
{{
3
,
2
,
0
,
1
}},
args
[
1
]);
}
}
auto
l0
=
args
[
0
];
if
(
contains
(
attributes
,
"padding"
))
{
...
...
@@ -466,8 +484,8 @@ struct tf_parser
new_weights_shape
[
0
]
=
out_channels
;
new_weights_shape
[
1
]
=
1
;
// Make sure weights are contiguous before doing reshape
auto
c
weights
=
prog
.
add_instruction
(
op
::
contiguous
{},
weights
);
auto
new_weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
c
weights
);
auto
new_
weights
=
prog
.
add_instruction
(
op
::
reshape
{
new_weights_shape
},
make_contiguous
(
weights
)
)
;
return
prog
.
add_instruction
(
op
,
{
l0
,
new_weights
});
}
...
...
@@ -535,16 +553,14 @@ struct tf_parser
MIGRAPHX_THROW
(
"TF_PARSER: axis value of "
+
to_string
(
axis
)
+
" must be smaller than input size "
+
to_string
(
input_size
));
}
// check if input arg needs axis to be converted to NCHW
if
(
input_size
>=
4
)
axis
=
parse_axis
(
axis
);
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
unsqueezed_args
),
[
&
](
instruction_ref
arg
)
{
return
prog
.
add_instruction
(
op
::
unsqueeze
{{
axis
}},
arg
);
});
return
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
);
return
to_nhwc
(
prog
.
add_instruction
(
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
));
}
instruction_ref
...
...
@@ -647,7 +663,7 @@ struct tf_parser
MIGRAPHX_THROW
(
"reshape needs 2 arguments (input, new_shape)"
);
auto
s
=
args
[
1
]
->
eval
();
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
void
parse_from
(
std
::
istream
&
is
)
...
...
@@ -678,7 +694,7 @@ struct tf_parser
std
::
vector
<
instruction_ref
>
args
)
{
op
::
squeeze
op
;
auto
axes
=
parse_axes
(
attributes
,
"squeeze_dims"
);
auto
axes
=
attributes
.
at
(
"squeeze_dims"
)
.
list
().
i
()
;
copy
(
axes
,
std
::
back_inserter
(
op
.
axes
));
auto
args0_dims
=
args
[
0
]
->
get_shape
().
lens
();
if
(
op
.
axes
.
empty
())
// no squeeze_dims provided, remove any dim that equals 1
...
...
@@ -691,7 +707,7 @@ struct tf_parser
}
}
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
}
instruction_ref
parse_stridedslice
(
const
std
::
string
&
,
...
...
@@ -702,11 +718,6 @@ struct tf_parser
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
size_t
num_axes
=
args
[
0
]
->
get_shape
().
lens
().
size
();
if
(
num_axes
>=
4
)
{
reorder_data
(
starts
);
reorder_data
(
ends
);
}
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
ends
.
begin
(),
ends
.
end
());
...
...
@@ -725,13 +736,9 @@ struct tf_parser
if
(((
shrink_axis_mask
>>
i
)
&
bitwise_compare
)
==
1
)
squeeze_axes
.
push_back
(
i
);
}
if
(
num_axes
>=
4
)
{
squeeze_axes
=
parse_axes
(
squeeze_axes
);
}
auto
l0
=
prog
.
add_instruction
(
op
,
args
[
0
]);
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
);
auto
l0
=
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
])
)
;
return
to_nhwc
(
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
)
)
;
}
void
parse_graph
(
const
tensorflow
::
GraphDef
&
graph
)
...
...
@@ -748,7 +755,7 @@ struct tf_parser
reorder_data
(
dims
);
}
shape
s
=
shape
{
shape_type
,
dims
};
instructions
[
name
]
=
prog
.
add_parameter
(
name
,
s
);
instructions
[
name
]
=
to_nhwc
(
prog
.
add_parameter
(
name
,
s
)
)
;
}
for
(
auto
&&
p
:
nodes
)
{
...
...
@@ -1098,6 +1105,7 @@ program parse_tf(const std::string& name, bool is_nhwc)
#else
parser
.
parse_from
(
input
);
#endif
parser
.
to_nchw
(
std
::
prev
(
parser
.
prog
.
end
()));
return
std
::
move
(
parser
.
prog
);
}
...
...
test/matcher.cpp
View file @
3f0e74b4
...
...
@@ -148,6 +148,56 @@ TEST_CASE(match_arg7)
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_arg8
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
all_of
(
match
::
arg
(
0
)(
match
::
name
(
"@literal"
)),
match
::
arg
(
1
)(
match
::
name
(
"@literal"
))),
match
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_nargs1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
nargs
(
2
));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_nargs2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
nargs
(
2
),
match
::
standard_shape
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_nargs3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
all_of
(
match
::
nargs
(
2
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_args1
)
{
migraphx
::
program
p
;
...
...
@@ -307,6 +357,19 @@ TEST_CASE(match_all_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_all_of3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
all_of
(
match
::
all_of
(
match
::
arg
(
0
)(
match
::
name
(
"@literal"
)),
match
::
arg
(
1
)(
match
::
name
(
"@literal"
)))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
}
TEST_CASE
(
match_any_of1
)
{
migraphx
::
program
p
;
...
...
@@ -359,6 +422,132 @@ TEST_CASE(match_none_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
output
(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus_pass1
=
p
.
add_instruction
(
pass_op
{},
minus
);
auto
minus_pass2
=
p
.
add_instruction
(
pass_op
{},
minus_pass1
);
auto
minus_pass3
=
p
.
add_instruction
(
pass_op
{},
minus_pass2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
minus_pass3
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
two
});
}
TEST_CASE
(
match_skip_output5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
pass
=
p
.
add_instruction
(
pass_op
{},
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
pass
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"@literal"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_skip_output6
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
minus
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
one
);
auto
sum3
=
p
.
add_instruction
(
sum_op
{},
sum2
,
two
);
p
.
add_instruction
(
pass_op
{},
sum3
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"sum"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus
});
}
TEST_CASE
(
match_skip_output7
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
minus1
=
p
.
add_instruction
(
minus_op
{},
two
,
one
);
auto
minus2
=
p
.
add_instruction
(
minus_op
{},
two
,
minus1
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
minus2
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"minus"
)(
match
::
skip_output
(
match
::
name
(
"pass"
))(
match
::
name
(
"minus"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
minus1
});
}
TEST_CASE
(
match_bind1
)
{
migraphx
::
program
p
;
...
...
test/shape_test.cpp
View file @
3f0e74b4
...
...
@@ -38,7 +38,7 @@ TEST_CASE(test_shape_packed)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_transposed
)
TEST_CASE
(
test_shape_transposed
1
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
EXPECT
(
not
s
.
standard
());
...
...
@@ -47,6 +47,15 @@ TEST_CASE(test_shape_transposed)
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_transposed2
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
1
,
2
},
{
2
,
2
,
2
,
2
,
1
}};
EXPECT
(
s
.
standard
());
EXPECT
(
s
.
packed
());
EXPECT
(
not
s
.
transposed
());
EXPECT
(
not
s
.
broadcasted
());
}
TEST_CASE
(
test_shape_broadcasted
)
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
0
}};
...
...
test/simplify_reshapes_test.cpp
View file @
3f0e74b4
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
...
...
@@ -165,4 +166,144 @@ TEST_CASE(transpose_double_contiguous)
EXPECT
(
p
.
has_instruction
(
t
));
}
TEST_CASE
(
transpose_partial1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
2
,
0
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
transpose_partial2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
2
,
0
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
t2
);
p
.
add_instruction
(
pass_op
{},
t3
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
}
TEST_CASE
(
transpose_partial3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
2
,
0
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
t2
);
auto
t4
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
,
2
}},
t3
);
p
.
add_instruction
(
pass_op
{},
t4
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
}
TEST_CASE
(
nop_transpose1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
x
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
nop_transpose2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
x
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t1
);
auto
t3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t2
);
auto
t4
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
}},
t3
);
p
.
add_instruction
(
pass_op
{},
t4
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
4
);
}
TEST_CASE
(
nop_transpose3
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
x
,
y
);
auto
t1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
2
,
3
}},
concat
);
auto
t2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
t1
);
p
.
add_instruction
(
pass_op
{},
t2
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
()
==
out_shape
);
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
1
);
}
TEST_CASE
(
concat_transpose1
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
2
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
1
,
3
,
2
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
3
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
3
);
}
TEST_CASE
(
concat_transpose2
)
{
migraphx
::
program
p
;
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
y
=
p
.
add_parameter
(
"y"
,
s
);
auto
xt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
x
);
auto
yt
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
y
);
auto
concat
=
p
.
add_instruction
(
migraphx
::
op
::
concat
{
3
},
xt
,
yt
);
auto
t
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
concat
);
p
.
add_instruction
(
pass_op
{},
t
);
auto
out_shape
=
p
.
get_shape
();
auto
n
=
std
::
distance
(
p
.
begin
(),
p
.
end
());
p
.
compile
(
simplify_reshapes_target
{});
EXPECT
(
p
.
get_shape
().
lens
()
==
out_shape
.
lens
());
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
n
-
2
);
auto
new_concat
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
ins
)
{
return
ins
.
name
()
==
"concat"
;
});
EXPECT
(
bool
{
new_concat
!=
p
.
end
()});
EXPECT
(
migraphx
::
any_cast
<
migraphx
::
op
::
concat
>
(
new_concat
->
get_operator
()).
axis
==
1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/tf/tf_test.cpp
View file @
3f0e74b4
#include <iostream>
#include <vector>
#include <migraphx/literal.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/tf.hpp>
#include "test.hpp"
migraphx
::
program
optimize_tf
(
const
std
::
string
&
name
,
bool
is_nhwc
)
{
auto
prog
=
migraphx
::
parse_tf
(
name
,
is_nhwc
);
if
(
is_nhwc
)
migraphx
::
run_passes
(
prog
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{},
migraphx
::
eliminate_identity
{}});
return
prog
;
}
TEST_CASE
(
add_test
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}});
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l0
,
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"add_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"add_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -28,7 +43,7 @@ TEST_CASE(add_bcast_test)
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{
s0
.
lens
()},
l0
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{
s0
.
lens
()},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l2
,
l3
);
auto
prog
=
migraphx
::
pars
e_tf
(
"add_bcast_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"add_bcast_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -51,7 +66,7 @@ TEST_CASE(batchnorm_test)
auto
l4
=
p
.
add_parameter
(
"4"
,
s0
);
auto
l1
=
p
.
add_literal
(
migraphx
::
literal
{
s0
,
const_vals
});
p
.
add_instruction
(
op
,
l0
,
l1
,
l2
,
l3
,
l4
);
auto
prog
=
migraphx
::
pars
e_tf
(
"batchnorm_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"batchnorm_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -65,7 +80,7 @@ TEST_CASE(biasadd_test)
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
500
}});
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
axis
,
l0
->
get_shape
().
lens
()},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
l0
,
l2
);
auto
prog
=
migraphx
::
pars
e_tf
(
"biasadd_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"biasadd_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -83,7 +98,7 @@ TEST_CASE(concat_test)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
},
std
::
vector
<
int
>
{
axis
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
std
::
size_t
>
(
axis
)},
l0
,
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"concat_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"concat_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -92,7 +107,7 @@ TEST_CASE(const_test)
{
migraphx
::
program
p
;
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
std
::
vector
<
float
>
{
1.0
f
});
auto
prog
=
migraphx
::
pars
e_tf
(
"constant_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"constant_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -112,10 +127,9 @@ TEST_CASE(conv_test)
op
.
padding
=
{
1
,
1
};
op
.
stride
=
{
1
,
1
};
op
.
dilation
=
{
1
,
1
};
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
p
.
add_instruction
(
op
,
l0
,
l3
);
auto
prog
=
migraphx
::
parse_tf
(
"conv_test.pb"
,
true
);
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
3
,
2
,
0
,
1
}},
l1
);
p
.
add_instruction
(
op
,
l0
,
l2
);
auto
prog
=
optimize_tf
(
"conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -136,12 +150,11 @@ TEST_CASE(depthwiseconv_test)
op
.
stride
=
{
1
,
1
};
op
.
dilation
=
{
1
,
1
};
op
.
group
=
3
;
auto
l2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
3
,
1
,
2
}},
l1
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
3
,
0
,
2
}},
l2
);
auto
l3
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
3
,
2
,
0
,
1
}},
l1
);
auto
l4
=
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
l3
);
auto
l5
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
3
,
1
,
3
,
3
}},
l4
);
p
.
add_instruction
(
op
,
l0
,
l5
);
auto
prog
=
migraphx
::
pars
e_tf
(
"depthwise_conv_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"depthwise_conv_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -151,7 +164,7 @@ TEST_CASE(identity_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
identity
{},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"identity_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"identity_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -166,7 +179,7 @@ TEST_CASE(matmul_test)
auto
trans_l1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
1
,
0
}},
l1
);
p
.
add_instruction
(
migraphx
::
op
::
dot
{},
trans_l0
,
trans_l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"matmul_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"matmul_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -183,7 +196,7 @@ TEST_CASE(mean_test)
p
.
add_instruction
(
op
,
l0
);
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
,
3
}},
l3
);
auto
prog
=
migraphx
::
pars
e_tf
(
"mean_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"mean_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -193,14 +206,11 @@ TEST_CASE(mean_test_nhwc)
migraphx
::
program
p
;
migraphx
::
literal
l
{
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
}},
{
1
,
2
}};
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_literal
(
l
);
p
.
add_literal
(
l
);
migraphx
::
op
::
pooling
op
;
op
.
lengths
=
{
16
,
16
};
p
.
add_instruction
(
op
,
l0
);
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
auto
l3
=
p
.
add_instruction
(
op
,
l0
);
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
2
,
3
}},
l3
);
auto
prog
=
migraphx
::
pars
e_tf
(
"mean_test_nhwc.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"mean_test_nhwc.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -212,7 +222,7 @@ TEST_CASE(mul_test)
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
1
,
1
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
l0
,
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"mul_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"mul_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -234,7 +244,7 @@ TEST_CASE(pack_test)
return
p
.
add_instruction
(
migraphx
::
op
::
unsqueeze
{{
axis
}},
arg
);
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
size_t
>
(
axis
)},
unsqueezed_args
);
auto
prog
=
migraphx
::
pars
e_tf
(
"pack_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"pack_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -242,12 +252,15 @@ TEST_CASE(pack_test)
TEST_CASE
(
pack_test_nhwc
)
{
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
l2
=
p
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
std
::
vector
<
migraphx
::
instruction_ref
>
args
{
l0
,
l1
,
l2
};
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt0
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
l0
);
auto
l1
=
p
.
add_parameter
(
"1"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt1
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
l1
);
auto
l2
=
p
.
add_parameter
(
"2"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
1
,
1
}});
auto
lt2
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
3
,
1
}},
l2
);
std
::
vector
<
migraphx
::
instruction_ref
>
args
{
lt0
,
lt1
,
lt2
};
std
::
vector
<
migraphx
::
instruction_ref
>
unsqueezed_args
;
int64_t
nchw_axis
=
1
;
int64_t
nchw_axis
=
3
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
...
...
@@ -256,7 +269,7 @@ TEST_CASE(pack_test_nhwc)
return
p
.
add_instruction
(
migraphx
::
op
::
unsqueeze
{{
nchw_axis
}},
arg
);
});
p
.
add_instruction
(
migraphx
::
op
::
concat
{
static_cast
<
size_t
>
(
nchw_axis
)},
unsqueezed_args
);
auto
prog
=
migraphx
::
pars
e_tf
(
"pack_test_nhwc.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"pack_test_nhwc.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -273,9 +286,9 @@ TEST_CASE(pooling_test)
max_pool_op
.
stride
=
{
2
,
2
};
avg_pool_op
.
lengths
=
{
2
,
2
};
max_pool_op
.
lengths
=
{
2
,
2
};
p
.
add_instruction
(
avg_pool_op
,
l0
);
p
.
add_instruction
(
max_pool_op
,
l0
);
auto
prog
=
migraphx
::
parse_tf
(
"pooling_test.pb"
,
true
);
// p.add_instruction(avg_pool_op, l0);
auto
prog
=
optimize_tf
(
"pooling_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -285,7 +298,7 @@ TEST_CASE(relu_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
relu
{},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"relu_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"relu_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -295,7 +308,7 @@ TEST_CASE(relu6_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
3
,
16
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
clip
{
6.0
,
0.0
},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"relu6_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"relu6_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -308,7 +321,7 @@ TEST_CASE(reshape_test)
// in tf, the second arg is a literal that contains new dimensions
p
.
add_literal
(
migraphx
::
literal
{
s0
,
{
1
,
1
,
1
,
16
}});
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
1
,
1
,
1
,
16
}},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"reshape_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"reshape_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -321,7 +334,7 @@ TEST_CASE(softmax_test)
auto
r
=
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
]),
1
,
1
}},
l0
);
auto
s
=
p
.
add_instruction
(
migraphx
::
op
::
softmax
{},
r
);
p
.
add_instruction
(
migraphx
::
op
::
reshape
{{
long
(
dims
[
0
]),
long
(
dims
[
1
])}},
s
);
auto
prog
=
migraphx
::
pars
e_tf
(
"softmax_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"softmax_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -331,7 +344,7 @@ TEST_CASE(squeeze_test)
migraphx
::
program
p
;
auto
l0
=
p
.
add_parameter
(
"0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
3
,
1
}});
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
0
,
3
}},
l0
);
auto
prog
=
migraphx
::
pars
e_tf
(
"squeeze_test.pb"
,
false
);
auto
prog
=
optimiz
e_tf
(
"squeeze_test.pb"
,
false
);
EXPECT
(
p
==
prog
);
}
...
...
@@ -343,18 +356,13 @@ TEST_CASE(stridedslice_test)
std
::
size_t
num_axes
=
4
;
migraphx
::
op
::
slice
op
;
op
.
starts
=
{
0
,
0
,
0
,
0
};
op
.
ends
=
{
1
,
5
,
1
,
1
};
op
.
ends
=
{
1
,
1
,
1
,
5
};
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
// add literals for starts, ends, and strides in tf (NHWC format)
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
1
,
1
,
1
,
5
});
p
.
add_literal
(
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
4
}},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
});
auto
l1
=
p
.
add_instruction
(
op
,
l0
);
auto
shrink_axis
=
2
;
auto
shrink_axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
squeeze
{{
shrink_axis
}},
l1
);
auto
prog
=
migraphx
::
pars
e_tf
(
"stridedslice_test.pb"
,
true
);
auto
prog
=
optimiz
e_tf
(
"stridedslice_test.pb"
,
true
);
EXPECT
(
p
==
prog
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment