Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b2f12dae
Commit
b2f12dae
authored
Sep 21, 2023
by
turneram
Browse files
Merge remote-tracking branch 'origin/develop' into ck-gemm-int8
parents
05122cc7
ce2adafd
Changes
143
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
263 additions
and
22 deletions
+263
-22
test/ref/sigmoid.cpp
test/ref/sigmoid.cpp
+1
-1
test/ref/sign.cpp
test/ref/sign.cpp
+1
-1
test/ref/sin.cpp
test/ref/sin.cpp
+1
-1
test/ref/sinh.cpp
test/ref/sinh.cpp
+1
-1
test/ref/slice.cpp
test/ref/slice.cpp
+1
-1
test/ref/softmax.cpp
test/ref/softmax.cpp
+1
-1
test/ref/sqdiff.cpp
test/ref/sqdiff.cpp
+1
-1
test/ref/sqrt.cpp
test/ref/sqrt.cpp
+1
-1
test/ref/squeeze.cpp
test/ref/squeeze.cpp
+1
-1
test/ref/step.cpp
test/ref/step.cpp
+1
-1
test/ref/sub.cpp
test/ref/sub.cpp
+1
-1
test/ref/tan.cpp
test/ref/tan.cpp
+1
-1
test/ref/tanh.cpp
test/ref/tanh.cpp
+1
-1
test/ref/topk.cpp
test/ref/topk.cpp
+1
-1
test/ref/transpose.cpp
test/ref/transpose.cpp
+1
-1
test/ref/unsqueeze.cpp
test/ref/unsqueeze.cpp
+1
-1
test/ref/where.cpp
test/ref/where.cpp
+1
-1
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+195
-3
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+50
-1
test/targets.cpp
test/targets.cpp
+1
-1
No files found.
test/ref/sigmoid.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/sign.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/sin.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/sinh.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/slice.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/softmax.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/sqdiff.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/sqrt.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/squeeze.cpp
View file @
b2f12dae
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/step.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/sub.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/tan.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/tanh.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/topk.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/transpose.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/unsqueeze.cpp
View file @
b2f12dae
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/ref/where.cpp
View file @
b2f12dae
...
@@ -24,7 +24,7 @@
...
@@ -24,7 +24,7 @@
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/
onnx
.hpp>
#include <migraphx/
program
.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/verify.hpp>
...
...
test/simplify_algebra_test.cpp
View file @
b2f12dae
...
@@ -2910,6 +2910,179 @@ TEST_CASE(reorder_reshape_slice_not_apply)
...
@@ -2910,6 +2910,179 @@ TEST_CASE(reorder_reshape_slice_not_apply)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
}
TEST_CASE
(
reorder_reshape_slice_multi_rsp
)
{
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
128
,
3
,
32
,
80
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
input
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
t1
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
t1
);
auto
slc2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
t1
);
auto
c1_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
32
,
128
,
80
}}}),
c1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
4
,
32
,
128
,
80
}}}),
c2
);
auto
r1_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
128
,
128
,
80
}}}),
c1_1
);
auto
r2_1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
128
,
128
,
80
}}}),
c2_1
);
auto
c0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
128
,
128
,
80
}}}),
c0
);
auto
t2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
r1_1
);
auto
c_t2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
r0
,
c_t2
);
m1
.
add_return
({
r1
,
r2
,
dot
,
r2_1
});
};
migraphx
::
module
m2
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
4
,
128
,
3
,
32
,
80
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
input
);
auto
c_t1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t1
);
auto
rsp1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
384
,
128
,
80
}}}),
c_t1
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
256
}},
{
"ends"
,
{
384
}}}),
rsp1
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
128
}},
{
"ends"
,
{
256
}}}),
rsp1
);
auto
t_slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
}}}),
slc1
);
auto
c_t_slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t_slc1
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
128
}}}),
rsp1
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
slc2
,
c_t_slc1
);
auto
c_t1_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
t1
);
auto
rsp2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
12
,
32
,
128
,
80
}}}),
c_t1_1
);
auto
slc2_1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
8
}}}),
rsp2
);
auto
slc2_2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
8
}},
{
"ends"
,
{
12
}}}),
rsp2
);
m2
.
add_return
({
slc2_1
,
slc2_2
,
dot
,
slc0
});
};
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_partial
)
{
migraphx
::
module
m1
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m1
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
8
}}}),
input
);
auto
slc1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
8
}},
{
"ends"
,
{
16
}}}),
input
);
auto
slc2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
16
}},
{
"ends"
,
{
24
}}}),
input
);
auto
slc3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
128
}}}),
input
);
auto
c0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
2
,
4
,
96
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m1
.
add_return
({
ret
,
slc3
});
};
migraphx
::
module
m2
;
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
rsp
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
32
,
4
,
96
}}}),
input
);
auto
slc3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
128
}}}),
input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
}}}),
rsp
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
4
}}}),
rsp
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
4
}},
{
"ends"
,
{
6
}}}),
rsp
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
slc0
,
slc1
);
auto
ret
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
slc2
);
m2
.
add_return
({
ret
,
slc3
});
};
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
reorder_reshape_slice_uneven_slice
)
{
auto
create_p
=
[]
{
migraphx
::
module
m
;
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
128
,
96
}};
auto
input
=
m
.
add_parameter
(
"input"
,
s
);
auto
slc0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
31
}}}),
input
);
auto
slc1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
31
}},
{
"ends"
,
{
62
}}}),
input
);
auto
slc2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
62
}},
{
"ends"
,
{
93
}}}),
input
);
auto
slc3
=
m
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
93
}},
{
"ends"
,
{
128
}}}),
input
);
auto
c0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc0
);
auto
c1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
auto
c2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc2
);
std
::
vector
<
int64_t
>
lens
=
{
1
,
31
,
96
};
auto
r0
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r2
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
auto
sum
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
r0
,
r1
);
auto
ret
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sum
,
r2
);
m
.
add_return
({
ret
,
slc3
});
return
m
;
};
auto
m1
=
create_p
();
auto
m2
=
m1
;
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
template
<
std
::
size_t
BS
>
template
<
std
::
size_t
BS
>
void
reorder_reshape_slice_diff_dims
()
void
reorder_reshape_slice_diff_dims
()
{
{
...
@@ -2931,13 +3104,32 @@ void reorder_reshape_slice_diff_dims()
...
@@ -2931,13 +3104,32 @@ void reorder_reshape_slice_diff_dims()
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r0
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c0
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c1
);
auto
r1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
1
}}),
c1
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
1
}}),
c2
);
auto
r2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
c2
);
m1
.
add_return
({
r0
,
r1
,
r2
});
m1
.
add_return
({
r0
,
r1
,
r2
});
};
};
auto
m2
=
m1
;
migraphx
::
module
m2
;
{
auto
s
=
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
BS
,
96
,
96
}};
auto
input
=
m2
.
add_parameter
(
"input"
,
s
);
auto
slc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
2
}},
{
"starts"
,
{
32
}},
{
"ends"
,
{
64
}}}),
input
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
slc1
);
std
::
vector
<
int64_t
>
lens1
=
{
static_cast
<
int64_t
>
(
BS
),
48
,
2
,
32
};
auto
r1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens1
}}),
c1
);
std
::
vector
<
int64_t
>
lens
=
{
static_cast
<
int64_t
>
(
BS
),
32
,
3
,
96
};
auto
r_new
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
auto
slc0
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
32
}}}),
r_new
);
auto
slc2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
3
}},
{
"starts"
,
{
64
}},
{
"ends"
,
{
96
}}}),
r_new
);
m2
.
add_return
({
slc0
,
r1
,
slc2
});
};
run_pass
(
m1
);
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
}
...
...
test/simplify_reshapes_test.cpp
View file @
b2f12dae
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -67,6 +67,55 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
...
@@ -67,6 +67,55 @@ migraphx::module make_concat_multibroadcast(const std::vector<size_t>& in_lens,
return
m
;
return
m
;
}
}
TEST_CASE
(
broadcast_transpose_scalar
)
{
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
mb
);
m1
.
add_return
({
t1
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
}}}),
l
);
m2
.
add_return
({
mb
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
broadcast_transpose_scalar_multi_use
)
{
// multibroadcast used more than once
migraphx
::
module
m1
;
{
auto
l
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
t1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
mb
);
auto
id
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"identity"
),
mb
);
m1
.
add_return
({
t1
,
id
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
l
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
},
{
0
}});
auto
mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
2
}}}),
l
);
auto
mb2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
}}}),
l
);
auto
id
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"identity"
),
mb2
);
m2
.
add_return
({
mb
,
id
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
double_contig
)
TEST_CASE
(
double_contig
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
test/targets.cpp
View file @
b2f12dae
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
...
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment