Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2fdf510d
Unverified
Commit
2fdf510d
authored
Aug 26, 2019
by
mvermeulen
Committed by
GitHub
Aug 26, 2019
Browse files
Merge branch 'develop' into onnx_autopad_fix
parents
d5939189
4085af9b
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1031 additions
and
161 deletions
+1031
-161
src/CMakeLists.txt
src/CMakeLists.txt
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+72
-1
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+2
-2
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+30
-7
src/quantization.cpp
src/quantization.cpp
+2
-1
src/rewrite_batchnorm.cpp
src/rewrite_batchnorm.cpp
+57
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+131
-13
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+10
-0
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+200
-31
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+21
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+107
-14
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
+6
-0
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
+25
-0
src/targets/gpu/quant_gemm.cpp
src/targets/gpu/quant_gemm.cpp
+30
-75
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+4
-4
src/tf/tf.cpp
src/tf/tf.cpp
+89
-12
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+18
-0
test/matcher.cpp
test/matcher.cpp
+198
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+27
-0
No files found.
src/CMakeLists.txt
View file @
2fdf510d
...
@@ -12,7 +12,7 @@ add_library(migraphx
...
@@ -12,7 +12,7 @@ add_library(migraphx
eliminate_concat.cpp
eliminate_concat.cpp
eliminate_identity.cpp
eliminate_identity.cpp
eliminate_pad.cpp
eliminate_pad.cpp
fwd_conv
_batchnorm
_rewrite
.cpp
rewrite
_batchnorm.cpp
rewrite_rnn.cpp
rewrite_rnn.cpp
rewrite_pooling.cpp
rewrite_pooling.cpp
env.cpp
env.cpp
...
...
src/include/migraphx/matcher.hpp
View file @
2fdf510d
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
...
@@ -74,7 +74,7 @@ auto bind_match(M m, std::string name)
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
[
=
,
name
=
std
::
move
(
name
)
](
matcher_context
&
ctx
,
instruction_ref
ins
)
{
auto
result
=
m
.
match
(
ctx
,
ins
);
auto
result
=
m
.
match
(
ctx
,
ins
);
if
(
result
!=
ctx
.
not_found
())
if
(
result
!=
ctx
.
not_found
())
ctx
.
instructions
.
emplace
(
name
,
ins
)
;
ctx
.
instructions
[
name
]
=
ins
;
return
result
;
return
result
;
});
});
}
}
...
@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
...
@@ -240,6 +240,21 @@ void find_matches(program& p, Ms&&... ms)
}
}
}
}
template
<
class
M
>
struct
find_skip
{
M
m
;
M
matcher
()
const
{
return
m
;
}
void
apply
(
program
&
,
const
matcher_result
&
)
const
{}
};
template
<
class
M
>
find_skip
<
M
>
make_find_skip
(
M
m
)
{
return
{
m
};
}
struct
lazy_and
struct
lazy_and
{
{
template
<
class
F
,
class
G
>
template
<
class
F
,
class
G
>
...
@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
...
@@ -311,6 +326,12 @@ const constexpr auto all_of = match_fold_f<lazy_and, true, true>{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
any_of
=
match_fold_f
<
lazy_or
,
false
,
true
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
const
constexpr
auto
none_of
=
match_fold_f
<
lazy_or
,
false
,
false
>
{};
template
<
class
...
Ms
>
auto
skip_matches
(
Ms
...
ms
)
{
return
make_find_skip
(
any_of
(
ms
...));
}
inline
auto
inputs
()
inline
auto
inputs
()
{
{
return
[](
auto
ins
,
auto
f
)
{
return
[](
auto
ins
,
auto
f
)
{
...
@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
...
@@ -369,6 +390,50 @@ MIGRAPHX_BASIC_MATCHER(used_once, const matcher_context& ctx, instruction_ref in
return
ctx
.
not_found
();
return
ctx
.
not_found
();
}
}
inline
auto
used_once_recursive
(
std
::
size_t
depth
)
{
return
make_basic_fun_matcher
([
=
](
const
matcher_context
&
ctx
,
instruction_ref
start
)
{
// Used once
if
(
start
->
outputs
().
size
()
==
1
)
return
start
;
// Unused
if
(
start
->
outputs
().
empty
())
{
if
(
std
::
next
(
start
)
==
ctx
.
not_found
())
return
start
;
else
return
ctx
.
not_found
();
}
// Check for dead instructions
auto
is_dead
=
fix
<
bool
>
([
&
](
auto
self
,
auto
ins
,
auto
n
)
{
if
(
n
==
0
)
return
false
;
if
(
ins
->
get_shape
().
elements
()
==
0
)
return
false
;
if
(
ins
->
outputs
().
empty
()
and
std
::
next
(
ins
)
!=
ctx
.
not_found
())
return
true
;
return
std
::
all_of
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
self
(
i
,
n
-
1
);
});
});
auto
dead
=
std
::
count_if
(
start
->
outputs
().
begin
(),
start
->
outputs
().
end
(),
[
&
](
auto
i
)
{
return
is_dead
(
i
,
depth
);
});
if
(
dead
+
1
==
start
->
outputs
().
size
())
return
start
;
return
ctx
.
not_found
();
});
}
MIGRAPHX_PRED_MATCHER
(
is_constant
,
instruction_ref
ins
)
{
return
ins
->
can_eval
();
}
MIGRAPHX_BASIC_MATCHER
(
is_unused
,
const
matcher_context
&
ctx
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
empty
()
and
ins
!=
std
::
prev
(
ctx
.
not_found
()))
return
ins
;
return
ctx
.
not_found
();
}
template
<
class
...
Ms
>
template
<
class
...
Ms
>
auto
skip_output
(
Ms
...
ms
)
auto
skip_output
(
Ms
...
ms
)
{
{
...
@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
...
@@ -404,6 +469,12 @@ inline auto name(std::unordered_set<std::string> names)
});
});
}
}
template
<
class
...
Ts
>
inline
auto
name
(
std
::
string
s
,
Ts
...
xs
)
// NOLINT
{
return
name
(
std
::
unordered_set
<
std
::
string
>
{
std
::
move
(
s
),
std
::
move
(
xs
)...});
}
inline
auto
nargs
(
std
::
size_t
n
)
inline
auto
nargs
(
std
::
size_t
n
)
{
{
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
ins
->
inputs
().
size
()
==
n
;
});
...
...
src/include/migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp
→
src/include/migraphx/
rewrite
_batchnorm.hpp
View file @
2fdf510d
...
@@ -13,9 +13,9 @@ struct program;
...
@@ -13,9 +13,9 @@ struct program;
/**
/**
* Rewrite batchnorm to a multiply and add.
* Rewrite batchnorm to a multiply and add.
*/
*/
struct
fwd_conv
_batchnorm
_rewrite
struct
rewrite
_batchnorm
{
{
std
::
string
name
()
const
{
return
"
fwd_conv
_batchnorm
_rewrite
"
;
}
std
::
string
name
()
const
{
return
"
rewrite
_batchnorm"
;
}
void
apply
(
program
&
p
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
...
...
src/onnx/onnx.cpp
View file @
2fdf510d
...
@@ -206,6 +206,16 @@ struct onnx_parser
...
@@ -206,6 +206,16 @@ struct onnx_parser
return
out_lens
;
return
out_lens
;
}
}
instruction_ref
make_contiguous
(
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
{
return
ins
;
}
return
prog
.
add_instruction
(
op
::
contiguous
{},
ins
);
}
template
<
class
T
>
template
<
class
T
>
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
instruction_ref
add_broadcastable_binary_op
(
instruction_ref
arg0
,
instruction_ref
arg1
,
T
x
)
{
{
...
@@ -441,12 +451,7 @@ struct onnx_parser
...
@@ -441,12 +451,7 @@ struct onnx_parser
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
dims
));
});
}
}
if
(
!
args
[
0
]
->
get_shape
().
standard
())
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
]));
{
args
[
0
]
=
prog
.
add_instruction
(
op
::
contiguous
{},
args
[
0
]);
}
return
prog
.
add_instruction
(
op
,
args
[
0
]);
}
}
instruction_ref
instruction_ref
...
@@ -494,23 +499,41 @@ struct onnx_parser
...
@@ -494,23 +499,41 @@ struct onnx_parser
{
{
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
axis
=
parse_value
(
attributes
.
at
(
"axis"
)).
at
<
int
>
();
}
}
op
::
gather
op
{
axis
};
op
::
gather
op
{
axis
};
return
prog
.
add_instruction
(
op
,
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
]),
make_contiguous
(
args
[
1
]
));
}
}
instruction_ref
instruction_ref
parse_slice
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_slice
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
op
::
slice
op
;
op
::
slice
op
;
std
::
vector
<
size_t
>
dims
=
args
[
0
]
->
get_shape
().
lens
();
size_t
num_dims
=
dims
.
size
();
if
(
contains
(
attributes
,
"axes"
))
if
(
contains
(
attributes
,
"axes"
))
{
{
literal
s
=
parse_value
(
attributes
.
at
(
"axes"
));
literal
s
=
parse_value
(
attributes
.
at
(
"axes"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
axes
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
axes
));
});
}
}
else
{
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_dims
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
}
if
(
contains
(
attributes
,
"ends"
))
{
{
literal
s
=
parse_value
(
attributes
.
at
(
"ends"
));
literal
s
=
parse_value
(
attributes
.
at
(
"ends"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
ends
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
ends
));
});
for
(
size_t
i
=
0
;
i
<
num_dims
;
i
++
)
{
if
(
static_cast
<
size_t
>
(
op
.
ends
[
i
])
>
dims
[
i
])
{
op
.
ends
[
i
]
=
dims
[
i
];
}
}
}
}
if
(
contains
(
attributes
,
"starts"
))
{
{
literal
s
=
parse_value
(
attributes
.
at
(
"starts"
));
literal
s
=
parse_value
(
attributes
.
at
(
"starts"
));
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
op
.
starts
));
});
...
...
src/quantization.cpp
View file @
2fdf510d
...
@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
...
@@ -74,7 +74,8 @@ void quantize(program& prog, const std::vector<std::string>& ins_names)
// if the input is a convert operator, uses its input
// if the input is a convert operator, uses its input
// as its current input
// as its current input
instruction_ref
input_fp16
{};
instruction_ref
input_fp16
{};
if
(
input
->
name
()
==
"convert"
)
if
(
input
->
name
()
==
"convert"
and
input
->
inputs
().
front
()
->
get_shape
().
type
()
==
shape
::
half_type
)
{
{
input_fp16
=
input
->
inputs
().
front
();
input_fp16
=
input
->
inputs
().
front
();
}
}
...
...
src/
fwd_conv
_batchnorm
_rewrite
.cpp
→
src/
rewrite
_batchnorm.cpp
View file @
2fdf510d
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
...
@@ -11,7 +12,7 @@
...
@@ -11,7 +12,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
fwd_conv
_batchnorm
_rewrite
::
apply
(
program
&
p
)
const
void
rewrite
_batchnorm
::
apply
(
program
&
p
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
p
))
{
{
...
@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
...
@@ -25,46 +26,30 @@ void fwd_conv_batchnorm_rewrite::apply(program& p) const
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
if
(
any_of
({
gamma
,
bias
,
mean
,
variance
},
[](
auto
arg
)
{
return
arg
.
empty
();
}))
continue
;
continue
;
auto
conv_ins
=
ins
->
inputs
()[
0
];
auto
s
=
shape
{
ins
->
get_shape
().
type
(),
{
ins
->
get_shape
().
lens
()[
1
]}};
if
(
conv_ins
->
name
()
!=
"convolution"
)
continue
;
// Get convolution weights
auto
weights
=
conv_ins
->
inputs
()[
1
]
->
eval
();
if
(
weights
.
empty
())
continue
;
// Get epsilon
// Get epsilon
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
bn_op
=
any_cast
<
op
::
batch_norm_inference
>
(
ins
->
get_operator
());
auto
epsilon
=
bn_op
.
epsilon
;
auto
epsilon
=
bn_op
.
epsilon
;
// Get convolution op
auto
conv_op
=
conv_ins
->
get_operator
();
argument
a
{
s
};
auto
weights_lens
=
weights
.
get_shape
().
lens
();
argument
b
{
s
};
auto
conv_lens
=
conv_ins
->
get_shape
().
lens
();
visit_all
(
gamma
,
bias
,
mean
,
variance
,
a
,
b
)(
argument
new_weights
{
weights
.
get_shape
()};
[
&
](
auto
gamma2
,
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
a2
,
auto
b2
)
{
argument
new_bias
{{
bias
.
get_shape
().
type
(),
{
bias
.
get_shape
().
elements
()}}};
dfor
(
a
.
get_shape
().
elements
())(
visit_all
(
weights
,
gamma
,
bias
,
mean
,
variance
,
new_weights
,
new_bias
)(
[
&
](
std
::
size_t
c
)
{
a2
[
c
]
=
gamma2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
);
});
[
&
](
auto
weights2
,
dfor
(
b
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
auto
gamma2
,
b2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
auto
bias2
,
auto
mean2
,
auto
variance2
,
auto
new_weights2
,
auto
new_bias2
)
{
dfor
(
weights_lens
[
0
],
weights_lens
[
1
],
weights_lens
[
2
],
weights_lens
[
3
])(
[
&
](
std
::
size_t
k
,
std
::
size_t
c
,
std
::
size_t
h
,
std
::
size_t
w
)
{
new_weights2
(
k
,
c
,
h
,
w
)
=
gamma2
[
k
]
/
std
::
sqrt
(
variance2
[
k
]
+
epsilon
)
*
weights2
(
k
,
c
,
h
,
w
);
});
dfor
(
new_bias
.
get_shape
().
elements
())([
&
](
std
::
size_t
c
)
{
new_bias2
[
c
]
=
bias2
[
c
]
-
(
gamma2
[
c
]
*
mean2
[
c
]
/
std
::
sqrt
(
variance2
[
c
]
+
epsilon
));
});
});
});
});
// Replace convolution instruction with updated weights
auto
l_weights
=
p
.
add_literal
({
weights
.
get_shape
(),
new_weights
.
data
()});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
l_bias
=
p
.
add_literal
({
new_bias
.
get_shape
(),
new_bias
.
data
()});
auto
a_ins
=
p
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
c
=
p
.
replace_instruction
(
conv_ins
,
conv_op
,
{
conv_ins
->
inputs
()[
0
],
l_weights
});
auto
a_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
b
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
c
->
get_shape
().
lens
()},
l_bias
);
auto
mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
ins
->
inputs
().
front
(),
a_broadcast
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
{
c
,
b
});
auto
b_ins
=
p
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
mul
,
b_broadcast
);
p
.
replace_instruction
(
ins
,
add
);
}
}
}
}
...
...
src/simplify_algebra.cpp
View file @
2fdf510d
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/simplify_algebra.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
find_add_lit_broadcast
auto
lit_broadcast
()
{
return
match
::
any_of
(
match
::
is_constant
(),
match
::
name
(
"broadcast"
));
}
auto
not_lit_broadcast
()
{
return
match
::
none_of
(
match
::
is_constant
(),
match
::
name
(
"broadcast"
));
}
auto
op_lit_broadcast
(
std
::
string
op
,
std
::
string
x
,
std
::
string
y
)
{
return
match
::
name
(
std
::
move
(
op
))(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
}
auto
conv_const_weights
()
{
return
match
::
name
(
"convolution"
)(
match
::
used_once
(),
match
::
args
(
match
::
any
(),
match
::
is_constant
().
bind
(
"w"
)));
}
struct
find_mul_conv
{
{
auto
lit_broadcast
()
const
auto
matcher
()
const
{
{
return
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
conv_const_weights
().
bind
(
"conv"
),
match
::
name
(
"broadcast"
).
bind
(
"a"
)));
}
}
auto
not_lit_broadcast
()
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
w_ins
=
r
.
instructions
[
"w"
];
auto
broadcast_op
=
any_cast
<
op
::
broadcast
>
(
a_ins
->
get_operator
());
if
(
broadcast_op
.
axis
!=
1
)
return
;
auto
new_a
=
p
.
insert_instruction
(
ins
,
op
::
broadcast
{
0
,
w_ins
->
get_shape
().
lens
()},
a_ins
->
inputs
().
front
());
auto
new_mul
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
new_a
,
w_ins
);
auto
new_conv
=
p
.
insert_instruction
(
ins
,
conv_ins
->
get_operator
(),
conv_ins
->
inputs
().
front
(),
new_mul
);
p
.
replace_instruction
(
ins
,
new_conv
);
}
};
// a * (x + b) => a * x + a * b
struct
find_mul_add
{
auto
matcher
()
const
{
{
return
match
::
none_of
(
match
::
name
(
"@literal"
),
match
::
name
(
"broadcast"
));
return
match
::
name
(
"mul"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any_of
(
conv_const_weights
(),
match
::
is_constant
()).
bind
(
"b"
)),
match
::
none_of
(
match
::
args
(
match
::
is_constant
(),
match
::
is_constant
())),
match
::
used_once
()),
match
::
is_constant
().
bind
(
"a"
)));
}
}
auto
add_lit_broadcast
(
std
::
string
x
,
std
::
string
y
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
{
return
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
lit_broadcast
().
bind
(
std
::
move
(
x
)),
auto
ins
=
r
.
result
;
not_lit_broadcast
().
bind
(
std
::
move
(
y
))));
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
auto
ax_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a_ins
,
x_ins
);
auto
ab_ins
=
p
.
insert_instruction
(
ins
,
op
::
mul
{},
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
ax_ins
,
ab_ins
);
}
}
};
struct
find_add_lit_broadcast
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"add"
)(
return
match
::
name
(
"add"
)(
match
::
args
(
add_lit_broadcast
(
"a"
,
"x"
),
add_lit_broadcast
(
"b"
,
"y"
)));
match
::
either_arg
(
0
,
1
)(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
lit_broadcast
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
sumab
=
p
.
insert_instruction
(
ins
,
op
::
add
{},
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
op
::
add
{},
x_ins
,
sumab
);
}
};
struct
find_double_add_lit_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"add"
)(
match
::
args
(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
op_lit_broadcast
(
"add"
,
"b"
,
"y"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
...
@@ -36,11 +117,9 @@ struct find_add_lit_broadcast
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
if
(
a_ins
->
name
()
!=
b_ins
->
name
())
return
;
instruction_ref
sumab
;
instruction_ref
sumab
;
if
(
a_ins
->
name
()
==
"broadcast"
)
if
(
a_ins
->
name
()
==
"broadcast"
and
b_ins
->
name
()
==
"broadcast"
)
{
{
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
return
;
return
;
...
@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
...
@@ -59,7 +138,46 @@ struct find_add_lit_broadcast
}
}
};
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
match
::
find_matches
(
p
,
find_add_lit_broadcast
{});
}
struct
find_inner_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"mul"
,
"add"
)(
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
y_ins
=
r
.
instructions
[
"y"
];
auto
xbroadcast
=
any_cast
<
op
::
broadcast
>
(
x_ins
->
get_operator
());
auto
ybroadcast
=
any_cast
<
op
::
broadcast
>
(
y_ins
->
get_operator
());
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
return
;
auto
op
=
p
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
}
};
void
simplify_algebra
::
apply
(
program
&
p
)
const
{
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
match
::
find_matches
(
p
,
find_inner_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_mul_conv
{},
find_mul_add
{});
dead_code_elimination
{}.
apply
(
p
);
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/CMakeLists.txt
View file @
2fdf510d
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
...
@@ -16,6 +16,7 @@ add_library(migraphx_device
device/argmin.cpp
device/argmin.cpp
device/max.cpp
device/max.cpp
device/min.cpp
device/min.cpp
device/mul_add.cpp
device/exp.cpp
device/exp.cpp
device/erf.cpp
device/erf.cpp
device/log.cpp
device/log.cpp
...
...
src/targets/gpu/device/add_relu.cpp
View file @
2fdf510d
...
@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -6,6 +6,16 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul_add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)(
[](
auto
x
,
auto
a
,
auto
b
)
{
return
std
::
max
<
decltype
(
a
*
x
+
b
)
>
(
0
,
a
*
x
+
b
);
});
}
void
add_relu
(
hipStream_t
stream
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
2fdf510d
...
@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -118,6 +118,111 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
().
broadcasted
());
assert
(
barg2
.
get_shape
().
broadcasted
());
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
const
std
::
size_t
nelements
=
output
.
size
()
/
vec_size
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
/
vec_size
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_vec_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
bp
[
bidx
];
auto
b2
=
bp
[
bidx
+
bdim_len
];
auto
out
=
output
.
data
()[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
out
[
j
]
=
f
(
inputs
.
data
()[
i
][
j
]...,
b2
,
b1
);
}
output
.
data
()[
i
]
=
out
;
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
assert
(
barg1
.
get_shape
().
broadcasted
());
assert
(
barg2
.
get_shape
().
broadcasted
());
assert
(
barg1
.
get_shape
()
==
barg2
.
get_shape
());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b2
,
b1
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -177,23 +282,18 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -177,23 +282,18 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
)
bool
broadcastable
(
bool
&
divisible_by_4
,
{
std
::
size_t
max_size
,
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
const
argument
&
result
,
}
const
argument
&
barg
,
const
Arguments
&
...
args
)
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
divisible_by_4
=
false
;
return
[
=
](
auto
f
)
{
auto
barg
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
bshape
=
barg
.
get_shape
();
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
const
bool
standard
=
all_of
({
args
2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
const
bool
same_shapes
=
{
args
2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
all_of
(
{
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
{
...
@@ -204,22 +304,91 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
...
@@ -204,22 +304,91 @@ auto nary(hipStream_t stream, argument result, Arguments... args)
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
if
(
b_len
<=
max_size
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
{
const
bool
divisible_by_4
=
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
return
true
;
}
}
return
false
;
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
std
::
size_t
,
const
argument
&
,
const
argument
&
)
{
divisible_by_4
=
false
;
return
false
;
}
// Nullary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
}
// Unary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
arg
);
};
}
// Binary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
argument
barg
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg
,
arg
))
{
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
arg
);
}
else
{
nary_impl
(
stream
,
f
,
result
,
arg
,
barg
);
}
};
}
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
static_assert
(
sizeof
...(
args
)
>
2
,
"Args needs to be greater than 2"
);
return
[
=
](
auto
f
)
{
auto
barg1
=
back_args
(
args
...);
bool
fallback1
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
barg2
=
back_args
(
args2
...);
bool
fallback2
=
barg2
.
get_shape
()
!=
barg1
.
get_shape
()
or
not
barg2
.
get_shape
().
broadcasted
()
or
pop_back_args
(
args2
...)([
&
](
auto
&&
...
args3
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
1024
,
result
,
barg2
,
args3
...))
{
if
(
divisible_by_4
)
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
nary_double_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
barg2
,
args3
...);
else
else
nary
_broadcast_impl
(
stream
,
f
,
result
,
barg
,
arg
s
2
...);
nary_double
_broadcast_impl
(
stream
,
f
,
result
,
barg
1
,
b
arg2
,
args3
...);
return
false
;
return
false
;
}
}
return
true
;
});
if
(
not
fallback2
)
return
false
;
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
2048
,
result
,
barg1
,
args2
...))
{
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
return
false
;
}
}
return
true
;
return
true
;
});
});
if
(
fallback
)
if
(
fallback
1
)
nary_impl
(
stream
,
f
,
result
,
args
...);
nary_impl
(
stream
,
f
,
result
,
args
...);
};
};
}
}
...
...
src/targets/gpu/device/mul_add.cpp
0 → 100644
View file @
2fdf510d
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/nary.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
a
,
auto
b
)
{
return
a
*
x
+
b
;
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_ops.cpp
View file @
2fdf510d
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add_relu.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
...
@@ -198,21 +199,62 @@ struct hip_add_relu
...
@@ -198,21 +199,62 @@ struct hip_add_relu
}
}
};
};
struct
hip_mul_add
{
std
::
string
name
()
const
{
return
"hip::mul_add"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
hip_mul_add_relu
{
std
::
string
name
()
const
{
return
"hip::mul_add_relu"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
4
);
return
inputs
.
front
();
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
device
::
mul_add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_broadcasted_back
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the last arguments is the broadcasted one
// Ensure the last arguments is the broadcasted one
auto
it
=
std
::
find_if
(
auto
last
=
std
::
prev
(
args
.
end
());
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
auto
it
=
if
(
it
!=
args
.
end
())
std
::
find_if
(
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
broadcasted
();
});
std
::
swap
(
*
it
,
*
std
::
prev
(
args
.
end
(),
2
));
if
(
it
!=
last
)
std
::
swap
(
*
it
,
*
std
::
prev
(
last
));
}
}
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
void
move_standard_front
(
std
::
vector
<
instruction_ref
>&
args
)
{
{
// Ensure the first arguments is the standard one
// Ensure the first arguments is the standard one
auto
it
=
std
::
find_if
(
auto
last
=
std
::
prev
(
args
.
end
());
args
.
begin
(),
args
.
end
(),
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
auto
it
=
if
(
it
!=
args
.
end
())
std
::
find_if
(
args
.
begin
(),
last
,
[](
auto
arg
)
{
return
arg
->
get_shape
().
standard
();
});
if
(
it
!=
last
)
std
::
swap
(
*
it
,
args
.
front
());
std
::
swap
(
*
it
,
args
.
front
());
}
}
...
@@ -220,10 +262,12 @@ struct find_add_relu
...
@@ -220,10 +262,12 @@ struct find_add_relu
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::relu"
)(
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
arg
(
0
)(
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
used_once
(),
match
::
any_of
(
match
::
name
(
"gpu::add"
),
match
::
name
(
"hip::triadd"
),
match
::
name
(
"hip::triadd"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
()))
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"add"
)));
.
bind
(
"add"
)));
}
}
...
@@ -249,8 +293,10 @@ struct find_triadd
...
@@ -249,8 +293,10 @@ struct find_triadd
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::add"
).
bind
(
"add"
),
match
::
name
(
"gpu::add"
)(
match
::
used_once
()).
bind
(
"add"
),
match
::
any
(
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())).
bind
(
"input"
)));
match
::
any
(
match
::
any_of
(
match
::
name
(
"@literal"
),
match
::
any_of
[
match
::
inputs
()](
match
::
standard_shape
())))
.
bind
(
"input"
)));
}
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
...
@@ -273,6 +319,51 @@ struct find_triadd
...
@@ -273,6 +319,51 @@ struct find_triadd
}
}
};
};
struct
find_mul_add
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::add"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"gpu::mul"
)(
match
::
used_once
()).
bind
(
"mul"
),
match
::
any
().
bind
(
"b"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_ins
=
r
.
instructions
[
"mul"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_ins
->
inputs
();
assert
(
mul_ins
!=
b_ins
);
move_standard_front
(
args
);
move_broadcasted_back
(
args
);
args
.
insert
(
std
::
prev
(
args
.
end
()),
b_ins
);
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add
{},
args
);
}
};
struct
find_mul_add_relu
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::relu"
)(
match
::
arg
(
0
)(
match
::
name
(
"hip::mul_add"
)(
match
::
used_once
()).
bind
(
"mul_add"
)));
}
void
apply
(
program
&
p
,
match
::
matcher_result
r
)
const
{
auto
mul_add_ins
=
r
.
instructions
[
"mul_add"
];
auto
ins
=
r
.
result
;
auto
args
=
mul_add_ins
->
inputs
();
// Use the allocation from the relu operator
args
.
back
()
=
ins
->
inputs
().
back
();
p
.
replace_instruction
(
ins
,
hip_mul_add_relu
{},
args
);
}
};
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
...
@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
...
@@ -428,6 +519,8 @@ void fuse_ops::apply(program& p) const
match
::
find_matches
(
p
,
match
::
find_matches
(
p
,
find_conv_bias_relu
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_conv_bias
{
ctx
},
find_mul_add
{},
find_mul_add_relu
{},
find_add_relu
{}
find_add_relu
{}
);
);
// clang-format on
// clang-format on
...
...
src/targets/gpu/include/migraphx/gpu/device/add_relu.hpp
View file @
2fdf510d
...
@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,6 +11,12 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul_add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
void
add_relu
(
hipStream_t
stream
,
void
add_relu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
...
...
src/targets/gpu/include/migraphx/gpu/device/mul_add.hpp
0 → 100644
View file @
2fdf510d
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/quant_gemm.cpp
View file @
2fdf510d
...
@@ -8,51 +8,6 @@ namespace migraphx {
...
@@ -8,51 +8,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
template
<
class
...
Ts
>
rocblas_status
generic_rocblas_gemm_ex
(
Ts
&&
...
xs
)
{
return
rocblas_gemm_ex
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
...
Ts
>
rocblas_status
generic_rocblas_batched_gemm_ex
(
Ts
&&
...
xs
)
{
return
rocblas_gemm_strided_batched_ex
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
>
struct
compute_rocblas_type
{
using
type
=
T
;
};
template
<
class
T
>
struct
compute_rocblas_type
<
const
T
>
{
using
type
=
const
typename
compute_rocblas_type
<
T
>::
type
;
};
template
<
>
struct
compute_rocblas_type
<
half
>
{
using
type
=
rocblas_half
;
};
template
<
class
T
>
using
rb_type
=
typename
compute_rocblas_type
<
T
>::
type
;
template
<
class
T
>
rb_type
<
T
>
to_rocblas_type
(
T
x
)
{
return
reinterpret_cast
<
const
rb_type
<
T
>&>
(
x
);
}
template
<
class
T
>
rb_type
<
T
>*
to_rocblas_type
(
T
*
x
)
{
return
reinterpret_cast
<
rb_type
<
T
>*>
(
x
);
}
shape
rocblas_quant_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
rocblas_quant_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
...
@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
...
@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
)
)
;
auto
alpha_r
=
as
(
op
.
alpha
);
auto
beta_r
=
to_rocblas_type
(
as
(
beta
)
)
;
auto
beta_r
=
as
(
beta
);
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
())
)
;
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
as
.
from
(
arg
.
data
());
};
assert
(
k
%
4
==
0
);
assert
(
k
%
4
==
0
);
auto
num_matrices
=
std
::
accumulate
(
auto
num_matrices
=
std
::
accumulate
(
...
@@ -119,7 +74,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
...
@@ -119,7 +74,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
// A and args[0] as B in calling the rocblas_gemm.
generic_
rocblas_gemm_ex
(
ctx
.
get_stream
().
get_rocblas
(),
rocblas_gemm_ex
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
n
,
...
@@ -148,7 +103,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
...
@@ -148,7 +103,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
}
}
else
else
{
{
generic_rocblas
_batched_
gemm_
ex
(
rocblas_gemm_strided
_batched_ex
(
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
src/targets/gpu/target.cpp
View file @
2fdf510d
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#include <migraphx/propagate_constant.hpp>
#include <migraphx/propagate_constant.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/common_subexpression_elimination.hpp>
#include <migraphx/
fwd_conv
_batchnorm
_rewrite
.hpp>
#include <migraphx/
rewrite
_batchnorm.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/rewrite_pooling.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_concat.hpp>
...
@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
...
@@ -44,13 +44,13 @@ std::vector<pass> target::get_passes(migraphx::context& gctx) const
eliminate_identity
{},
eliminate_identity
{},
eliminate_pad
{},
eliminate_pad
{},
dead_code_elimination
{},
dead_code_elimination
{},
fwd_conv
_batchnorm
_rewrite
{},
rewrite
_batchnorm
{},
dead_code_elimination
{},
dead_code_elimination
{},
rewrite_rnn
{},
rewrite_rnn
{},
rewrite_pooling
{},
rewrite_pooling
{},
dead_code_elimination
{},
dead_code_elimination
{},
//common_subexpression_elimination{},
//
common_subexpression_elimination{},
//dead_code_elimination{},
//
dead_code_elimination{},
simplify_algebra
{},
simplify_algebra
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
...
...
src/tf/tf.cpp
View file @
2fdf510d
...
@@ -26,7 +26,6 @@ struct tf_parser
...
@@ -26,7 +26,6 @@ struct tf_parser
{
{
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
AttrValue
>
;
using
attribute_map
=
std
::
unordered_map
<
std
::
string
,
tensorflow
::
AttrValue
>
;
using
node_map
=
std
::
map
<
std
::
string
,
tensorflow
::
NodeDef
>
;
using
node_map
=
std
::
map
<
std
::
string
,
tensorflow
::
NodeDef
>
;
// using input_node_map = std::unordered_map<std::string, std::unordered_set<std::string>>;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
using
op_func
=
std
::
function
<
instruction_ref
(
attribute_map
,
std
::
vector
<
instruction_ref
>
)
>
;
node_map
nodes
;
node_map
nodes
;
...
@@ -149,9 +148,26 @@ struct tf_parser
...
@@ -149,9 +148,26 @@ struct tf_parser
return
axes
;
return
axes
;
}
}
std
::
vector
<
int64_t
>
get_axes_from_mask
(
const
size_t
num_axes
,
const
uint32_t
mask
)
{
uint32_t
bitwise_compare
=
1
;
std
::
vector
<
int64_t
>
axes
;
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
// the LSB corresponds to axis 0 when determining which axes to begin
if
(((
mask
>>
i
)
&
bitwise_compare
)
==
1
)
axes
.
push_back
(
1
);
else
axes
.
push_back
(
0
);
}
return
axes
;
}
tf_parser
()
tf_parser
()
{
{
add_generic_op
(
"All"
,
op
::
identity
{});
add_generic_op
(
"Identity"
,
op
::
identity
{});
add_generic_op
(
"Identity"
,
op
::
identity
{});
add_generic_op
(
"LessEqual"
,
op
::
identity
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu"
,
op
::
relu
{});
add_generic_op
(
"Relu6"
,
op
::
clip
{
6.0
,
0.0
});
add_generic_op
(
"Relu6"
,
op
::
clip
{
6.0
,
0.0
});
add_generic_op
(
"Rsqrt"
,
op
::
rsqrt
{});
add_generic_op
(
"Rsqrt"
,
op
::
rsqrt
{});
...
@@ -166,6 +182,7 @@ struct tf_parser
...
@@ -166,6 +182,7 @@ struct tf_parser
add_mem_op
(
"AvgPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"AvgPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"BatchMatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"BatchMatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"BatchMatMulV2"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"BiasAdd"
,
&
tf_parser
::
parse_biasadd
);
add_mem_op
(
"Cast"
,
&
tf_parser
::
parse_cast
,
false
);
add_mem_op
(
"Cast"
,
&
tf_parser
::
parse_cast
,
false
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
add_mem_op
(
"ConcatV2"
,
&
tf_parser
::
parse_concat
,
false
);
...
@@ -177,14 +194,15 @@ struct tf_parser
...
@@ -177,14 +194,15 @@ struct tf_parser
add_mem_op
(
"GatherV2"
,
&
tf_parser
::
parse_gather
,
false
);
add_mem_op
(
"GatherV2"
,
&
tf_parser
::
parse_gather
,
false
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MatMul"
,
&
tf_parser
::
parse_matmul
,
false
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"MaxPool"
,
&
tf_parser
::
parse_pooling
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
);
add_mem_op
(
"Mean"
,
&
tf_parser
::
parse_mean
,
false
);
add_mem_op
(
"OneHot"
,
&
tf_parser
::
parse_onehot
,
false
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pack"
,
&
tf_parser
::
parse_pack
,
false
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Pad"
,
&
tf_parser
::
parse_pad
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Reshape"
,
&
tf_parser
::
parse_reshape
,
false
);
add_mem_op
(
"Slice"
,
&
tf_parser
::
parse_slice
,
false
);
add_mem_op
(
"Slice"
,
&
tf_parser
::
parse_slice
,
false
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
<
op
::
softmax
>
);
add_mem_op
(
"Softmax"
,
&
tf_parser
::
parse_softmax
<
op
::
softmax
>
,
false
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"Squeeze"
,
&
tf_parser
::
parse_squeeze
,
false
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
);
add_mem_op
(
"StridedSlice"
,
&
tf_parser
::
parse_stridedslice
,
false
);
add_mem_op
(
"Transpose"
,
&
tf_parser
::
parse_transpose
,
false
);
add_mem_op
(
"Transpose"
,
&
tf_parser
::
parse_transpose
,
false
);
}
}
...
@@ -547,7 +565,7 @@ struct tf_parser
...
@@ -547,7 +565,7 @@ struct tf_parser
}
}
if
(
contains
(
attributes
,
"transpose_b"
))
if
(
contains
(
attributes
,
"transpose_b"
))
{
{
transb
=
attributes
.
at
(
"transpose_
a
"
).
b
();
transb
=
attributes
.
at
(
"transpose_
b
"
).
b
();
}
}
if
(
contains
(
attributes
,
"adj_x"
))
if
(
contains
(
attributes
,
"adj_x"
))
...
@@ -574,8 +592,7 @@ struct tf_parser
...
@@ -574,8 +592,7 @@ struct tf_parser
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_mean
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
bool
keep_dims
=
attributes
.
at
(
"keep_dims"
).
b
();
auto
lens
=
args
[
0
]
->
get_shape
().
lens
();
auto
axes
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
<
int64_t
>
();
auto
axes
=
parse_axes
(
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
<
int64_t
>
(),
lens
.
size
());
if
(
keep_dims
)
if
(
keep_dims
)
{
{
...
@@ -588,6 +605,32 @@ struct tf_parser
...
@@ -588,6 +605,32 @@ struct tf_parser
}
}
}
}
instruction_ref
parse_onehot
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
size_t
depth
=
static_cast
<
size_t
>
(
args
[
1
]
->
eval
().
at
<
int32_t
>
());
int64_t
axis
=
-
1
;
float
on_value
=
args
[
2
]
->
eval
().
at
<
float
>
();
float
off_value
=
args
[
3
]
->
eval
().
at
<
float
>
();
std
::
vector
<
float
>
depth_input
(
depth
*
depth
,
off_value
);
for
(
int
i
=
0
;
i
<
depth
;
i
++
)
{
depth_input
[
depth
*
i
+
i
]
=
on_value
;
}
if
(
contains
(
attributes
,
"axis"
))
axis
=
attributes
.
at
(
"axis"
).
i
();
if
(
axis
==
-
1
)
{
shape
s
{
shape
::
float_type
,
{
depth
,
depth
}};
auto
l0
=
prog
.
add_literal
({
s
,
depth_input
});
return
prog
.
add_instruction
(
op
::
gather
{
0
},
{
l0
,
args
[
0
]});
}
MIGRAPHX_THROW
(
"MIGraphX does not support axis != -1"
);
}
instruction_ref
parse_pack
(
const
std
::
string
&
,
instruction_ref
parse_pack
(
const
std
::
string
&
,
const
attribute_map
&
attributes
,
const
attribute_map
&
attributes
,
std
::
vector
<
instruction_ref
>
args
)
std
::
vector
<
instruction_ref
>
args
)
...
@@ -801,19 +844,48 @@ struct tf_parser
...
@@ -801,19 +844,48 @@ struct tf_parser
op
::
slice
op
;
op
::
slice
op
;
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
starts
=
args
[
1
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
auto
ends
=
args
[
2
]
->
eval
().
get
<
int32_t
>
().
to_vector
();
size_t
num_axes
=
args
[
0
]
->
get_shape
().
lens
().
size
();
auto
l0
=
args
[
0
];
size_t
num_axes
=
l0
->
get_shape
().
lens
().
size
();
std
::
vector
<
size_t
>
axes
=
l0
->
get_shape
().
lens
();
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
starts
=
std
::
vector
<
int64_t
>
(
starts
.
begin
(),
starts
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
ends
.
begin
(),
ends
.
end
());
op
.
ends
=
std
::
vector
<
int64_t
>
(
ends
.
begin
(),
ends
.
end
());
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
op
.
axes
=
std
::
vector
<
int64_t
>
(
num_axes
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
std
::
iota
(
op
.
axes
.
begin
(),
op
.
axes
.
end
(),
0
);
uint32_t
begin_mask
=
0
;
uint32_t
end_mask
=
0
;
uint32_t
shrink_axis_mask
=
0
;
uint32_t
shrink_axis_mask
=
0
;
uint32_t
bitwise_compare
=
1
;
uint32_t
bitwise_compare
=
1
;
std
::
vector
<
int64_t
>
squeeze_axes
;
std
::
vector
<
int64_t
>
squeeze_axes
;
if
(
contains
(
attributes
,
"begin_mask"
))
begin_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"begin_mask"
).
i
());
if
(
contains
(
attributes
,
"end_mask"
))
end_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"end_mask"
).
i
());
if
(
contains
(
attributes
,
"shrink_axis_mask"
))
if
(
contains
(
attributes
,
"shrink_axis_mask"
))
shrink_axis_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"shrink_axis_mask"
).
i
());
shrink_axis_mask
=
static_cast
<
uint32_t
>
(
attributes
.
at
(
"shrink_axis_mask"
).
i
());
std
::
vector
<
int64_t
>
begin_axes
=
get_axes_from_mask
(
num_axes
,
begin_mask
);
std
::
vector
<
int64_t
>
end_axes
=
get_axes_from_mask
(
num_axes
,
end_mask
);
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
if
(
begin_axes
.
at
(
i
)
==
1
)
{
op
.
starts
.
at
(
i
)
=
0
;
}
if
(
end_axes
.
at
(
i
)
==
1
)
{
op
.
ends
.
at
(
i
)
=
axes
.
at
(
i
);
}
}
auto
l1
=
prog
.
add_instruction
(
op
,
l0
);
if
(
shrink_axis_mask
==
0
)
return
l1
;
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
for
(
size_t
i
=
0
;
i
<
num_axes
;
i
++
)
{
{
// the LSB corresponds to axis 0 when determining which axes to squeeze
// the LSB corresponds to axis 0 when determining which axes to squeeze
...
@@ -821,8 +893,7 @@ struct tf_parser
...
@@ -821,8 +893,7 @@ struct tf_parser
squeeze_axes
.
push_back
(
i
);
squeeze_axes
.
push_back
(
i
);
}
}
auto
l0
=
prog
.
add_instruction
(
op
,
make_contiguous
(
args
[
0
]));
return
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l1
);
return
to_nhwc
(
prog
.
add_instruction
(
op
::
squeeze
{
squeeze_axes
},
l0
));
}
}
instruction_ref
instruction_ref
...
@@ -862,10 +933,16 @@ struct tf_parser
...
@@ -862,10 +933,16 @@ struct tf_parser
if
(
instructions
.
count
(
name
)
==
0
)
if
(
instructions
.
count
(
name
)
==
0
)
{
{
auto
&&
node
=
nodes
.
at
(
name
);
auto
&&
node
=
nodes
.
at
(
name
);
// assert ops ignored
if
(
node
.
op
()
==
"Assert"
or
contains
(
name
,
"Assert"
))
return
;
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
for
(
auto
&&
input
:
node
.
input
())
for
(
auto
&&
input
:
node
.
input
())
{
{
// control dependencies (signified by ^ before the name) are ignored
if
(
contains
(
input
,
"^"
))
continue
;
if
(
nodes
.
count
(
input
)
>
0
)
if
(
nodes
.
count
(
input
)
>
0
)
{
{
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
auto
&&
iname
=
get_name
(
nodes
.
at
(
input
));
...
...
test/gpu/miopen.cpp
View file @
2fdf510d
...
@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
...
@@ -502,6 +502,24 @@ struct test_triadd2 : verify_program<test_triadd2>
}
}
};
};
struct
test_mul_add
:
verify_program
<
test_mul_add
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
bs
{
migraphx
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
auto
a
=
p
.
add_parameter
(
"a"
,
bs
);
auto
b
=
p
.
add_parameter
(
"b"
,
bs
);
auto
ab
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
a
);
auto
bb
=
p
.
add_instruction
(
migraphx
::
op
::
broadcast
{
1
,
s
.
lens
()},
b
);
auto
mul
=
p
.
add_instruction
(
migraphx
::
op
::
mul
{},
x
,
ab
);
p
.
add_instruction
(
migraphx
::
op
::
add
{},
mul
,
bb
);
return
p
;
}
};
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
...
...
test/matcher.cpp
View file @
2fdf510d
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
namespace
match
=
migraphx
::
match
;
namespace
match
=
migraphx
::
match
;
MIGRAPHX_PRED_MATCHER
(
throws
,
migraphx
::
instruction_ref
)
{
MIGRAPHX_THROW
(
"Matcher throws"
);
}
template
<
class
M
>
template
<
class
M
>
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
migraphx
::
match
::
matcher_result
find_match
(
migraphx
::
program
&
p
,
M
&&
m
)
{
{
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
...
@@ -331,6 +333,81 @@ TEST_CASE(match_either_args3)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_either_args_any1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"@literal"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"@literal"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum1
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
name
(
"sum"
).
bind
(
"x"
),
match
::
any
().
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_either_args_any5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum1
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
auto
sum2
=
p
.
add_instruction
(
sum_op
{},
sum1
,
two
);
p
.
add_instruction
(
pass_op
{},
sum2
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
either_arg
(
0
,
1
)(
match
::
any
().
bind
(
"x"
),
match
::
name
(
"sum"
).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum2
});
EXPECT
(
bool
{
r
.
instructions
.
at
(
"x"
)
!=
r
.
instructions
.
at
(
"y"
)});
}
TEST_CASE
(
match_all_of1
)
TEST_CASE
(
match_all_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
...
@@ -370,6 +447,36 @@ TEST_CASE(match_all_of3)
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
bool
{
r
.
result
==
sum
});
}
}
TEST_CASE
(
match_lazy_any_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
any_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
one
});
}
TEST_CASE
(
match_lazy_all_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
all_of
(
match
::
none
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_lazy_none_of
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
p
.
add_instruction
(
pass_op
{},
one
);
auto
m
=
match
::
none_of
(
match
::
any
(),
throws
());
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
TEST_CASE
(
match_any_of1
)
TEST_CASE
(
match_any_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
...
@@ -396,6 +503,97 @@ TEST_CASE(match_any_of2)
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
}
TEST_CASE
(
match_any_of_lazy1
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"sum"
),
match
::
name
(
"sum"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy2
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"x"
),
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy3
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
(),
match
::
any
()).
bind
(
"x"
),
match
::
args
(
match
::
name
(
"@literal"
),
match
::
name
(
"@literal"
)).
bind
(
"y"
)));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x"
));
EXPECT
(
bool
{
r
.
instructions
[
"x"
]
==
sum
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y"
));
}
TEST_CASE
(
match_any_of_lazy4
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x1"
),
match
::
name
(
"@literal"
).
bind
(
"y1"
)),
match
::
args
(
match
::
any
().
bind
(
"x2"
),
match
::
any
().
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_any_of_lazy5
)
{
migraphx
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
auto
two
=
p
.
add_literal
(
2
);
auto
sum
=
p
.
add_instruction
(
sum_op
{},
one
,
two
);
p
.
add_instruction
(
pass_op
{},
sum
);
auto
m
=
match
::
name
(
"sum"
)(
match
::
any_of
(
match
::
args
(
match
::
any
().
bind
(
"x1"
),
match
::
any
().
bind
(
"y1"
)),
match
::
args
(
match
::
name
(
"@literal"
).
bind
(
"x2"
),
match
::
name
(
"@literal"
).
bind
(
"y2"
))));
auto
r
=
find_match
(
p
,
m
);
EXPECT
(
bool
{
r
.
result
==
sum
});
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"x1"
));
EXPECT
(
migraphx
::
contains
(
r
.
instructions
,
"y1"
));
EXPECT
(
bool
{
r
.
instructions
[
"x1"
]
==
one
});
EXPECT
(
bool
{
r
.
instructions
[
"y1"
]
==
two
});
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"x2"
));
EXPECT
(
not
migraphx
::
contains
(
r
.
instructions
,
"y2"
));
}
TEST_CASE
(
match_none_of1
)
TEST_CASE
(
match_none_of1
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/onnx/onnx_test.cpp
View file @
2fdf510d
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/operators.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/onnx.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -514,6 +515,32 @@ TEST_CASE(shape_gather_test)
...
@@ -514,6 +515,32 @@ TEST_CASE(shape_gather_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
transpose_gather_test
)
{
migraphx
::
program
p
;
auto
make_contiguous
=
[
&
p
](
migraphx
::
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
standard
())
{
return
ins
;
}
return
p
.
add_instruction
(
migraphx
::
op
::
contiguous
{},
ins
);
};
auto
data
=
p
.
add_parameter
(
"data"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
3
,
5
,
4
,
6
}});
auto
ind
=
p
.
add_parameter
(
"indices"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
2
,
4
,
3
,
5
}});
auto
tr_data
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
1
,
3
}},
data
);
auto
tr_ind
=
p
.
add_instruction
(
migraphx
::
op
::
transpose
{{
0
,
2
,
1
,
3
}},
ind
);
int
axis
=
1
;
p
.
add_instruction
(
migraphx
::
op
::
gather
{
axis
},
make_contiguous
(
tr_data
),
make_contiguous
(
tr_ind
));
auto
prog
=
migraphx
::
parse_onnx
(
"transpose_gather.onnx"
);
EXPECT
(
p
==
prog
);
}
TEST_CASE
(
flatten_test
)
TEST_CASE
(
flatten_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment