Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
19072784
Commit
19072784
authored
Sep 24, 2023
by
Khalique Ahmed
Browse files
add change and test
parent
1af66a1c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
0 deletions
+40
-0
src/common.cpp
src/common.cpp
+1
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+13
-0
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+26
-0
No files found.
src/common.cpp
View file @
19072784
...
...
@@ -182,6 +182,7 @@ insert_common_args(module& m, instruction_ref ins, std::vector<instruction_ref>
else
{
auto
common
=
common_shape
(
to_shapes
(
inputs
));
std
::
cout
<<
common
<<
std
::
endl
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
input
)
{
if
(
input
->
get_shape
().
lens
()
!=
common
.
lens
())
{
...
...
src/simplify_algebra.cpp
View file @
19072784
...
...
@@ -521,6 +521,19 @@ struct find_inner_broadcast
})
<
(
lens
.
size
()
-
1
);
}))
return
;
auto
bcast_strides
=
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
std
::
vector
<
size_t
>
common_axis
(
bcast_strides
,
0
);
for
(
auto
i
=
0
;
i
<
broadcasts
.
front
()
->
get_shape
().
strides
().
size
();
i
++
)
{
for
(
auto
j
=
0
;
j
<
broadcasts
.
size
();
j
++
)
{
if
(
broadcasts
[
j
]
->
get_shape
().
strides
()[
i
]
==
0
)
common_axis
[
i
]
++
;
}
}
if
(
std
::
find_if
(
common_axis
.
begin
(),
common_axis
.
end
(),
[](
auto
num_common
){
return
num_common
>
1
;
})
==
common_axis
.
end
())
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
...
...
test/simplify_algebra_test.cpp
View file @
19072784
...
...
@@ -639,6 +639,32 @@ TEST_CASE(simplify_inner_broadcast_different_dims)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast_different_dims2
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{{
1
,
1024
,
3072
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
1024
,
3072
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
1
,
1024
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
m1
.
debug_print
();
// migraphx::module m2;
// {
// auto x = m2.add_parameter("x", {migraphx::shape::int32_type, {1024, 768}});
// auto y = m2.add_parameter("y", {migraphx::shape::int32_type, {768}});
// auto yb = m2.add_instruction(migraphx::op::multibroadcast{{384, 768}}, y);
// auto sum = m2.add_instruction(migraphx::make_op("add"), x, yb);
// auto sumb = m2.add_instruction(b, sum);
// m2.add_instruction(pass_op{}, sumb);
// }
// EXPECT(m1 == m2);
}
TEST_CASE
(
simplify_inner_broadcast_different_broadcasts
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
1
,
24
,
112
,
112
}};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment