Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6b36c82e
Commit
6b36c82e
authored
Aug 27, 2019
by
Khalique
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into test_gen_scripts
parents
34150e61
01ff5321
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
935 additions
and
76 deletions
+935
-76
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+72
-1
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+2
-2
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+1
-1
src/rewrite_batchnorm.cpp
src/rewrite_batchnorm.cpp
+57
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+131
-13
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+10
-0
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+200
-31
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+21
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+107
-14
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
+6
-0
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
+25
-0
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-4
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+18
-0
test/matcher.cpp
test/matcher.cpp
+198
-0
test/rewrite_batchnorm_test.cpp
test/rewrite_batchnorm_test.cpp
+6
-6
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+75
-3
No files found.
src/CMakeLists.txt
View file @
6b36c82e
...
@@ -12,7 +12,7 @@ add_library(migraphx
...
@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_concat.cpp
eliminate_identity.cpp
eliminate_identity.cpp
eliminate_pad.cpp
eliminate_pad.cpp
fwd_conv
_batchnorm
_rewrite
.cpp
rewrite
_batchnorm.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
env.cpp
env.cpp
...
...
src/include/migraphx/matcher.hpp
View file @
6b36c82e
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
)
;
ctx
.
instructions
[
name
]
=
ins
;
return
result
;
return
result
;
});
});
}
}
...
@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
}
}
}
}
template
<
class
M
>
struct
find_skip
{
M
m
;
M
matcher
()
const
{
return
m
;
}
void
apply
(
program
&
,
const
matcher_result
&
)
const
{}
};
template
<
class
M
>
find_skip
<
M
>
make_find_skip
(
M
m
)
{
return
{
m
};
}
struct
lazy_and
struct
lazy_and
{
{
template
<
class
F
,
class
G
>
template
<
class
F
,
class
G
>
...
@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
...
@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
template
<
class
...
Ms
>
auto
skip_matches
(
Ms
...
ms
)
{
return
make_find_skip
(
any_of
(
ms
...));
}
inline
auto
inputs
()
inline
auto
inputs
()
{
{
return
[](
auto
ins
,
auto
f
)
{
return
[](
auto
ins
,
auto
f
)
{
...
@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
...
@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
inline
auto
used_once_recursive
(
std
::
size_t
depth
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
start
)
{
// Used once
if
(
start
->
outputs
().
size
()
==
1
)
return
start
;
// Unused
if
(
start
->
outputs
().
empty
())
{
if
(
std
::
next
(
start
)
==
ctx
.
not_found
())
return
start
;
else
return
ctx
.
not_found
();
}
// Check for dead instructions
auto
is_dead
=
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
,
auto
n
)
{
if
(
n
==
0
)
return
false
;
if
(
ins
->
get_shape
().
elements
()
==
0
)
return
false
;
if
(
ins
->
outputs
().
empty
()
and
std
::
next
(
ins
)
!=
ctx
.
not_found
())
return
true
;
return
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
self
(
i
,
n
-
1
);
});
});
auto
dead
=
std
::
count_if
(
start
->
outputs
().
begin
(),
start
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
is_dead
(
i
,
depth
);
});
if
(
dead
+
1
==
start
->
outputs
().
size
())
return
start
;
return
ctx
.
not_found
();
});
}
MIGRAPHX_PRED_MATCHER
(
is_constant
,
instruction_ref
ins
)
{
return
ins
->
can_eval
();
}
MIGRAPHX_BASIC_MATCHER
(
is_unused
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
std
::
prev
(
ctx
.
not_found
()))
return
ins
;
return
ctx
.
not_found
();
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
auto
skip_output
(
Ms
...
ms
)
{
{
...
@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
...
@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
});
});
}
}
template
<
class
...
Ts
>
inline
auto
name
(
std
::
string
s
,
Ts
...
xs
)
// NOLINT
{
return
name
(
std
::
unordered_set
<
std
::
string
>
{
std
::
move
(
s
),
std
::
move
(
xs
)...});
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
{
{
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
...
...
src/include/migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp
→
src/include/migraphx/
rewrite
_batchnorm.hpp
View file @
6b36c82e
...
@@ -13,9 +13,9 @@ struct program;
...
@@ -13,9 +13,9 @@ struct program;
/**
/**
* Rewrite batchnorm to a multiply and add.
* Rewrite batchnorm to a multiply and add.
*/
*/
struct
fwd_conv
_batchnorm
_rewrite
struct
rewrite
_batchnorm
{
{
std
::
string
name
()
const
{
return
"
fwd_conv
_batchnorm
_rewrite
"
;
}
std
::
string
name
()
const
{
return
"
rewrite
_batchnorm"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/onnx/onnx.cpp
View file @
6b36c82e
...
@@ -375,7 +375,7 @@ struct onnx_parser
...
@@ -375,7 +375,7 @@ struct onnx_parser
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
uint64_t
axis
=
1
;
uint64_t
axis
=
1
;
auto
l1
=
prog
.
add_instruction
(
op
,
args
[
0
]
,
args
[
1
]);
auto
l1
=
prog
.
add_instruction
(
op
,
l0
,
args
[
1
]);
auto
l2
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
l1
->
get_shape
().
lens
()},
args
[
2
]);
auto
l2
=
prog
.
add_instruction
(
op
::
broadcast
{
axis
,
l1
->
get_shape
().
lens
()},
args
[
2
]);
return
prog
.
add_instruction
(
op
::
add
{},
l1
,
l2
);
return
prog
.
add_instruction
(
op
::
add
{},
l1
,
l2
);
}
}
...
...
src/
fwd_conv
_batchnorm
_rewrite
.cpp
→
src/
rewrite
_batchnorm.cpp
View file @
6b36c82e
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
...
@@ -11,7 +12,7 @@
...
@@ -11,7 +12,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
fwd_conv
_batchnorm
_rewrite
::
apply
(
program
&
p
)
const
void
rewrite
_batchnorm
::
apply
(
program
&
p
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
...
@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
s
=
shape
{
ins
->
get_shape
().
type
(),
{
ins
->
get_shape
().
lens
()[
1
]}};
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
argument
a
{
s
};
auto
weights_lens
=
weights
.
get_shape
().
lens
();
argument
b
{
s
};
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
visit_all
(
gamma
,
bias
,
mean
,
variance
,
a
,
b
)(
argument
new_weights
{
weights
.
get_shape
()};
[
&
](
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
a2
,
auto
b2
)
{
argument
new_bias
{{
bias
.
get_shape
().
type
(),
{
bias
.
get_shape
().
elements
()}}};
dfor
(
a
.
get_shape
().
elements
())(
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
std
::
size_t
c
)
{
a2
[
c
]
=
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
[
&
](
auto
weights2
,
dfor
(
b
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
auto
gamma2
,
b2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
new_weights2
,
auto
new_bias2
)
{
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
a_ins
=
p
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
inputs
()[
0
],
l_weights
});
auto
a_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
b
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
c
->
get_shape
().
lens
()},
l_bias
);
auto
mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
ins
->
inputs
().
front
(),
a_broadcast
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
{
c
,
b
});
auto
b_ins
=
p
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
mul
,
b_broadcast
);
p
.
replace_instruction
(
ins
,
add
);
}
}
}
}
...
...
src/simplify_algebra.cpp
View file @
6b36c82e
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_add_lit_broadcast
auto
lit_broadcast
()
{
return
match
::
any_of
(
match
::
is_constant
(),
match
::
name
(
"broadcast"
));
}
auto
not_lit_broadcast
()
{
return
match
::
none_of
(
match
::
is_constant
(),
match
::
name
(
"broadcast"
));
}
auto
op_lit_broadcast
(
std
::
string
op
,
std
::
string
x
,
std
::
string
y
)
{
return
match
::
name
(
std
::
move
(
op
))(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
}
auto
conv_const_weights
()
{
return
match
::
name
(
"convolution"
)(
match
::
used_once
(),
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
)));
}
struct
find_mul_conv
{
{
auto
lit_broadcast
()
const
auto
matcher
()
const
{
{
return
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
conv_const_weights
().
bind
(
"conv"
),
match
::
name
(
"broadcast"
).
bind
(
"a"
)));
}
}
auto
not_lit_broadcast
()
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
{
return
match
::
none_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
auto
ins
=
r
.
result
;
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
w_ins
=
r
.
instructions
[
"w"
];
auto
broadcast_op
=
any_cast
<
op
::
broadcast
>
(
a_ins
->
get_operator
());
if
(
broadcast_op
.
axis
!=
1
)
return
;
auto
new_a
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
0
,
w_ins
->
get_shape
().
lens
()},
a_ins
->
inputs
().
front
());
auto
new_mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
new_a
,
w_ins
);
auto
new_conv
=
p
.
insert_instruction
(
ins
,
conv_ins
->
get_operator
(),
conv_ins
->
inputs
().
front
(),
new_mul
);
p
.
replace_instruction
(
ins
,
new_conv
);
}
}
auto
add_lit_broadcast
(
std
::
string
x
,
std
::
string
y
)
const
};
// a * (x + b) => a * x + a * b
struct
find_mul_add
{
auto
matcher
()
const
{
{
return
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
auto
ax_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a_ins
,
x_ins
);
auto
ab_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
ax_ins
,
ab_ins
);
}
};
struct
find_add_lit_broadcast
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"add"
)(
return
match
::
name
(
"add"
)(
match
::
args
(
add_lit_broadcast
(
"a"
,
"x"
),
add_lit_broadcast
(
"b"
,
"y"
)));
match
::
either_arg
(
0
,
1
)(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
lit_broadcast
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
sumab
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
x_ins
,
sumab
);
}
};
struct
find_double_add_lit_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"add"
)(
match
::
args
(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
op_lit_broadcast
(
"add"
,
"b"
,
"y"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
...
@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
if
(
a_ins
->
name
()
!=
b_ins
->
name
())
return
;
instruction_ref
sumab
;
instruction_ref
sumab
;
if
(
a_ins
->
name
()
==
"broadcast"
)
if
(
a_ins
->
name
()
==
"broadcast"
and
b_ins
->
name
()
==
"broadcast"
)
{
{
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
return
;
return
;
...
@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
...
@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
}
}
};
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
match
::
find_matches
(
p
,
find_add_lit_broadcast
{});
}
struct
find_inner_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"mul"
,
"add"
)(
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
y_ins
=
r
.
instructions
[
"y"
];
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
return
;
auto
op
=
p
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
}
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
p
,
find_inner_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_mul_conv
{},
find_mul_add
{});
dead_code_elimination
{}.
apply
(
p
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/CMakeLists.txt
View file @
6b36c82e
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp
device/argmin.cpp
device/max.cpp
device/max.cpp
device/min.cpp
device/min.cpp
device/mul_add.cpp
device/exp.cpp
device/exp.cpp
device/erf.cpp
device/erf.cpp
device/log.cpp
device/log.cpp
...
...
src/targets/gpu/device/add_relu.cpp
View file @
6b36c82e
...
@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul_add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
a
,
auto
b
)
{
return
std
::
max
<
decltype
(
a
*
x
+
b
)
>
(
0
,
a
*
x
+
b
);
});
}
void
add_relu
(
hipStream_t
stream
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
6b36c82e
...
@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
().
broadcasted
());
assert
(
barg2
.
get_shape
().
broadcasted
());
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_vec_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
bp
[
bidx
];
auto
b2
=
bp
[
bidx
+
bdim_len
];
auto
out
=
output
.
data
()[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
inputs
.
data
()[
i
][
j
]...,
b2
,
b1
);
}
output
.
data
()[
i
]
=
out
;
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
().
broadcasted
());
assert
(
barg2
.
get_shape
().
broadcasted
());
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b2
,
b1
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -177,49 +282,113 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
)
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
max_size
,
const
argument
&
result
,
const
argument
&
barg
,
const
Arguments
&
...
args
)
{
divisible_by_4
=
false
;
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
bshape
.
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
max_size
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
return
true
;
}
}
return
false
;
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
,
const
argument
&
,
const
argument
&
)
{
divisible_by_4
=
false
;
return
false
;
}
// Nullary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
}
}
// Unary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
arg
);
};
}
// Binary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
argument
barg
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg
,
arg
))
{
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
arg
);
}
else
{
nary_impl
(
stream
,
f
,
result
,
arg
,
barg
);
}
};
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
static_assert
(
sizeof
...(
args
)
>
2
,
"Args needs to be greater than 2"
);
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
auto
barg
=
back_args
(
args
...);
auto
barg1
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
fallback1
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
bshape
=
barg
.
get_shape
();
auto
barg2
=
back_args
(
args2
...);
const
bool
standard
=
bool
fallback2
=
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
barg2
.
get_shape
()
!=
barg1
.
get_shape
()
or
not
barg2
.
get_shape
().
broadcasted
()
or
const
bool
same_shapes
=
all_of
(
pop_back_args
(
args2
...)([
&
](
auto
&&
...
args3
)
{
{
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
bool
divisible_by_4
=
false
;
// TODO: Check result and args shape is the same
if
(
broadcastable
(
divisible_by_4
,
1024
,
result
,
barg2
,
args3
...))
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
if
(
divisible_by_4
)
nary_double_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
else
nary_double_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
return
false
;
}
return
true
;
});
if
(
not
fallback2
)
return
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg1
,
args2
...))
{
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
if
(
divisible_by_4
)
const
auto
&
strides
=
bshape
.
strides
();
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
else
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
nary_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
return
false
;
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
return
false
;
}
}
}
return
true
;
return
true
;
});
});
if
(
fallback
)
if
(
fallback
1
)
nary_impl
(
stream
,
f
,
result
,
args
...);
nary_impl
(
stream
,
f
,
result
,
args
...);
};
};
}
}
...
...
src/targets/gpu/device/mul_add.cpp
0 → 100644
View file @
6b36c82e
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
a
,
auto
b
)
{
return
a
*
x
+
b
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_ops.cpp
View file @
6b36c82e
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -198,21 +199,62 @@ struct hip_add_relu
...
@@ -198,21 +199,62 @@ struct hip_add_relu
}
}
};
};
struct
hip_mul_add
{
std
::
string
name
()
const
{
return
"hip::mul_add"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
hip_mul_add_relu
{
std
::
string
name
()
const
{
return
"hip::mul_add_relu"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the last arguments is the broadcasted one
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
auto
last
=
std
::
prev
(
args
.
end
());
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
auto
it
=
if
(
it
!=
args
.
end
())
std
::
find_if
(
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
if
(
it
!=
last
)
std
::
swap
(
*
it
,
*
std
::
prev
(
last
));
}
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the first arguments is the standard one
// Ensure the first arguments is the standard one
auto
it
=
std
::
find_if
(
auto
last
=
std
::
prev
(
args
.
end
());
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
auto
it
=
if
(
it
!=
args
.
end
())
std
::
find_if
(
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
last
)
std
::
swap
(
*
it
,
args
.
front
());
std
::
swap
(
*
it
,
args
.
front
());
}
}
...
@@ -220,11 +262,13 @@ struct find_add_relu
...
@@ -220,11 +262,13 @@ struct find_add_relu
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::relu"
)(
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
used_once
(),
match
::
name
(
"hip::triadd"
),
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
()))
match
::
name
(
"hip::triadd"
),
.
bind
(
"add"
)));
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"add"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -249,8 +293,10 @@ struct find_triadd
...
@@ -249,8 +293,10 @@ struct find_triadd
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
name
(
"gpu::add"
)(
match
::
used_once
()).
bind
(
"add"
),
match
::
any
(
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())).
bind
(
"input"
)));
match
::
any
(
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"input"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -273,6 +319,51 @@ struct find_triadd
...
@@ -273,6 +319,51 @@ struct find_triadd
}
}
};
};
struct
find_mul_add
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::mul"
)(
match
::
used_once
()).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_ins
=
r
.
instructions
[
"mul"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_ins
->
inputs
();
assert
(
mul_ins
!=
b_ins
);
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
args
.
insert
(
std
::
prev
(
args
.
end
()),
b_ins
);
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add
{},
args
);
}
};
struct
find_mul_add_relu
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
name
(
"hip::mul_add"
)(
match
::
used_once
()).
bind
(
"mul_add"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_add_ins
=
r
.
instructions
[
"mul_add"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_add_ins
->
inputs
();
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add_relu
{},
args
);
}
};
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
...
@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
...
@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
match
::
find_matches
(
p
,
match
::
find_matches
(
p
,
find_conv_bias_relu
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_conv_bias
{
ctx
},
find_mul_add
{},
find_mul_add_relu
{},
find_add_relu
{}
find_add_relu
{}
);
);
// clang-format on
// clang-format on
...
...
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
View file @
6b36c82e
...
@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul_add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
void
add_relu
(
hipStream_t
stream
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
...
...
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
0 → 100644
View file @
6b36c82e
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/target.cpp
View file @
6b36c82e
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_concat.hpp>
...
@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity
{},
eliminate_identity
{},
eliminate_pad
{},
eliminate_pad
{},
dead_code_elimination
{},
dead_code_elimination
{},
fwd_conv
_batchnorm
_rewrite
{},
rewrite
_batchnorm
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_rnn
{},
rewrite_rnn
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
//common_subexpression_elimination{},
//
common_subexpression_elimination{},
//dead_code_elimination{},
//
dead_code_elimination{},
simplify_algebra
{},
simplify_algebra
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
...
...
test/gpu/miopen.cpp
View file @
6b36c82e
...
@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
...
@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
}
}
};
};
struct
test_mul_add
:
verify_program
<
test_mul_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
a
=
p
.
add_parameter
(
"a"
,
bs
);
auto
b
=
p
.
add_parameter
(
"b"
,
bs
);
auto
ab
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
a
);
auto
bb
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
b
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
x
,
ab
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
mul
,
bb
);
return
p
;
}
};
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
test/matcher.cpp
View file @
6b36c82e
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
namespace
match
=
migraphx
::
match
;
namespace
match
=
migraphx
::
match
;
MIGRAPHX_PRED_MATCHER
(
throws
,
migraphx
::
instruction_ref
)
{
MIGRAPHX_THROW
(
"Matcher throws"
);
}
template
<
class
M
>
template
<
class
M
>
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
{
{
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_either_args_any1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"@literal"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"sum"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"sum"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_all_of1
)
TEST_CASE
(
match_all_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
bool
{
r
.
result
==
sum
});
}
}
TEST_CASE
(
match_lazy_any_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
any_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
one
});
}
TEST_CASE
(
match_lazy_all_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
all_of
(
match
::
none
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_lazy_none_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
none_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_any_of1
)
TEST_CASE
(
match_any_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
...
@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_any_of_lazy1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"sum"
),
match
::
name
(
"sum"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"x"
),
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x1"
),
match
::
name
(
"@literal"
).
bind
(
"y1"
)),
match
::
args
(
match
::
any
().
bind
(
"x2"
),
match
::
any
().
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_any_of_lazy5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
().
bind
(
"x1"
),
match
::
any
().
bind
(
"y1"
)),
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x2"
),
match
::
name
(
"@literal"
).
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_none_of1
)
TEST_CASE
(
match_none_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/
fwd_conv
_batchnorm_
rewrite_
test.cpp
→
test/
rewrite
_batchnorm_test.cpp
View file @
6b36c82e
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/cpu/target.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/convolution.hpp>
...
@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
...
@@ -56,7 +56,7 @@ TEST_CASE(fwd_conv_batchnorm_rewrite_test)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p1
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
p2
.
compile
(
migraphx
::
cpu
::
target
{});
...
@@ -93,10 +93,10 @@ TEST_CASE(non_literal)
...
@@ -93,10 +93,10 @@ TEST_CASE(non_literal)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any
_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none
_of
(
p2
,
&
is_batch_norm
));
}
}
TEST_CASE
(
as_literal
)
TEST_CASE
(
as_literal
)
...
@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
...
@@ -121,7 +121,7 @@ TEST_CASE(as_literal)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
...
@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
...
@@ -159,7 +159,7 @@ TEST_CASE(literal_reshape)
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p1
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
program
p2
=
create_program
();
migraphx
::
fwd_conv
_batchnorm
_rewrite
opt
;
migraphx
::
rewrite
_batchnorm
opt
;
opt
.
apply
(
p2
);
opt
.
apply
(
p2
);
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
any_of
(
p1
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
EXPECT
(
none_of
(
p2
,
&
is_batch_norm
));
...
...
test/simplify_algebra_test.cpp
View file @
6b36c82e
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/instruction.hpp>
#include <basic_ops.hpp>
#include <basic_ops.hpp>
#include <test.hpp>
#include <test.hpp>
...
@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3)
...
@@ -91,9 +94,9 @@ TEST_CASE(simplify_add3)
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p2
.
add_literal
(
1
);
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
two
=
p2
.
add_literal
(
2
);
auto
sum1
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
x
);
auto
sum1
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
two
);
auto
sum2
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
sum1
);
auto
sum3
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
sum1
,
sum2
);
auto
sum3
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
sum2
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
p2
.
add_instruction
(
pass_op
{},
sum3
);
}
}
EXPECT
(
p1
==
p2
);
EXPECT
(
p1
==
p2
);
...
@@ -129,4 +132,73 @@ void simplify_add4()
...
@@ -129,4 +132,73 @@ void simplify_add4()
EXPECT
(
p1
==
p2
);
EXPECT
(
p1
==
p2
);
}
}
TEST_CASE
(
simplify_mul_conv1
)
{
migraphx
::
program
p
;
auto
x
=
p
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
128
,
28
,
28
}});
auto
w
=
p
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
int32_type
,
{
256
,
128
,
3
,
3
}}));
auto
conv
=
p
.
add_instruction
(
migraphx
::
op
::
convolution
{{
1
,
1
},
{
2
,
2
},
{
1
,
1
}},
x
,
w
);
auto
a
=
p
.
add_literal
(
migraphx
::
generate_literal
({
migraphx
::
shape
::
int32_type
,
{
256
}}));
auto
b
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
{
1
,
256
,
14
,
14
}},
a
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
conv
,
b
);
p
.
add_instruction
(
pass_op
{},
mul
);
EXPECT
(
conv
->
outputs
().
front
()
->
name
()
==
"mul"
);
p
.
compile
(
simplify_algebra_target
{});
auto
new_conv
=
std
::
find_if
(
p
.
begin
(),
p
.
end
(),
[](
auto
&&
ins
)
{
return
ins
.
name
()
==
"convolution"
;
});
EXPECT
(
new_conv
->
outputs
().
front
()
->
name
()
!=
"mul"
);
}
TEST_CASE
(
simplify_mul_add
)
{
migraphx
::
program
p1
;
{
auto
x
=
p1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p1
.
add_literal
(
1
);
auto
two
=
p1
.
add_literal
(
2
);
auto
sum
=
p1
.
add_instruction
(
migraphx
::
op
::
add
{},
one
,
x
);
auto
mul
=
p1
.
add_instruction
(
migraphx
::
op
::
mul
{},
sum
,
two
);
p1
.
add_instruction
(
pass_op
{},
mul
);
}
p1
.
compile
(
simplify_algebra_target
{});
migraphx
::
program
p2
;
{
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
one
=
p2
.
add_literal
(
1
);
auto
two
=
p2
.
add_literal
(
2
);
auto
mul1
=
p2
.
add_instruction
(
migraphx
::
op
::
mul
{},
two
,
x
);
auto
mul2
=
p2
.
add_instruction
(
migraphx
::
op
::
mul
{},
two
,
one
);
auto
sum
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
mul1
,
mul2
);
p2
.
add_instruction
(
pass_op
{},
sum
);
}
EXPECT
(
p1
==
p2
);
}
TEST_CASE
(
simplify_inner_broadcast
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
2
,
1
,
4
,
5
}};
migraphx
::
program
p1
;
{
auto
x
=
p1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
y
=
p1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
xb
=
p1
.
add_instruction
(
b
,
x
);
auto
yb
=
p1
.
add_instruction
(
b
,
y
);
auto
sum
=
p1
.
add_instruction
(
migraphx
::
op
::
add
{},
xb
,
yb
);
p1
.
add_instruction
(
pass_op
{},
sum
);
}
p1
.
compile
(
simplify_algebra_target
{});
migraphx
::
program
p2
;
{
auto
x
=
p2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
y
=
p2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
}});
auto
sum
=
p2
.
add_instruction
(
migraphx
::
op
::
add
{},
x
,
y
);
auto
sumb
=
p2
.
add_instruction
(
b
,
sum
);
p2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
p1
==
p2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment