Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5e1bb505
Unverified
Commit
5e1bb505
authored
Aug 26, 2019
by
mvermeulen
Committed by
GitHub
Aug 26, 2019
Browse files
Merge branch 'develop' into round_operator
parents
e8b8acf8
4085af9b
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
934 additions
and
75 deletions
+934
-75
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+72
-1
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+2
-2
src/rewrite_batchnorm.cpp
src/rewrite_batchnorm.cpp
+57
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+131
-13
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+10
-0
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+200
-31
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+21
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+107
-14
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
+6
-0
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
+25
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-4
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+18
-0
test/matcher.cpp
test/matcher.cpp
+198
-0
test/rewrite_batchnorm_test.cpp
test/rewrite_batchnorm_test.cpp
+6
-6
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+75
-3
No files found.
src/CMakeLists.txt
View file @
5e1bb505
...
@@ -12,7 +12,7 @@ add_library(migraphx
...
@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_concat.cpp
eliminate_identity.cpp
eliminate_identity.cpp
eliminate_pad.cpp
eliminate_pad.cpp
fwd_conv
_batchnorm
_rewrite
.cpp
rewrite
_batchnorm.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
env.cpp
env.cpp
...
...
src/include/migraphx/matcher.hpp
View file @
5e1bb505
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
)
;
ctx
.
instructions
[
name
]
=
ins
;
return
result
;
return
result
;
});
});
}
}
...
@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
}
}
}
}
template
<
class
M
>
struct
find_skip
{
M
m
;
M
matcher
()
const
{
return
m
;
}
void
apply
(
program
&
,
const
matcher_result
&
)
const
{}
};
template
<
class
M
>
find_skip
<
M
>
make_find_skip
(
M
m
)
{
return
{
m
};
}
struct
lazy_and
struct
lazy_and
{
{
template
<
class
F
,
class
G
>
template
<
class
F
,
class
G
>
...
@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
...
@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
template
<
class
...
Ms
>
auto
skip_matches
(
Ms
...
ms
)
{
return
make_find_skip
(
any_of
(
ms
...));
}
inline
auto
inputs
()
inline
auto
inputs
()
{
{
return
[](
auto
ins
,
auto
f
)
{
return
[](
auto
ins
,
auto
f
)
{
...
@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
...
@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
inline
auto
used_once_recursive
(
std
::
size_t
depth
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
start
)
{
// Used once
if
(
start
->
outputs
().
size
()
==
1
)
return
start
;
// Unused
if
(
start
->
outputs
().
empty
())
{
if
(
std
::
next
(
start
)
==
ctx
.
not_found
())
return
start
;
else
return
ctx
.
not_found
();
}
// Check for dead instructions
auto
is_dead
=
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
,
auto
n
)
{
if
(
n
==
0
)
return
false
;
if
(
ins
->
get_shape
().
elements
()
==
0
)
return
false
;
if
(
ins
->
outputs
().
empty
()
and
std
::
next
(
ins
)
!=
ctx
.
not_found
())
return
true
;
return
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
self
(
i
,
n
-
1
);
});
});
auto
dead
=
std
::
count_if
(
start
->
outputs
().
begin
(),
start
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
is_dead
(
i
,
depth
);
});
if
(
dead
+
1
==
start
->
outputs
().
size
())
return
start
;
return
ctx
.
not_found
();
});
}
MIGRAPHX_PRED_MATCHER
(
is_constant
,
instruction_ref
ins
)
{
return
ins
->
can_eval
();
}
MIGRAPHX_BASIC_MATCHER
(
is_unused
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
std
::
prev
(
ctx
.
not_found
()))
return
ins
;
return
ctx
.
not_found
();
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
auto
skip_output
(
Ms
...
ms
)
{
{
...
@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
...
@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
});
});
}
}
template
<
class
...
Ts
>
inline
auto
name
(
std
::
string
s
,
Ts
...
xs
)
// NOLINT
{
return
name
(
std
::
unordered_set
<
std
::
string
>
{
std
::
move
(
s
),
std
::
move
(
xs
)...});
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
{
{
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
...
...
src/include/migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp
→
src/include/migraphx/
rewrite
_batchnorm.hpp
View file @
5e1bb505
...
@@ -13,9 +13,9 @@ struct program;
...
@@ -13,9 +13,9 @@ struct program;
/**
/**
* Rewrite batchnorm to a multiply and add.
* Rewrite batchnorm to a multiply and add.
*/
*/
struct
fwd_conv
_batchnorm
_rewrite
struct
rewrite
_batchnorm
{
{
std
::
string
name
()
const
{
return
"
fwd_conv
_batchnorm
_rewrite
"
;
}
std
::
string
name
()
const
{
return
"
rewrite
_batchnorm"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/
fwd_conv
_batchnorm
_rewrite
.cpp
→
src/
rewrite
_batchnorm.cpp
View file @
5e1bb505
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
...
@@ -11,7 +12,7 @@
...
@@ -11,7 +12,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
fwd_conv
_batchnorm
_rewrite
::
apply
(
program
&
p
)
const
void
rewrite
_batchnorm
::
apply
(
program
&
p
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
...
@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
s
=
shape
{
ins
->
get_shape
().
type
(),
{
ins
->
get_shape
().
lens
()[
1
]}};
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
argument
a
{
s
};
auto
weights_lens
=
weights
.
get_shape
().
lens
();
argument
b
{
s
};
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
visit_all
(
gamma
,
bias
,
mean
,
variance
,
a
,
b
)(
argument
new_weights
{
weights
.
get_shape
()};
[
&
](
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
a2
,
auto
b2
)
{
argument
new_bias
{{
bias
.
get_shape
().
type
(),
{
bias
.
get_shape
().
elements
()}}};
dfor
(
a
.
get_shape
().
elements
())(
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
std
::
size_t
c
)
{
a2
[
c
]
=
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
[
&
](
auto
weights2
,
dfor
(
b
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
auto
gamma2
,
b2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
new_weights2
,
auto
new_bias2
)
{
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
a_ins
=
p
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
inputs
()[
0
],
l_weights
});
auto
a_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
b
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
c
->
get_shape
().
lens
()},
l_bias
);
auto
mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
ins
->
inputs
().
front
(),
a_broadcast
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
{
c
,
b
});
auto
b_ins
=
p
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
mul
,
b_broadcast
);
p
.
replace_instruction
(
ins
,
add
);
}
}
}
}
...
...
src/simplify_algebra.cpp
View file @
5e1bb505
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_add_lit_broadcast
auto
lit_broadcast
()
{
return
match
::
any_of
(
match
::
is_constant
(),
match
::
name
(
"broadcast"
));
}
auto
not_lit_broadcast
()
{
return
match
::
none_of
(
match
::
is_constant
(),
match
::
name
(
"broadcast"
));
}
auto
op_lit_broadcast
(
std
::
string
op
,
std
::
string
x
,
std
::
string
y
)
{
return
match
::
name
(
std
::
move
(
op
))(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
}
auto
conv_const_weights
()
{
return
match
::
name
(
"convolution"
)(
match
::
used_once
(),
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
)));
}
struct
find_mul_conv
{
{
auto
lit_broadcast
()
const
auto
matcher
()
const
{
{
return
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
conv_const_weights
().
bind
(
"conv"
),
match
::
name
(
"broadcast"
).
bind
(
"a"
)));
}
}
auto
not_lit_broadcast
()
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
{
return
match
::
none_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
auto
ins
=
r
.
result
;
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
w_ins
=
r
.
instructions
[
"w"
];
auto
broadcast_op
=
any_cast
<
op
::
broadcast
>
(
a_ins
->
get_operator
());
if
(
broadcast_op
.
axis
!=
1
)
return
;
auto
new_a
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
0
,
w_ins
->
get_shape
().
lens
()},
a_ins
->
inputs
().
front
());
auto
new_mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
new_a
,
w_ins
);
auto
new_conv
=
p
.
insert_instruction
(
ins
,
conv_ins
->
get_operator
(),
conv_ins
->
inputs
().
front
(),
new_mul
);
p
.
replace_instruction
(
ins
,
new_conv
);
}
}
auto
add_lit_broadcast
(
std
::
string
x
,
std
::
string
y
)
const
};
// a * (x + b) => a * x + a * b
struct
find_mul_add
{
auto
matcher
()
const
{
{
return
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
auto
ax_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a_ins
,
x_ins
);
auto
ab_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
ax_ins
,
ab_ins
);
}
};
struct
find_add_lit_broadcast
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"add"
)(
return
match
::
name
(
"add"
)(
match
::
args
(
add_lit_broadcast
(
"a"
,
"x"
),
add_lit_broadcast
(
"b"
,
"y"
)));
match
::
either_arg
(
0
,
1
)(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
lit_broadcast
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
sumab
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
x_ins
,
sumab
);
}
};
struct
find_double_add_lit_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"add"
)(
match
::
args
(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
op_lit_broadcast
(
"add"
,
"b"
,
"y"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
...
@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
if
(
a_ins
->
name
()
!=
b_ins
->
name
())
return
;
instruction_ref
sumab
;
instruction_ref
sumab
;
if
(
a_ins
->
name
()
==
"broadcast"
)
if
(
a_ins
->
name
()
==
"broadcast"
and
b_ins
->
name
()
==
"broadcast"
)
{
{
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
return
;
return
;
...
@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
...
@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
}
}
};
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
match
::
find_matches
(
p
,
find_add_lit_broadcast
{});
}
struct
find_inner_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"mul"
,
"add"
)(
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
y_ins
=
r
.
instructions
[
"y"
];
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
return
;
auto
op
=
p
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
}
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
p
,
find_inner_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_mul_conv
{},
find_mul_add
{});
dead_code_elimination
{}.
apply
(
p
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/CMakeLists.txt
View file @
5e1bb505
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp
device/argmin.cpp
device/max.cpp
device/max.cpp
device/min.cpp
device/min.cpp
device/mul_add.cpp
device/exp.cpp
device/exp.cpp
device/erf.cpp
device/erf.cpp
device/log.cpp
device/log.cpp
...
...
src/targets/gpu/device/add_relu.cpp
View file @
5e1bb505
...
@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul_add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
a
,
auto
b
)
{
return
std
::
max
<
decltype
(
a
*
x
+
b
)
>
(
0
,
a
*
x
+
b
);
});
}
void
add_relu
(
hipStream_t
stream
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
5e1bb505
...
@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
().
broadcasted
());
assert
(
barg2
.
get_shape
().
broadcasted
());
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_vec_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
bp
[
bidx
];
auto
b2
=
bp
[
bidx
+
bdim_len
];
auto
out
=
output
.
data
()[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
inputs
.
data
()[
i
][
j
]...,
b2
,
b1
);
}
output
.
data
()[
i
]
=
out
;
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
().
broadcasted
());
assert
(
barg2
.
get_shape
().
broadcasted
());
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b2
,
b1
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
)
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
max_size
,
const
argument
&
result
,
const
argument
&
barg
,
const
Arguments
&
...
args
)
{
divisible_by_4
=
false
;
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
bshape
.
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
max_size
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
return
true
;
}
}
return
false
;
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
,
const
argument
&
,
const
argument
&
)
{
divisible_by_4
=
false
;
return
false
;
}
// Nullary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
}
}
// Unary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
arg
);
};
}
// Binary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
argument
barg
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg
,
arg
))
{
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
arg
);
}
else
{
nary_impl
(
stream
,
f
,
result
,
arg
,
barg
);
}
};
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
static_assert
(
sizeof
...(
args
)
>
2
,
"Args needs to be greater than 2"
);
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
auto
barg
=
back_args
(
args
...);
auto
barg1
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
fallback1
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
bshape
=
barg
.
get_shape
();
auto
barg2
=
back_args
(
args2
...);
const
bool
standard
=
bool
fallback2
=
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
barg2
.
get_shape
()
!=
barg1
.
get_shape
()
or
not
barg2
.
get_shape
().
broadcasted
()
or
const
bool
same_shapes
=
all_of
(
pop_back_args
(
args2
...)([
&
](
auto
&&
...
args3
)
{
{
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
bool
divisible_by_4
=
false
;
// TODO: Check result and args shape is the same
if
(
broadcastable
(
divisible_by_4
,
1024
,
result
,
barg2
,
args3
...))
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
if
(
divisible_by_4
)
nary_double_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
else
nary_double_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
return
false
;
}
return
true
;
});
if
(
not
fallback2
)
return
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg1
,
args2
...))
{
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
if
(
divisible_by_4
)
const
auto
&
strides
=
bshape
.
strides
();
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
else
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
nary_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
return
false
;
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
return
false
;
}
}
}
return
true
;
return
true
;
});
});
if
(
fallback
)
if
(
fallback
1
)
nary_impl
(
stream
,
f
,
result
,
args
...);
nary_impl
(
stream
,
f
,
result
,
args
...);
};
};
}
}
...
...
src/targets/gpu/device/mul_add.cpp
0 → 100644
View file @
5e1bb505
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
a
,
auto
b
)
{
return
a
*
x
+
b
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_ops.cpp
View file @
5e1bb505
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -198,21 +199,62 @@ struct hip_add_relu
...
@@ -198,21 +199,62 @@ struct hip_add_relu
}
}
};
};
struct
hip_mul_add
{
std
::
string
name
()
const
{
return
"hip::mul_add"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
hip_mul_add_relu
{
std
::
string
name
()
const
{
return
"hip::mul_add_relu"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the last arguments is the broadcasted one
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
auto
last
=
std
::
prev
(
args
.
end
());
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
auto
it
=
if
(
it
!=
args
.
end
())
std
::
find_if
(
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
if
(
it
!=
last
)
std
::
swap
(
*
it
,
*
std
::
prev
(
last
));
}
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the first arguments is the standard one
// Ensure the first arguments is the standard one
auto
it
=
std
::
find_if
(
auto
last
=
std
::
prev
(
args
.
end
());
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
auto
it
=
if
(
it
!=
args
.
end
())
std
::
find_if
(
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
last
)
std
::
swap
(
*
it
,
args
.
front
());
std
::
swap
(
*
it
,
args
.
front
());
}
}
...
@@ -220,11 +262,13 @@ struct find_add_relu
...
@@ -220,11 +262,13 @@ struct find_add_relu
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::relu"
)(
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
used_once
(),
match
::
name
(
"hip::triadd"
),
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
()))
match
::
name
(
"hip::triadd"
),
.
bind
(
"add"
)));
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"add"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -249,8 +293,10 @@ struct find_triadd
...
@@ -249,8 +293,10 @@ struct find_triadd
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
name
(
"gpu::add"
)(
match
::
used_once
()).
bind
(
"add"
),
match
::
any
(
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())).
bind
(
"input"
)));
match
::
any
(
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"input"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -273,6 +319,51 @@ struct find_triadd
...
@@ -273,6 +319,51 @@ struct find_triadd
}
}
};
};
struct
find_mul_add
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::mul"
)(
match
::
used_once
()).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_ins
=
r
.
instructions
[
"mul"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_ins
->
inputs
();
assert
(
mul_ins
!=
b_ins
);
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
args
.
insert
(
std
::
prev
(
args
.
end
()),
b_ins
);
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add
{},
args
);
}
};
struct
find_mul_add_relu
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
name
(
"hip::mul_add"
)(
match
::
used_once
()).
bind
(
"mul_add"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_add_ins
=
r
.
instructions
[
"mul_add"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_add_ins
->
inputs
();
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add_relu
{},
args
);
}
};
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
...
@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
...
@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
match
::
find_matches
(
p
,
match
::
find_matches
(
p
,
find_conv_bias_relu
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_conv_bias
{
ctx
},
find_mul_add
{},
find_mul_add_relu
{},
find_add_relu
{}
find_add_relu
{}
);
);
// clang-format on
// clang-format on
...
...
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
View file @
5e1bb505
...
@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul_add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
void
add_relu
(
hipStream_t
stream
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
...
...
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
0 → 100644
View file @
5e1bb505
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/target.cpp
View file @
5e1bb505
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_concat.hpp>
...
@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity
{},
eliminate_identity
{},
eliminate_pad
{},
eliminate_pad
{},
dead_code_elimination
{},
dead_code_elimination
{},
fwd_conv
_batchnorm
_rewrite
{},
rewrite
_batchnorm
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_rnn
{},
rewrite_rnn
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
//common_subexpression_elimination{},
//
common_subexpression_elimination{},
//dead_code_elimination{},
//
dead_code_elimination{},
simplify_algebra
{},
simplify_algebra
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
...
...
test/gpu/miopen.cpp
View file @
5e1bb505
...
@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
...
@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
}
}
};
};
struct
test_mul_add
:
verify_program
<
test_mul_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
a
=
p
.
add_parameter
(
"a"
,
bs
);
auto
b
=
p
.
add_parameter
(
"b"
,
bs
);
auto
ab
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
a
);
auto
bb
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
b
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
x
,
ab
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
mul
,
bb
);
return
p
;
}
};
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
test/matcher.cpp
View file @
5e1bb505
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
namespace
match
=
migraphx
::
match
;
namespace
match
=
migraphx
::
match
;
MIGRAPHX_PRED_MATCHER
(
throws
,
migraphx
::
instruction_ref
)
{
MIGRAPHX_THROW
(
"Matcher throws"
);
}
template
<
class
M
>
template
<
class
M
>
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
{
{
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_either_args_any1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"@literal"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"sum"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"sum"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_all_of1
)
TEST_CASE
(
match_all_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
bool
{
r
.
result
==
sum
});
}
}
TEST_CASE
(
match_lazy_any_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
any_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
one
});
}
TEST_CASE
(
match_lazy_all_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
all_of
(
match
::
none
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_lazy_none_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
none_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_any_of1
)
TEST_CASE
(
match_any_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
...
@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_any_of_lazy1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"sum"
),
match
::
name
(
"sum"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"x"
),
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x1"
),
match
::
name
(
"@literal"
).
bind
(
"y1"
)),
match
::
args
(
match
::
any
().
bind
(
"x2"
),
match
::
any
().
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_any_of_lazy5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
().
bind
(
"x1"
),
match
::
any
().
bind
(
"y1"
)),
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x2"
),
match
::
name
(
"@literal"
).
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_none_of1
)
TEST_CASE
(
match_none_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/
fwd_conv
_batchnorm_
rewrite_
test.cpp
→
test/
rewrite
_batchnorm_test.cpp
View file @
5e1bb505
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
...
@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
...
@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
...
@@ -93,10 +93,10 @@ TEST_CASE(non_literal)
...
@@ -93,10 +93,10 @@ TEST_CASE(non_literal)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any
_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none
_of
(
p2
,
&
is_batch_norm
));
}
}
TEST_CASE
(
as_literal
)
TEST_CASE
(
as_literal
)
...
@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
...
@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
...
@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
...
@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
...
...
test/simplify_algebra_test.cpp
View file @
5e1bb505
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <test.hpp>
...
@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3)
...
@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3)
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p2
.
add_literal
(
1
);
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
two
=
p2
.
add_literal
(
2
);
auto
sum1
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
x
);
auto
sum1
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
sum1
);
auto
sum3
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
sum1
,
sum2
);
auto
sum3
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
sum2
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
}
}
EXPECT
(
p1
==
p2
);
EXPECT
(
p1
==
p2
);
...
@@ -129,4 +132,73 @@ void simplify_add4()
...
@@ -129,4 +132,73 @@ void simplify_add4()
EXPECT
(
p1
==
p2
);
EXPECT
(
p1
==
p2
);
}
}
TEST_CASE
(
simplify_mul_conv1
)
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
128
,
28
,
28
}});
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
int32_type
,
{
256
,
128
,
3
,
3
}}));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{{
1
,
1
},
{
2
,
2
},
{
1
,
1
}},
x
,
w
);
auto
a
=
p
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
int32_type
,
{
256
}}));
auto
b
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
{
1
,
256
,
14
,
14
}},
a
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
conv
,
b
);
p
.
add_instruction
(
pass_op
{},
mul
);
EXPECT
(
conv
->
outputs
().
front
()
->
name
()
==
"mul"
);
p
.
compile
(
simplify_algebra_target
{});
auto
new_conv
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
&&
ins
)
{
return
ins
.
name
()
==
"convolution"
;
});
EXPECT
(
new_conv
->
outputs
().
front
()
->
name
()
!=
"mul"
);
}
TEST_CASE
(
simplify_mul_add
)
{
migraphx
::
program
p1
;
{
auto
x
=
p1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
2
);
auto
sum
=
p1
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
x
);
auto
mul
=
p1
.
add_instruction
(
migraphx
::
op
::
mul
{},
sum
,
two
);
p1
.
add_instruction
(
pass_op
{},
mul
);
}
p1
.
compile
(
simplify_algebra_target
{});
migraphx
::
program
p2
;
{
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
mul1
=
p2
.
add_instruction
(
migraphx
::
op
::
mul
{},
two
,
x
);
auto
mul2
=
p2
.
add_instruction
(
migraphx
::
op
::
mul
{},
two
,
one
);
auto
sum
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
mul1
,
mul2
);
p2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
p1
==
p2
);
}
TEST_CASE
(
simplify_inner_broadcast
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
migraphx
::
program
p1
;
{
auto
x
=
p1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
y
=
p1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
xb
=
p1
.
add_instruction
(
b
,
x
);
auto
yb
=
p1
.
add_instruction
(
b
,
y
);
auto
sum
=
p1
.
add_instruction
(
migraphx
::
op
::
add
{},
xb
,
yb
);
p1
.
add_instruction
(
pass_op
{},
sum
);
}
p1
.
compile
(
simplify_algebra_target
{});
migraphx
::
program
p2
;
{
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
y
=
p2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
sum
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
auto
sumb
=
p2
.
add_instruction
(
b
,
sum
);
p2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
p1
==
p2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment