Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
17c03174
Unverified
Commit
17c03174
authored
Jun 14, 2023
by
Umang Yadav
Committed by
GitHub
Jun 14, 2023
Browse files
Revert "Revert "Handle broadcasts across dot and concat (#1689) (#1731)"" (#1837)
parent
7b82fc39
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
213 additions
and
23 deletions
+213
-23
src/normalize_attributes.cpp
src/normalize_attributes.cpp
+11
-8
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+116
-15
test/simplify_algebra_test.cpp
test/simplify_algebra_test.cpp
+86
-0
No files found.
src/normalize_attributes.cpp
View file @
17c03174
...
...
@@ -40,10 +40,12 @@ inline namespace MIGRAPHX_INLINE_NS {
*
* See normalize_attribute.hpp for explaining the options.
*/
template
<
class
Message
>
auto
tune_attribute
(
const
std
::
vector
<
int64_t
>&
vec
,
const
std
::
vector
<
int64_t
>&
axes
,
const
value
&
val
,
const
std
::
vector
<
std
::
size_t
>&
lens
)
const
std
::
vector
<
std
::
size_t
>&
lens
,
Message
m
)
{
std
::
vector
<
int64_t
>
result
(
vec
);
int64_t
n_rank
=
lens
.
size
();
...
...
@@ -84,14 +86,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
{
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less_equal
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
value out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
value out of range!"
);
}
}
else
{
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
max_vals
.
begin
(),
std
::
less
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
value out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
value out of range!"
);
}
}
}
...
...
@@ -124,14 +126,14 @@ auto tune_attribute(const std::vector<int64_t>& vec,
if
(
not
std
::
equal
(
min_vals
.
begin
(),
min_vals
.
end
(),
result
.
begin
(),
std
::
less_equal
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
attribute out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
attribute out of range!"
);
}
}
else
{
if
(
not
std
::
equal
(
result
.
begin
(),
result
.
end
(),
min_vals
.
begin
(),
std
::
less
<>
{}))
{
MIGRAPHX_THROW
(
"TUNE_VECTOR:
attribute out of range!"
);
MIGRAPHX_THROW
(
m
()
+
"
attribute out of range!"
);
}
}
}
...
...
@@ -193,7 +195,8 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
const
auto
&
key
=
rv
.
get_key
();
if
(
val
.
contains
(
key
))
{
auto
vv
=
val
.
at
(
key
).
without_key
();
auto
message
=
[
&
]
{
return
op
.
name
()
+
": "
+
key
+
": "
;
};
auto
vv
=
val
.
at
(
key
).
without_key
();
if
(
vv
.
is_array
())
{
std
::
vector
<
int64_t
>
axes
;
...
...
@@ -202,7 +205,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
axes
=
val
.
at
(
"axes"
).
without_key
().
to_vector
<
int64_t
>
();
}
auto
vec
=
vv
.
to_vector
<
int64_t
>
();
auto
result
=
tune_attribute
(
vec
,
axes
,
rv
.
without_key
(),
lens
);
auto
result
=
tune_attribute
(
vec
,
axes
,
rv
.
without_key
(),
lens
,
message
);
val
[
key
]
=
result
;
op
.
from_value
(
val
);
val
=
op
.
to_value
();
...
...
@@ -211,7 +214,7 @@ bool normalize_attributes(operation& op, const std::vector<std::size_t>& lens)
else
{
auto
num
=
vv
.
to
<
int64_t
>
();
auto
result
=
tune_attribute
({
num
},
{
num
},
rv
.
without_key
(),
lens
);
auto
result
=
tune_attribute
({
num
},
{
num
},
rv
.
without_key
(),
lens
,
message
);
val
[
key
]
=
result
.
front
();
op
.
from_value
(
val
);
val
=
op
.
to_value
();
...
...
src/simplify_algebra.cpp
View file @
17c03174
...
...
@@ -361,30 +361,118 @@ struct find_inner_broadcast
{
auto
matcher
()
const
{
return
pointwise
(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
static
auto
non_scalar_op
(
const
std
::
string
&
name
)
{
return
[
=
](
instruction_ref
ins
)
{
if
(
ins
->
get_shape
().
scalar
())
return
false
;
return
ins
->
name
()
==
name
;
};
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
broadcasts
=
ins
->
inputs
();
if
(
broadcasts
.
empty
())
return
;
bool
mixed_broadcasts
=
any_of
(
broadcasts
,
non_scalar_op
(
"broadcast"
))
and
any_of
(
broadcasts
,
non_scalar_op
(
"multibroadcast"
));
// If the broadcast is not a single dimension, then dont perform inner_broadcast
if
(
mixed_broadcasts
and
any_of
(
broadcasts
,
[
&
](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
false
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
false
;
auto
input
=
i
->
inputs
().
at
(
0
);
const
auto
&
lens
=
input
->
get_shape
().
lens
();
return
std
::
count_if
(
lens
.
begin
(),
lens
.
end
(),
[
&
](
std
::
size_t
d
)
{
return
d
==
1
;
})
<
(
lens
.
size
()
-
1
);
}))
return
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
std
::
back_inserter
(
inputs
),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
()
!=
inputs
.
front
()
->
get_shape
()
and
i
->
get_shape
().
elements
()
!=
1
;
}))
return
;
auto
b_it
=
std
::
find_if
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
[
&
](
auto
i
)
{
return
not
i
->
get_shape
().
scalar
();
});
if
(
b_it
==
broadcasts
.
end
())
b_it
=
broadcasts
.
begin
();
[
&
](
instruction_ref
i
)
{
auto
input
=
i
->
inputs
().
front
();
if
(
mixed_broadcasts
and
not
i
->
get_shape
().
scalar
()
and
i
->
get_shape
().
lens
().
size
()
>
1
)
return
m
.
insert_instruction
(
i
,
make_op
(
"squeeze"
),
input
);
return
input
;
});
std
::
sort
(
broadcasts
.
begin
(),
broadcasts
.
end
(),
by
(
std
::
less
<>
{},
[](
instruction_ref
i
)
{
if
(
i
->
get_shape
().
scalar
())
return
2
;
else
if
(
i
->
name
()
==
"broadcast"
)
return
0
;
if
(
i
->
name
()
==
"multibroadcast"
)
return
1
;
return
3
;
}));
auto
op
=
insert_common_op
(
m
,
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
(
*
b_it
)
->
get_operator
(),
op
);
m
.
replace_instruction
(
ins
,
broadcasts
.
front
()
->
get_operator
(),
op
);
}
};
struct
find_dot_broadcast
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
broadcast
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a
=
ins
->
inputs
()[
0
];
auto
b
=
ins
->
inputs
()[
1
];
if
(
a
->
get_operator
().
name
()
!=
b
->
get_operator
().
name
())
return
;
if
(
ins
->
get_shape
().
lens
().
size
()
<
3
)
return
;
auto
nbatch_axes
=
ins
->
get_shape
().
lens
().
size
()
-
2
;
const
auto
&
a_strides
=
a
->
get_shape
().
strides
();
const
auto
&
b_strides
=
b
->
get_shape
().
strides
();
// Find leading batch axes that are broadcasted
auto
p
=
std
::
mismatch
(
a_strides
.
begin
(),
a_strides
.
begin
()
+
nbatch_axes
,
b_strides
.
begin
(),
b_strides
.
begin
()
+
nbatch_axes
,
[](
auto
astride
,
auto
bstride
)
{
return
astride
==
0
and
bstride
==
0
;
});
auto
naxes
=
p
.
first
-
a_strides
.
begin
();
assert
(
naxes
<=
nbatch_axes
);
std
::
vector
<
std
::
size_t
>
axes
(
naxes
);
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
0
);
auto
insert_broadcast
=
[
&
](
instruction_ref
b_ins
)
->
instruction_ref
{
auto
input
=
b_ins
->
inputs
()[
0
];
std
::
vector
<
std
::
size_t
>
lens
(
b_ins
->
get_shape
().
lens
().
begin
()
+
naxes
,
b_ins
->
get_shape
().
lens
().
end
());
if
(
b_ins
->
name
()
==
"multibroadcast"
)
{
return
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
input
);
}
else
if
(
b_ins
->
name
()
==
"broadcast"
)
{
auto
v
=
b_ins
->
get_operator
().
to_value
();
auto
axis
=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
()
-
naxes
;
return
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
lens
}}),
input
);
}
assert
(
false
);
return
m
.
end
();
};
auto
a1
=
insert_broadcast
(
a
);
auto
b1
=
insert_broadcast
(
b
);
auto
dot
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
a1
,
b1
);
auto
broadcast
=
m
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
dot
);
m
.
replace_instruction
(
ins
,
broadcast
);
}
};
...
...
@@ -393,7 +481,8 @@ struct find_concat_op
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
any_of
[
match
::
inputs
()](
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
)),
match
::
used_once
()));
match
::
any_of
(
match
::
pointwise
(),
match
::
name
(
"broadcast"
,
"multibroadcast"
)),
match
::
used_once
()));
}
template
<
class
Iterator
>
...
...
@@ -412,7 +501,8 @@ struct find_concat_op
static
bool
is_valid_op
(
const
operation
&
op
)
{
return
op
.
name
()
==
"broadcast"
or
op
.
attributes
().
contains
(
"pointwise"
);
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
())
or
op
.
attributes
().
contains
(
"pointwise"
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -440,6 +530,16 @@ struct find_concat_op
op
=
b
;
iaxis
=
0
;
}
else
if
(
op
.
name
()
==
"multibroadcast"
)
{
shape
bshape
=
(
*
start
)
->
get_shape
();
auto
input
=
(
*
start
)
->
inputs
()[
0
];
if
(
iaxis
>=
bshape
.
strides
().
size
()
or
bshape
.
strides
()[
iaxis
]
==
0
)
return
{
start
,
last
};
op
.
from_value
({{
"out_lens"
,
get_output_lens
(
start
,
last
,
iaxis
)}});
auto
delta
=
bshape
.
lens
().
size
()
-
input
->
get_shape
().
lens
().
size
();
iaxis
-=
delta
;
}
std
::
vector
<
instruction_ref
>
concats
;
for
(
std
::
size_t
i
=
0
;
i
<
x
->
inputs
().
size
();
i
++
)
...
...
@@ -1260,6 +1360,7 @@ void simplify_algebra::apply(module& m) const
{
match
::
find_matches
(
m
,
find_inner_broadcast
{},
find_dot_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
find_add_convs
{},
...
...
test/simplify_algebra_test.cpp
View file @
17c03174
...
...
@@ -613,6 +613,60 @@ TEST_CASE(simplify_inner_broadcast_scalar)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast_different_dims
)
{
auto
b
=
migraphx
::
op
::
multibroadcast
{{
2
,
384
,
768
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
768
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
b
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
384
,
768
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
768
}});
auto
yb
=
m2
.
add_instruction
(
migraphx
::
op
::
multibroadcast
{{
384
,
768
}},
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
yb
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_inner_broadcast_different_broadcasts
)
{
auto
b
=
migraphx
::
op
::
broadcast
{
1
,
{
1
,
24
,
112
,
112
}};
auto
mb
=
migraphx
::
op
::
multibroadcast
{{
1
,
24
,
112
,
112
}};
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
24
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
24
,
1
,
1
}});
auto
xb
=
m1
.
add_instruction
(
b
,
x
);
auto
yb
=
m1
.
add_instruction
(
mb
,
y
);
auto
sum
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xb
,
yb
);
m1
.
add_instruction
(
pass_op
{},
sum
);
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
int32_type
,
{
24
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
int32_type
,
{
24
,
1
,
1
}});
auto
xs
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
),
x
);
auto
ys
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
),
y
);
auto
sum
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
xs
,
ys
);
auto
sumb
=
m2
.
add_instruction
(
b
,
sum
);
m2
.
add_instruction
(
pass_op
{},
sumb
);
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
simplify_add_conv1
)
{
migraphx
::
module
m
;
...
...
@@ -3003,6 +3057,38 @@ TEST_CASE(reorder_slice_ins_deps)
EXPECT
(
m
==
create_module
());
}
TEST_CASE
(
dot_broadcast_different_rank
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
768
}});
auto
y
=
m1
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
768
,
3072
}});
auto
xb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
384
,
768
}}}),
x
);
auto
yb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
768
,
3072
}}}),
y
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
xb
,
yb
);
m1
.
add_return
({
dot
});
};
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
768
}});
auto
y
=
m2
.
add_parameter
(
"y"
,
{
migraphx
::
shape
::
float_type
,
{
768
,
3072
}});
auto
xb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
384
,
768
}}}),
x
);
auto
yb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
768
,
3072
}}}),
y
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
xb
,
yb
);
auto
broadcast
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
384
,
3072
}}}),
dot
);
m2
.
add_return
({
broadcast
});
};
run_pass
(
m1
);
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
dot_fusion_reshape
)
{
migraphx
::
module
m1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment