Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e9a3e6c1
Commit
e9a3e6c1
authored
Apr 08, 2023
by
Paul
Browse files
Merge branch 'simplify-more-reshapes' into sd-opt
parents
f6e22d56
5967d68d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
298 additions
and
38 deletions
+298
-38
src/include/migraphx/op/reshape.hpp
src/include/migraphx/op/reshape.hpp
+108
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+154
-35
test/op_shape_test.cpp
test/op_shape_test.cpp
+10
-0
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+26
-0
No files found.
src/include/migraphx/op/reshape.hpp
View file @
e9a3e6c1
...
...
@@ -96,10 +96,41 @@ struct reshape
return
{
s0
.
type
(),
output_dyn_dims
};
}
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>=
dim
;
});
if
(
x
!=
dim
)
return
start
;
return
it
;
}
template
<
class
DimIterator
,
class
StrideIterator
>
static
auto
can_strides_merge
(
DimIterator
dim_start
,
DimIterator
dim_last
,
StrideIterator
stride_start
,
StrideIterator
stride_last
)
{
auto
cstride
=
*
std
::
prev
(
stride_last
);
return
std
::
equal
(
std
::
make_reverse_iterator
(
dim_last
),
std
::
make_reverse_iterator
(
dim_start
+
1
),
std
::
make_reverse_iterator
(
stride_last
-
1
),
std
::
make_reverse_iterator
(
stride_start
),
[
&
](
auto
dim
,
auto
stride
)
{
cstride
*=
dim
;
return
stride
==
cstride
;
});
}
shape
static_compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
size_t
n_neg_dims
)
const
{
check_shapes
{
inputs
,
*
this
}.
standard
();
auto
&&
idims
=
inputs
.
front
().
lens
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
&&
idims
=
inputs
.
front
().
lens
();
auto
&&
istrides
=
inputs
.
front
().
strides
();
std
::
vector
<
std
::
size_t
>
rdims
(
dims
.
begin
(),
dims
.
end
());
for
(
std
::
size_t
i
=
0
;
i
<
dims
.
size
();
i
++
)
...
...
@@ -125,7 +156,81 @@ struct reshape
}
}
shape
s
{
inputs
.
front
().
type
(),
rdims
};
shape
s
;
if
(
inputs
.
front
().
standard
())
{
s
=
shape
{
inputs
.
front
().
type
(),
rdims
};
}
else
{
std
::
vector
<
std
::
size_t
>
rstrides
;
std
::
size_t
i
=
0
;
std
::
size_t
r
=
0
;
while
(
i
<
idims
.
size
()
and
r
<
rdims
.
size
())
{
auto
idim
=
idims
[
i
];
auto
rdim
=
rdims
[
r
];
if
(
rdim
==
idim
)
{
rstrides
.
push_back
(
istrides
[
i
]);
}
// squeeze
else
if
(
rdim
>
idim
)
{
auto
start
=
idims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
idims
.
end
(),
rdim
);
if
(
it
==
start
)
break
;
auto
n
=
it
-
start
;
if
((
i
+
n
)
>
istrides
.
size
())
break
;
if
(
not
can_strides_merge
(
start
,
it
+
1
,
istrides
.
begin
()
+
i
,
istrides
.
begin
()
+
i
+
n
))
break
;
i
+=
n
;
rstrides
.
push_back
(
istrides
[
i
]);
}
// unsqueeze
else
if
(
rdim
<
idim
)
{
auto
start
=
rdims
.
begin
()
+
i
;
auto
it
=
compute_end_dim
(
start
,
rdims
.
end
(),
idim
);
if
(
it
==
start
)
break
;
auto
n
=
it
-
start
;
if
((
r
+
n
)
>
rdims
.
size
())
break
;
auto
stride
=
istrides
[
i
]
*
idim
;
std
::
for_each
(
start
,
it
+
1
,
[
&
](
auto
dim
)
{
stride
/=
dim
;
rstrides
.
push_back
(
stride
);
});
r
+=
n
;
}
i
++
;
r
++
;
}
// Handle trailing 1s
if
(
rstrides
.
size
()
<
rdims
.
size
()
and
not
rstrides
.
empty
())
{
auto
stride
=
rstrides
.
back
();
for
(
auto
d
:
range
(
rdims
.
begin
()
+
rstrides
.
size
(),
rdims
.
end
()))
{
if
(
d
!=
1
)
break
;
rstrides
.
push_back
(
stride
);
}
}
if
(
rdims
.
size
()
!=
rstrides
.
size
())
MIGRAPHX_THROW
(
"Reshape on axis that is not standard"
);
s
=
shape
{
inputs
.
front
().
type
(),
rdims
,
rstrides
};
}
assert
(
s
.
bytes
()
==
inputs
.
front
().
bytes
());
if
(
s
.
elements
()
!=
inputs
.
front
().
elements
())
MIGRAPHX_THROW
(
"Reshape: Wrong number of elements for reshape: reshape has "
+
std
::
to_string
(
s
.
elements
())
+
" elements whereas the input has "
+
...
...
src/simplify_reshapes.cpp
View file @
e9a3e6c1
...
...
@@ -49,7 +49,6 @@ const auto& reshaper_names()
static
const
std
::
unordered_set
<
std
::
string
>
names
=
{
"flatten"
,
"reshape"
,
"contiguous"
,
"squeeze"
,
"unsqueeze"
};
...
...
@@ -89,38 +88,23 @@ struct find_reshaper
{
auto
matcher
()
const
{
return
match
::
name
(
reshaper_names
())(
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
auto
no_output_reshape
=
match
::
none_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
()));
auto
input_reshape
=
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
reshaper_names
())));
auto
input
=
match
::
skip
(
match
::
name
(
reshaper_names
()),
match
::
name
(
"contiguous"
))(
match
::
arg
(
0
).
bind
(
"x"
));
return
match
::
name
(
reshaper_names
())(
no_output_reshape
,
input_reshape
,
input
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
not
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
auto
ins
=
mr
.
result
;
auto
input
=
mr
.
instructions
[
"x"
];
auto
dims
=
ins
->
get_shape
().
lens
();
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
m
.
end
(),
m
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
return
i
->
get_shape
()
==
(
*
start
)
->
get_shape
()
and
i
!=
(
*
start
);
});
if
(
last
!=
reshapes
.
rend
())
{
r
=
std
::
make_pair
(
*
start
,
*
last
);
break
;
}
}
if
(
r
.
first
!=
r
.
second
)
{
m
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
if
(
not
input
->
get_shape
().
standard
())
input
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
input
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
input
);
}
};
...
...
@@ -603,14 +587,15 @@ struct find_reshape_cont
};
// match sequence of transpose --> contiguous --> reshaper_op
auto
match_transpose_contiguous_reshaper
()
template
<
class
...
Ms
>
auto
match_transpose_contiguous_reshaper
(
Ms
...
ms
)
{
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
match
::
used_once
(),
match
::
args
(
match
::
name
(
"contiguous"
)(
match
::
used_once
(),
match
::
args
(
match
::
transpose_shape
().
bind
(
"trans_ins"
)))
.
bind
(
"cont_ins"
)))
match
::
args
(
match
::
name
(
"contiguous"
)(
match
::
used_once
(),
match
::
args
(
match
::
transpose_shape
(
ms
...
).
bind
(
"trans_ins"
)))
.
bind
(
"cont_ins"
)))
.
bind
(
"reshaper_ins"
);
};
...
...
@@ -642,6 +627,45 @@ struct find_transpose_contiguous_reshaper_unary
}
};
struct
find_mul_add_transpose_contiguous_reshaper_gemm
{
auto
matcher
()
const
{
auto
pw
=
match
::
name
(
"mul"
,
"add"
)(
match
::
used_once
(),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"c"
),
match
::
any
().
bind
(
"x"
)));
return
match
::
name
(
"dot"
)(
match
::
either_arg
(
0
,
1
)(
match_transpose_contiguous_reshaper
(
match
::
args
(
pw
.
bind
(
"pointwise"
))),
match
::
is_constant
()));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
pw_ins
=
r
.
instructions
[
"pointwise"
];
auto
insert_reshapes
=
[
&
](
auto
x
)
{
auto
t
=
m
.
insert_instruction
(
ins
,
trans_ins
->
get_operator
(),
x
);
auto
c
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
t
);
return
m
.
insert_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
c
);
};
if
(
x_ins
->
name
()
==
"mul"
)
{
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
{
insert_reshapes
(
x_ins
->
inputs
()[
0
]),
insert_reshapes
(
x_ins
->
inputs
()[
1
])});
}
auto
y_ins
=
m
.
insert_instruction
(
ins
,
pw_ins
->
get_operator
(),
{
x_ins
,
insert_reshapes
(
c_ins
)});
m
.
replace_instruction
(
reshaper_ins
,
y_ins
);
}
};
struct
find_slice_transpose
{
auto
matcher
()
const
...
...
@@ -797,6 +821,98 @@ struct find_transpose_slice
}
};
struct
find_reshape_gemm
{
auto
matcher
()
const
{
return
match
::
name
(
"reshape"
)(
match
::
arg
(
0
)(
match
::
name
(
"dot"
)));
}
static
bool
is_batched_unsqueeze
(
instruction_ref
ins
)
{
auto
input
=
ins
->
inputs
().
front
()
->
get_shape
().
lens
();
auto
output
=
ins
->
get_shape
().
lens
();
if
(
output
.
size
()
<=
input
.
size
())
return
false
;
if
(
not
std
::
equal
(
input
.
end
()
-
2
,
input
.
end
(),
output
.
end
()
-
2
,
output
.
end
()))
return
false
;
return
true
;
}
static
operation
make_reshape
(
std
::
vector
<
std
::
size_t
>
batches
,
instruction_ref
ins
)
{
batches
.
insert
(
batches
.
end
(),
ins
->
get_shape
().
lens
().
end
()
-
2
,
ins
->
get_shape
().
lens
().
end
());
return
make_op
(
"reshape"
,
{{
"dims"
,
batches
}});
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
reshape_ins
=
r
.
result
;
auto
dot_ins
=
reshape_ins
->
inputs
().
front
();
// TODO: Put this in the matcher
if
(
not
is_batched_unsqueeze
(
reshape_ins
))
return
;
std
::
vector
<
std
::
size_t
>
batches
;
std
::
copy
(
reshape_ins
->
get_shape
().
lens
().
begin
(),
reshape_ins
->
get_shape
().
lens
().
end
()
-
2
,
std
::
back_inserter
(
batches
));
auto
input0
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
0
]),
dot_ins
->
inputs
()[
0
]);
auto
input1
=
m
.
insert_instruction
(
dot_ins
,
make_reshape
(
batches
,
dot_ins
->
inputs
()[
1
]),
dot_ins
->
inputs
()[
1
]);
m
.
replace_instruction
(
dot_ins
,
make_op
(
"dot"
),
input0
,
input1
);
}
};
struct
find_broadcast_reshaper
{
auto
matcher
()
const
{
auto
broadcast
=
match
::
broadcast_shape
(
match
::
skip
(
match
::
broadcast_shape
())(
match
::
any
().
bind
(
"x"
)));
return
match
::
name
(
reshaper_names
())(
match
::
args
(
match
::
skip
(
match
::
name
(
"contiguous"
))(
broadcast
.
bind
(
"broadcast"
))));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
broadcast_ins
=
r
.
instructions
[
"broadcast"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
broadcast_shape
=
broadcast_ins
->
get_shape
();
auto
result_shape
=
ins
->
get_shape
();
if
(
std
::
accumulate
(
broadcast_shape
.
strides
().
begin
(),
broadcast_shape
.
strides
().
end
(),
0
)
!=
1
)
return
;
auto
baxis
=
std
::
find
(
broadcast_shape
.
strides
().
begin
(),
broadcast_shape
.
strides
().
end
(),
1
)
-
broadcast_shape
.
strides
().
begin
();
auto
relements
=
result_shape
.
lens
();
std
::
partial_sum
(
relements
.
begin
(),
relements
.
end
(),
relements
.
begin
(),
std
::
multiplies
<>
{});
auto
prefix_elements
=
std
::
accumulate
(
broadcast_shape
.
lens
().
begin
(),
broadcast_shape
.
lens
().
begin
()
+
baxis
+
1
,
1
,
std
::
multiplies
<>
{});
auto
axis
=
std
::
find
(
relements
.
begin
(),
relements
.
end
(),
prefix_elements
)
-
relements
.
begin
();
if
(
axis
>=
relements
.
size
())
return
;
if
(
x_ins
->
get_shape
().
lens
().
size
()
>
1
)
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
),
x_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
ins
->
get_shape
().
lens
()}}),
x_ins
);
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
...
...
@@ -804,9 +920,10 @@ void simplify_reshapes::apply(module& m) const
match
::
find_matches
(
m
,
find_where_op
{},
find_resize
{},
find_reshape_cont
{},
find_nop_reshapes
{},
find_reshaper
{},
find_broadcast_reshaper
{},
// find_reshape_cont{},
find_transpose
{},
find_concat_transpose
{},
find_concat_multibroadcasts
{},
...
...
@@ -815,7 +932,9 @@ void simplify_reshapes::apply(module& m) const
find_nested_concat
{},
find_transpose_slice
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
find_transpose_contiguous_reshaper_unary
{},
find_mul_add_transpose_contiguous_reshaper_gemm
{},
find_reshape_gemm
{});
dead_code_elimination
{}.
apply
(
m
);
}
}
...
...
test/op_shape_test.cpp
View file @
e9a3e6c1
...
...
@@ -2181,6 +2181,16 @@ TEST_CASE(reshape_shape)
}
}
TEST_CASE
(
reshape_nonstandard_unsqeeze
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{
4
,
24
,
1
,
1
,
1
},
{
1
,
4
,
1
,
1
,
1
}};
std
::
vector
<
std
::
size_t
>
lens
=
{
4
,
1
,
3
,
4
,
2
};
std
::
vector
<
int64_t
>
perm
=
{
4
,
0
,
1
,
2
,
3
};
migraphx
::
shape
output
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
lens
,
migraphx
::
invert_permutation
(
perm
));
expect_shape
(
output
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
lens
}}),
input
);
}
TEST_CASE
(
reshape_dyn_shape
)
{
migraphx
::
shape
input
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
},
{
24
,
24
},
{
1
,
1
},
{
1
,
1
}}};
...
...
test/simplify_reshapes_test.cpp
View file @
e9a3e6c1
...
...
@@ -1503,4 +1503,30 @@ TEST_CASE(transpose_slice_non_packed_multi_axis)
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
TEST_CASE
(
broadcast_transpose_reshape
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
320
,
1
,
1
}});
auto
broadcast
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
320
,
64
,
64
}}}),
x
);
auto
transpose
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
broadcast
);
auto
contiguous
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
transpose
);
auto
reshape
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
2
,
4096
,
320
}}}),
contiguous
);
m1
.
add_return
({
reshape
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
320
,
1
,
1
}});
auto
squeeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
),
x
);
auto
broadcast
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
2
},
{
"out_lens"
,
{
2
,
4096
,
320
}}}),
squeeze
);
m2
.
add_return
({
broadcast
});
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment