Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
665455e1
"test/srt/vscode:/vscode.git/clone" did not exist on "46e9d1c7c19e4734b32b70a1bcafef70464f6f49"
Commit
665455e1
authored
Apr 28, 2023
by
Paul
Browse files
Format
parent
69de5fdb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
42 additions
and
41 deletions
+42
-41
src/common_dims.cpp
src/common_dims.cpp
+13
-12
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+2
-3
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+7
-8
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+20
-18
No files found.
src/common_dims.cpp
View file @
665455e1
...
@@ -20,18 +20,19 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
...
@@ -20,18 +20,19 @@ static auto compute_end_dim(Iterator start, Iterator last, std::size_t dim)
return
it
;
return
it
;
}
}
template
<
class
Iterator
>
template
<
class
Iterator
>
static
auto
elements
(
Iterator
start
,
Iterator
last
)
static
auto
elements
(
Iterator
start
,
Iterator
last
)
{
{
return
std
::
accumulate
(
start
,
last
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
return
std
::
accumulate
(
start
,
last
,
std
::
size_t
{
1
},
std
::
multiplies
<>
{});
}
}
template
<
class
Range
>
template
<
class
Range
>
static
auto
elements
(
const
Range
&
r
)
static
auto
elements
(
const
Range
&
r
)
{
{
return
elements
(
r
.
begin
(),
r
.
end
());
return
elements
(
r
.
begin
(),
r
.
end
());
}
}
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
common_dims
common_dims
::
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
)
{
{
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
assert
(
elements
(
dims1
)
==
elements
(
dims2
));
common_dims
cd
;
common_dims
cd
;
...
@@ -43,7 +44,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
...
@@ -43,7 +44,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
{
{
auto
d1
=
*
it1
;
auto
d1
=
*
it1
;
auto
d2
=
*
it2
;
auto
d2
=
*
it2
;
if
(
d1
==
d2
)
if
(
d1
==
d2
)
{
{
cd
.
axes_map1
.
push_back
({
cd
.
dims
.
size
()});
cd
.
axes_map1
.
push_back
({
cd
.
dims
.
size
()});
cd
.
axes_map2
.
push_back
({
cd
.
dims
.
size
()});
cd
.
axes_map2
.
push_back
({
cd
.
dims
.
size
()});
...
@@ -51,15 +52,15 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
...
@@ -51,15 +52,15 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
it1
++
;
it1
++
;
it2
++
;
it2
++
;
}
}
else
if
(
d1
<
d2
)
else
if
(
d1
<
d2
)
{
{
auto
dim_end
=
compute_end_dim
(
it1
,
dims1
.
begin
(),
d2
);
auto
dim_end
=
compute_end_dim
(
it1
,
dims1
.
begin
(),
d2
);
auto
dims
=
range
(
it1
,
dim_end
);
auto
dims
=
range
(
it1
,
dim_end
);
auto
n
=
elements
(
dims
);
auto
n
=
elements
(
dims
);
if
(
n
!=
d2
)
if
(
n
!=
d2
)
{
{
// If not divisible then we can't compute a common dims
// If not divisible then we can't compute a common dims
if
((
d2
%
n
)
!=
0
)
if
((
d2
%
n
)
!=
0
)
return
{};
return
{};
rem1
=
d2
/
n
;
rem1
=
d2
/
n
;
}
}
...
@@ -69,7 +70,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
...
@@ -69,7 +70,7 @@ common_dims common_dims::compute(const std::vector<std::size_t>& dims1, const st
cd
.
axes_map2
.
push_back
(
axes
);
cd
.
axes_map2
.
push_back
(
axes
);
cd
.
dims
.
insert
(
cd
.
dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
cd
.
dims
.
insert
(
cd
.
dims
.
end
(),
dims
.
begin
(),
dims
.
end
());
if
(
rem1
!=
1
)
if
(
rem1
!=
1
)
cd
.
dims
.
push_back
(
rem1
);
cd
.
dims
.
push_back
(
rem1
);
it1
+=
distance
(
dims
);
it1
+=
distance
(
dims
);
it2
++
;
it2
++
;
...
...
src/include/migraphx/common_dims.hpp
View file @
665455e1
...
@@ -10,14 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,14 +10,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
common_dims
struct
common_dims
{
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
#endif // MIGRAPHX_GUARD_MIGRAPHX_COMMON_DIMS_HPP
src/include/migraphx/matcher.hpp
View file @
665455e1
...
@@ -835,13 +835,12 @@ inline auto has_attribute(const std::string& name)
...
@@ -835,13 +835,12 @@ inline auto has_attribute(const std::string& name)
[
=
](
instruction_ref
ins
)
{
return
ins
->
get_operator
().
attributes
().
contains
(
name
);
});
[
=
](
instruction_ref
ins
)
{
return
ins
->
get_operator
().
attributes
().
contains
(
name
);
});
}
}
template
<
class
T
>
template
<
class
T
>
inline
auto
has_attribute
(
const
std
::
string
&
name
,
T
value
)
inline
auto
has_attribute
(
const
std
::
string
&
name
,
T
value
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
[
=
](
instruction_ref
ins
)
{
auto
attributes
=
ins
->
get_operator
().
attributes
();
auto
attributes
=
ins
->
get_operator
().
attributes
();
if
(
not
attributes
.
contains
(
name
))
if
(
not
attributes
.
contains
(
name
))
return
false
;
return
false
;
return
attributes
[
name
].
to
<
T
>
()
==
value
;
return
attributes
[
name
].
to
<
T
>
()
==
value
;
});
});
...
...
src/simplify_reshapes.cpp
View file @
665455e1
...
@@ -920,7 +920,8 @@ struct find_poinwise_reduce_reshape
...
@@ -920,7 +920,8 @@ struct find_poinwise_reduce_reshape
auto
reshaper
=
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
});
auto
reshaper
=
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
});
auto
skip_contiguous
=
match
::
skip
(
match
::
name
(
"contiguous"
));
auto
skip_contiguous
=
match
::
skip
(
match
::
name
(
"contiguous"
));
auto
pointwise_or_reduce
=
match
::
any_of
(
match
::
pointwise
(),
match
::
reduce
());
auto
pointwise_or_reduce
=
match
::
any_of
(
match
::
pointwise
(),
match
::
reduce
());
auto
reshape_pointwise_or_reduce
=
reshaper
(
skip_contiguous
(
pointwise_or_reduce
.
bind
(
"x"
))).
bind
(
"reshape"
);
auto
reshape_pointwise_or_reduce
=
reshaper
(
skip_contiguous
(
pointwise_or_reduce
.
bind
(
"x"
))).
bind
(
"reshape"
);
return
pointwise_or_reduce
(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise_or_reduce
));
return
pointwise_or_reduce
(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise_or_reduce
));
}
}
...
@@ -952,7 +953,7 @@ struct find_poinwise_reduce_reshape
...
@@ -952,7 +953,7 @@ struct find_poinwise_reduce_reshape
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
axes
;
std
::
vector
<
int64_t
>
axes
;
if
(
x_ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
if
(
x_ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
{
{
axes
=
x_ins
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
axes
=
x_ins
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
}
}
...
@@ -963,28 +964,29 @@ struct find_poinwise_reduce_reshape
...
@@ -963,28 +964,29 @@ struct find_poinwise_reduce_reshape
inss
.
insert
(
i
);
inss
.
insert
(
i
);
entry
=
i
;
entry
=
i
;
auto
pointwise_or_reduce
=
[
&
](
instruction_ref
input
)
{
auto
pointwise_or_reduce
=
[
&
](
instruction_ref
input
)
{
if
(
input
->
can_eval
())
if
(
input
->
can_eval
())
return
false
;
return
false
;
return
is_pointwise
(
input
);
return
is_pointwise
(
input
);
};
};
auto
it
=
std
::
find_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
pointwise_or_reduce
);
auto
it
=
std
::
find_if
(
i
->
inputs
().
begin
(),
i
->
inputs
().
end
(),
pointwise_or_reduce
);
if
(
it
==
i
->
inputs
().
end
())
if
(
it
==
i
->
inputs
().
end
())
return
;
return
;
auto
it2
=
std
::
find_if
(
it
,
i
->
inputs
().
end
(),
pointwise_or_reduce
);
auto
it2
=
std
::
find_if
(
it
,
i
->
inputs
().
end
(),
pointwise_or_reduce
);
// If there is more than one pointwise_reduce than stop
// If there is more than one pointwise_reduce than stop
if
(
it2
!=
i
->
inputs
().
end
())
if
(
it2
!=
i
->
inputs
().
end
())
return
;
return
;
self
(
*
it
);
self
(
*
it
);
})(
x_ins
);
})(
x_ins
);
// Collect from output
// Collect from output
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
for
(
auto
output
:
out
->
outputs
())
for
(
auto
output
:
out
->
outputs
())
{
{
if
(
not
std
::
all_of
(
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
if
(
not
std
::
all_of
(
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
return
input
->
can_eval
()
or
contains
(
inss
,
input
);
return
input
->
can_eval
()
or
contains
(
inss
,
input
);
}))
}))
continue
;
continue
;
if
(
not
is_pointwise_or_reduce
(
ins
))
if
(
not
is_pointwise_or_reduce
(
ins
))
continue
;
continue
;
inss
.
insert
(
output
);
inss
.
insert
(
output
);
self
(
output
);
self
(
output
);
...
@@ -996,9 +998,9 @@ struct find_poinwise_reduce_reshape
...
@@ -996,9 +998,9 @@ struct find_poinwise_reduce_reshape
// Topological sort
// Topological sort
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
instructions
.
push_back
(
i
);
instructions
.
push_back
(
i
);
for
(
auto
output
:
i
->
outputs
())
for
(
auto
output
:
i
->
outputs
())
{
{
if
(
not
contains
(
inss
,
output
))
if
(
not
contains
(
inss
,
output
))
{
{
aux
.
insert
(
output
);
aux
.
insert
(
output
);
continue
;
continue
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment