Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
69de5fdb
Commit
69de5fdb
authored
Apr 28, 2023
by
Paul
Browse files
Add common_dims
parent
59386637
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
233 additions
and
3 deletions
+233
-3
src/CMakeLists.txt
src/CMakeLists.txt
+1
-0
src/common_dims.cpp
src/common_dims.cpp
+82
-0
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+23
-0
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+21
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+96
-0
test/common_dims.cpp
test/common_dims.cpp
+10
-0
No files found.
src/CMakeLists.txt
View file @
69de5fdb
...
...
@@ -34,6 +34,7 @@ add_library(migraphx
argument.cpp
auto_contiguous.cpp
common.cpp
common_dims.cpp
compile_src.cpp
convert_to_json.cpp
cpp_generator.cpp
...
...
src/common_dims.cpp
0 → 100644
View file @
69de5fdb
#include <migraphx/common_dims.hpp>
#include <migraphx/ranges.hpp>
#include <algorithm>
#include <cassert>
#include <numeric>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
>
static
auto
compute_end_dim
(
Iterator
start
,
Iterator
last
,
std
::
size_t
dim
)
{
std
::
size_t
x
=
1
;
auto
it
=
std
::
find_if
(
start
,
last
,
[
&
](
auto
i
)
{
x
*=
i
;
return
x
>=
dim
;
});
if
(
x
!=
dim
)
return
start
;
return
it
;
}
template
<
class
Iterator
>
static
auto
elements
(
Iterator
start
,
Iterator
last
)
{
return
std
::
accumulate
(
start
,
last
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
{
return
elements
(
r
.
begin
(),
r
.
end
());
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
auto
it1
=
dims1
.
begin
();
auto
it2
=
dims2
.
begin
();
std
::
size_t
rem1
=
1
;
std
::
size_t
rem2
=
1
;
while
(
it1
!=
dims1
.
end
()
and
it2
!=
dims2
.
end
())
{
auto
d1
=
*
it1
;
auto
d2
=
*
it2
;
if
(
d1
==
d2
)
{
cd
.
axes_map1
.
push_back
({
cd
.
dims
.
size
()});
cd
.
axes_map2
.
push_back
({
cd
.
dims
.
size
()});
cd
.
dims
.
push_back
(
d1
);
it1
++
;
it2
++
;
}
else
if
(
d1
<
d2
)
{
auto
dim_end
=
compute_end_dim
(
it1
,
dims1
.
begin
(),
d2
);
auto
dims
=
range
(
it1
,
dim_end
);
auto
n
=
elements
(
dims
);
if
(
n
!=
d2
)
{
// If not divisible then we can't compute a common dims
if
((
d2
%
n
)
!=
0
)
return
{};
rem1
=
d2
/
n
;
}
std
::
vector
<
std
::
size_t
>
axes
(
distance
(
dims
));
std
::
iota
(
axes
.
begin
(),
axes
.
end
(),
cd
.
dims
.
size
());
cd
.
axes_map1
.
push_back
(
axes
);
cd
.
axes_map2
.
push_back
(
axes
);
cd
.
dims
.
insert
(
cd
.
dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
rem1
!=
1
)
cd
.
dims
.
push_back
(
rem1
);
it1
+=
distance
(
dims
);
it2
++
;
}
}
return
cd
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/include/migraphx/common_dims.hpp
0 → 100644
View file @
69de5fdb
#ifndef MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#define MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#include <migraphx/config.hpp>
#include <cstdint>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
common_dims
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/matcher.hpp
View file @
69de5fdb
...
...
@@ -474,7 +474,7 @@ struct match_fold_f
template
<
class
...
Ts
>
auto
operator
()(
Ts
...
ms
)
const
{
return
make_b
f
_matcher
(
return
make_b
asic_fun
_matcher
(
[
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
bool
matches
=
match_fold_f
::
fold_matchers
(
ctx
,
ins
,
ms
...);
if
(
matches
==
Matches
)
...
...
@@ -489,7 +489,7 @@ struct match_fold_f
return
[
=
](
auto
...
ms
)
{
// Workaround ICE on gcc by packing matchers into an object
auto
mpack
=
pack
(
ms
...);
return
make_b
f
_matcher
(
return
make_b
asic_fun
_matcher
(
[
=
](
matcher_context
&
ctx
,
instruction_ref
start
)
->
optional
<
instruction_ref
>
{
Op
op
;
bool
matches
=
Start
;
...
...
@@ -835,10 +835,28 @@ inline auto has_attribute(const std::string& name)
[
=
](
instruction_ref
ins
)
{
return
ins
->
get_operator
().
attributes
().
contains
(
name
);
});
}
template
<
class
T
>
inline
auto
has_attribute
(
const
std
::
string
&
name
,
T
value
)
{
return
make_basic_pred_matcher
(
[
=
](
instruction_ref
ins
)
{
auto
attributes
=
ins
->
get_operator
().
attributes
();
if
(
not
attributes
.
contains
(
name
))
return
false
;
return
attributes
[
name
].
to
<
T
>
()
==
value
;
});
}
template
<
class
...
Ms
>
auto
pointwise
(
Ms
...
ms
)
{
return
match
::
has_attribute
(
"pointwise"
)(
ms
...);
return
match
::
has_attribute
(
"pointwise"
,
true
)(
ms
...);
}
template
<
class
...
Ms
>
auto
reduce
(
Ms
...
ms
)
{
return
match
::
has_attribute
(
"reduce"
,
true
)(
ms
...);
}
}
// namespace match
...
...
src/simplify_reshapes.cpp
View file @
69de5fdb
...
...
@@ -913,6 +913,102 @@ struct find_broadcast_reshaper
}
};
struct
find_poinwise_reduce_reshape
{
auto
matcher
()
const
{
auto
reshaper
=
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
});
auto
skip_contiguous
=
match
::
skip
(
match
::
name
(
"contiguous"
));
auto
pointwise_or_reduce
=
match
::
any_of
(
match
::
pointwise
(),
match
::
reduce
());
auto
reshape_pointwise_or_reduce
=
reshaper
(
skip_contiguous
(
pointwise_or_reduce
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
pointwise_or_reduce
(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise_or_reduce
));
}
static
bool
is_pointwise
(
instruction_ref
ins
)
{
auto
a
=
ins
->
get_operator
().
attributes
();
return
a
.
get
(
"pointwise"
,
false
);
}
static
bool
is_reduce
(
instruction_ref
ins
)
{
auto
a
=
ins
->
get_operator
().
attributes
();
return
a
.
get
(
"reduce"
,
false
);
}
static
bool
is_pointwise_or_reduce
(
instruction_ref
ins
)
{
auto
a
=
ins
->
get_operator
().
attributes
();
return
a
.
get
(
"pointwise"
,
false
)
or
a
.
get
(
"reduce"
,
false
);
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
dims1
=
x_ins
->
get_shape
().
lens
();
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
axes
;
if
(
x_ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
{
axes
=
x_ins
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
}
std
::
unordered_set
<
instruction_ref
>
inss
;
instruction_ref
entry
;
// Collect from inputs
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
inss
.
insert
(
i
);
entry
=
i
;
auto
pointwise_or_reduce
=
[
&
](
instruction_ref
input
)
{
if
(
input
->
can_eval
())
return
false
;
return
is_pointwise
(
input
);
};
auto
it
=
std
::
find_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
pointwise_or_reduce
);
if
(
it
==
i
->
inputs
().
end
())
return
;
auto
it2
=
std
::
find_if
(
it
,
i
->
inputs
().
end
(),
pointwise_or_reduce
);
// If there is more than one pointwise_reduce than stop
if
(
it2
!=
i
->
inputs
().
end
())
return
;
self
(
*
it
);
})(
x_ins
);
// Collect from output
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
for
(
auto
output
:
out
->
outputs
())
{
if
(
not
std
::
all_of
(
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
return
input
->
can_eval
()
or
contains
(
inss
,
input
);
}))
continue
;
if
(
not
is_pointwise_or_reduce
(
ins
))
continue
;
inss
.
insert
(
output
);
self
(
output
);
}
})(
x_ins
);
std
::
vector
<
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
aux
;
// Topological sort
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
instructions
.
push_back
(
i
);
for
(
auto
output
:
i
->
outputs
())
{
if
(
not
contains
(
inss
,
output
))
{
aux
.
insert
(
output
);
continue
;
}
self
(
output
);
}
})(
entry
);
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
...
...
test/common_dims.cpp
0 → 100644
View file @
69de5fdb
#include <migraphx/common_dims.hpp>
#include <test.hpp>
TEST_CASE
(
common1
)
{
auto
cd
=
migraphx
::
common_dims
::
compute
({
2
,
32
,
2560
},
{
2
,
1280
,
8
,
8
});
EXPECT
(
cd
.
dims
==
std
::
vector
<
std
::
size_t
>
{
2
,
32
,
40
,
8
,
8
});
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment