Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d0dbaf41
Commit
d0dbaf41
authored
May 03, 2023
by
Paul
Browse files
Save code
parent
da78b0c0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
228 additions
and
46 deletions
+228
-46
src/dom_info.cpp
src/dom_info.cpp
+15
-2
src/include/migraphx/common_dims.hpp
src/include/migraphx/common_dims.hpp
+4
-0
src/include/migraphx/dom_info.hpp
src/include/migraphx/dom_info.hpp
+1
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+3
-3
src/include/migraphx/module.hpp
src/include/migraphx/module.hpp
+7
-0
src/instruction.cpp
src/instruction.cpp
+1
-0
src/module.cpp
src/module.cpp
+22
-6
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+175
-34
No files found.
src/dom_info.cpp
View file @
d0dbaf41
...
@@ -46,7 +46,15 @@ bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins
...
@@ -46,7 +46,15 @@ bool dominator_info::strictly_dominate(instruction_ref ins1, instruction_ref ins
return
false
;
return
false
;
}
}
struct
module_visitor
struct
module_input_visitor
{
module
*
mm
;
module
&
get_nodes
()
const
{
return
*
mm
;
}
const
std
::
vector
<
instruction_ref
>&
get_children
(
instruction_ref
ins
)
{
return
ins
->
inputs
();
}
};
struct
module_output_visitor
{
{
module
*
mm
;
module
*
mm
;
module
&
get_nodes
()
const
{
return
*
mm
;
}
module
&
get_nodes
()
const
{
return
*
mm
;
}
...
@@ -93,7 +101,12 @@ dominator_info compute_dominator_generic(Visitor v)
...
@@ -93,7 +101,12 @@ dominator_info compute_dominator_generic(Visitor v)
dominator_info
compute_dominator
(
module
&
m
)
dominator_info
compute_dominator
(
module
&
m
)
{
{
return
compute_dominator_generic
(
module_visitor
{
&
m
});
return
compute_dominator_generic
(
module_input_visitor
{
&
m
});
}
dominator_info
compute_post_dominator
(
module
&
m
)
{
return
compute_dominator_generic
(
module_output_visitor
{
&
m
});
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/common_dims.hpp
View file @
d0dbaf41
...
@@ -12,6 +12,10 @@ struct common_dims
...
@@ -12,6 +12,10 @@ struct common_dims
{
{
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
static
common_dims
compute
(
const
std
::
vector
<
std
::
size_t
>&
dims1
,
const
std
::
vector
<
std
::
size_t
>&
dims2
);
const
std
::
vector
<
std
::
size_t
>&
dims2
);
bool
empty
()
const
{
return
dims
.
empty
();
}
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
size_t
>
dims
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map1
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
axes_map2
;
...
...
src/include/migraphx/dom_info.hpp
View file @
d0dbaf41
...
@@ -42,7 +42,7 @@ struct dominator_info
...
@@ -42,7 +42,7 @@ struct dominator_info
};
};
dominator_info
compute_dominator
(
module
&
m
);
dominator_info
compute_dominator
(
module
&
m
);
//
dominator_info compute_dominator
_naive(const
module& m);
dominator_info
compute_
post_
dominator
(
module
&
m
);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
d0dbaf41
...
@@ -198,8 +198,8 @@ struct basic_matcher
...
@@ -198,8 +198,8 @@ struct basic_matcher
{
{
M
m
;
M
m
;
template
<
class
...
T
s
>
template
<
class
...
M
s
>
auto
operator
()(
T
s
...
ms
)
const
auto
operator
()(
M
s
...
ms
)
const
{
{
// Copy m because we cant capture `this` by value
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
auto
mm
=
m
;
...
@@ -209,7 +209,7 @@ struct basic_matcher
...
@@ -209,7 +209,7 @@ struct basic_matcher
if
(
result
)
if
(
result
)
{
{
bool
matches
=
bool
matches
=
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
ctx
.
matched
(
y
,
result
);
})(
true
,
ms
...);
+
fold
([
&
](
auto
x
,
auto
y
)
{
return
x
and
ctx
.
matched
(
y
,
result
);
})(
true
,
ms
...);
if
(
matches
)
if
(
matches
)
return
result
;
return
result
;
}
}
...
...
src/include/migraphx/module.hpp
View file @
d0dbaf41
...
@@ -54,6 +54,7 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
...
@@ -54,6 +54,7 @@ using ins_dep_map = std::unordered_map<instruction_ref, std::unordered_set<ins
*/
*/
struct
module
struct
module
{
{
using
inserter
=
std
::
function
<
instruction_ref
(
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
module_ref
>&
module_args
)
>
;
module
(
const
std
::
string
&
name
=
""
);
module
(
const
std
::
string
&
name
=
""
);
// move constructor
// move constructor
...
@@ -137,6 +138,12 @@ struct module
...
@@ -137,6 +138,12 @@ struct module
const
std
::
vector
<
instruction_ref
>&
instructions
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
vector
<
instruction_ref
>
insert_instructions
(
inserter
insert
,
instruction_ref
ins
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
=
{});
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
insert_instructions
(
instruction_ref
ins
,
insert_instructions
(
instruction_ref
ins
,
const_module_ref
m
,
const_module_ref
m
,
...
...
src/instruction.cpp
View file @
d0dbaf41
...
@@ -328,6 +328,7 @@ bool instruction::can_eval() const
...
@@ -328,6 +328,7 @@ bool instruction::can_eval() const
}
}
else
if
(
is_context_free
(
op
))
else
if
(
is_context_free
(
op
))
{
{
assert
(
std
::
none_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[
&
](
instruction_ref
arg
)
{
return
std
::
addressof
(
*
arg
)
==
this
;
}));
return
std
::
all_of
(
return
std
::
all_of
(
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
can_eval
();
});
this
->
inputs
().
begin
(),
this
->
inputs
().
end
(),
[](
auto
arg
)
{
return
arg
->
can_eval
();
});
}
}
...
...
src/module.cpp
View file @
d0dbaf41
...
@@ -197,12 +197,13 @@ void module::assign(const module& m)
...
@@ -197,12 +197,13 @@ void module::assign(const module& m)
}
}
}
}
template
<
class
Range
>
template
<
class
Range
,
class
Inserter
>
static
std
::
vector
<
instruction_ref
>
static
std
::
vector
<
instruction_ref
>
insert_generic_instructions
(
module
&
m
,
insert_generic_instructions
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
Range
&&
instructions
,
Range
&&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
,
Inserter
insert
)
{
{
assert
(
m
.
has_instruction
(
ins
)
or
is_end
(
ins
,
m
.
end
()));
assert
(
m
.
has_instruction
(
ins
)
or
is_end
(
ins
,
m
.
end
()));
std
::
vector
<
instruction_ref
>
mod_outputs
;
std
::
vector
<
instruction_ref
>
mod_outputs
;
...
@@ -244,7 +245,7 @@ insert_generic_instructions(module& m,
...
@@ -244,7 +245,7 @@ insert_generic_instructions(module& m,
break
;
break
;
}
}
copy_ins
=
m
.
insert
_instruction
(
ins
,
sins
->
get_operator
(),
copy_inputs
,
mod_args
);
copy_ins
=
insert
(
m
,
ins
,
sins
->
get_operator
(),
copy_inputs
,
mod_args
);
}
}
map_ins
[
sins
]
=
copy_ins
;
map_ins
[
sins
]
=
copy_ins
;
}
}
...
@@ -253,6 +254,13 @@ insert_generic_instructions(module& m,
...
@@ -253,6 +254,13 @@ insert_generic_instructions(module& m,
return
mod_outputs
;
return
mod_outputs
;
}
}
static
auto
default_module_inserter
()
{
return
[](
module
&
m
,
auto
&&
...
xs
)
{
return
m
.
insert_instruction
(
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
}
instruction_ref
module
::
add_instruction
(
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
args
)
instruction_ref
module
::
add_instruction
(
const
operation
&
op
,
std
::
vector
<
instruction_ref
>
args
)
{
{
return
insert_instruction
(
impl
->
instructions
.
end
(),
op
,
std
::
move
(
args
));
return
insert_instruction
(
impl
->
instructions
.
end
(),
op
,
std
::
move
(
args
));
...
@@ -422,7 +430,15 @@ module::insert_instructions(instruction_ref ins,
...
@@ -422,7 +430,15 @@ module::insert_instructions(instruction_ref ins,
const
std
::
vector
<
instruction_ref
>&
instructions
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
),
default_module_inserter
());
}
std
::
vector
<
instruction_ref
>
module
::
insert_instructions
(
module
::
inserter
insert
,
instruction_ref
ins
,
const
std
::
vector
<
instruction_ref
>&
instructions
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
return
insert_generic_instructions
(
*
this
,
ins
,
instructions
,
std
::
move
(
map_ins
),
insert
);
}
}
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
...
@@ -430,7 +446,7 @@ module::insert_instructions(instruction_ref ins,
...
@@ -430,7 +446,7 @@ module::insert_instructions(instruction_ref ins,
const_module_ref
m
,
const_module_ref
m
,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
*
m
),
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
*
m
),
std
::
move
(
map_ins
)
,
default_module_inserter
()
);
}
}
std
::
vector
<
instruction_ref
>
std
::
vector
<
instruction_ref
>
...
@@ -440,7 +456,7 @@ module::insert_instructions(instruction_ref ins,
...
@@ -440,7 +456,7 @@ module::insert_instructions(instruction_ref ins,
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
)
{
{
auto
r
=
range
(
start
,
last
);
auto
r
=
range
(
start
,
last
);
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
));
return
insert_generic_instructions
(
*
this
,
ins
,
iterator_for
(
r
),
std
::
move
(
map_ins
)
,
default_module_inserter
()
);
}
}
instruction_ref
module
::
add_literal
(
literal
l
)
{
return
insert_literal
(
begin
(),
std
::
move
(
l
));
}
instruction_ref
module
::
add_literal
(
literal
l
)
{
return
insert_literal
(
begin
(),
std
::
move
(
l
));
}
...
...
src/simplify_reshapes.cpp
View file @
d0dbaf41
...
@@ -37,6 +37,8 @@
...
@@ -37,6 +37,8 @@
#include <unordered_set>
#include <unordered_set>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/common_dims.hpp>
#include <migraphx/dom_info.hpp>
#include <map>
#include <map>
...
@@ -915,16 +917,29 @@ struct find_broadcast_reshaper
...
@@ -915,16 +917,29 @@ struct find_broadcast_reshaper
struct
find_poinwise_reduce_reshape
struct
find_poinwise_reduce_reshape
{
{
template
<
class
...
Ms
>
static
auto
match_reshaper
(
Ms
...
ms
)
{
return
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
})(
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"contiguous"
))(
ms
...)));
}
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
reshaper
=
match
::
name
({
"reshape"
,
"squeeze"
,
"unsqueeze"
});
auto
skip_contiguous
=
match
::
skip
(
match
::
name
(
"contiguous"
));
auto
pointwise_or_reduce
=
match
::
any_of
(
match
::
pointwise
(),
match
::
reduce
());
auto
pointwise_or_reduce
=
match
::
any_of
(
match
::
pointwise
(),
match
::
reduce
());
auto
reshape_pointwise_or_reduce
=
auto
reshape_pointwise_or_reduce
=
reshaper
(
skip_contiguous
(
pointwise_or_reduce
.
bind
(
"x"
))
)
.
bind
(
"reshape"
);
match_
reshaper
(
match
::
pointwise
()
.
bind
(
"x"
)).
bind
(
"reshape"
);
return
pointwise_or_reduce
(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise_or_reduce
));
return
pointwise_or_reduce
(
match
::
any_of
[
match
::
inputs
()](
reshape_pointwise_or_reduce
));
}
}
static
bool
is_broadcast
(
const
operation
&
op
)
{
return
contains
({
"broadcast"
,
"multibroadcast"
},
op
.
name
());
}
static
bool
is_broadcast
(
instruction_ref
ins
)
{
return
is_broadcast
(
ins
->
get_operator
());
}
static
bool
is_pointwise
(
instruction_ref
ins
)
static
bool
is_pointwise
(
instruction_ref
ins
)
{
{
auto
a
=
ins
->
get_operator
().
attributes
();
auto
a
=
ins
->
get_operator
().
attributes
();
...
@@ -933,7 +948,12 @@ struct find_poinwise_reduce_reshape
...
@@ -933,7 +948,12 @@ struct find_poinwise_reduce_reshape
static
bool
is_reduce
(
instruction_ref
ins
)
static
bool
is_reduce
(
instruction_ref
ins
)
{
{
auto
a
=
ins
->
get_operator
().
attributes
();
return
is_reduce
(
ins
->
get_operator
());
}
static
bool
is_reduce
(
const
operation
&
op
)
{
auto
a
=
op
.
attributes
();
return
a
.
get
(
"reduce"
,
false
);
return
a
.
get
(
"reduce"
,
false
);
}
}
...
@@ -943,27 +963,87 @@ struct find_poinwise_reduce_reshape
...
@@ -943,27 +963,87 @@ struct find_poinwise_reduce_reshape
return
a
.
get
(
"pointwise"
,
false
)
or
a
.
get
(
"reduce"
,
false
);
return
a
.
get
(
"pointwise"
,
false
)
or
a
.
get
(
"reduce"
,
false
);
}
}
static
std
::
vector
<
instruction_ref
>
topo_sort
(
instruction_ref
entry
,
const
std
::
unordered_set
<
instruction_ref
>&
inss
,
std
::
unordered_set
<
instruction_ref
>&
aux
)
{
std
::
vector
<
instruction_ref
>
instructions
;
bool
has_entry
=
contains
(
inss
,
entry
);
fix
([
&
](
auto
self
,
instruction_ref
ins
)
{
if
(
ins
!=
entry
or
has_entry
)
instructions
.
push_back
(
ins
);
for
(
auto
input
:
ins
->
inputs
())
{
if
(
not
contains
(
inss
,
input
))
aux
.
insert
(
input
);
}
for
(
auto
output
:
ins
->
outputs
())
{
if
(
contains
(
instructions
,
output
))
continue
;
if
(
not
contains
(
inss
,
output
))
continue
;
self
(
output
);
}
})(
entry
);
assert
(
instructions
.
size
()
==
inss
.
size
());
return
instructions
;
}
static
std
::
vector
<
instruction_ref
>
topo_sort
(
const
std
::
unordered_set
<
instruction_ref
>&
inss
,
std
::
unordered_set
<
instruction_ref
>&
aux
)
{
std
::
vector
<
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
visited
;
for
(
auto
ins
:
inss
)
{
fix
([
&
](
auto
self
,
instruction_ref
child
)
{
if
(
contains
(
visited
,
child
))
return
;
for
(
auto
output
:
child
->
outputs
())
{
if
(
not
contains
(
inss
,
output
))
continue
;
self
(
output
);
}
visited
.
insert
(
child
);
for
(
auto
input
:
child
->
inputs
())
{
if
(
not
contains
(
inss
,
input
))
aux
.
insert
(
input
);
}
instructions
.
push_back
(
child
);
})(
ins
);
}
std
::
reverse
(
instructions
.
begin
(),
instructions
.
end
());
assert
(
instructions
.
size
()
==
inss
.
size
());
return
instructions
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
// std::cout << "find_poinwise_reduce_reshape" << std::endl;
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
reshape_ins
=
r
.
instructions
[
"reshape"
];
auto
nelements
=
x_ins
->
get_shape
().
elements
();
auto
dims1
=
x_ins
->
get_shape
().
lens
();
auto
dims1
=
x_ins
->
get_shape
().
lens
();
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
auto
dims2
=
reshape_ins
->
get_shape
().
lens
();
std
::
vector
<
int64_t
>
axes
;
auto
cd
=
common_dims
::
compute
(
dims1
,
dims2
)
;
if
(
x_ins
->
get_operator
().
attributes
().
get
(
"reduce"
,
false
))
if
(
cd
.
empty
(
))
{
return
;
axes
=
x_ins
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
}
// m.debug_print();
std
::
unordered_set
<
instruction_ref
>
ins
s
;
// m.debug_print(reshape_
ins
)
;
instruction_ref
entry
;
// m.debug_print(ins)
;
// Collect from inputs
// Collect from inputs
std
::
unordered_set
<
instruction_ref
>
input_inss
;
instruction_ref
entry
;
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
inss
.
insert
(
i
);
if
(
contains
(
input_inss
,
i
))
return
;
input_inss
.
insert
(
i
);
entry
=
i
;
entry
=
i
;
auto
pointwise_or_reduce
=
[
&
](
instruction_ref
input
)
{
auto
pointwise_or_reduce
=
[](
instruction_ref
input
)
{
if
(
input
->
can_eval
())
if
(
input
->
can_eval
())
return
false
;
return
false
;
return
is_pointwise
(
input
);
return
is_pointwise
(
input
);
...
@@ -977,37 +1057,97 @@ struct find_poinwise_reduce_reshape
...
@@ -977,37 +1057,97 @@ struct find_poinwise_reduce_reshape
return
;
return
;
self
(
*
it
);
self
(
*
it
);
})(
x_ins
);
})(
x_ins
);
std
::
vector
<
int64_t
>
axes
;
auto
dom
=
compute_post_dominator
(
m
);
std
::
unordered_set
<
instruction_ref
>
output_inss
;
// Collect from output
// Collect from output
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
fix
([
&
](
auto
self
,
instruction_ref
out
)
{
for
(
auto
output
:
out
->
outputs
())
// if(contains(inss, out))
// return;
// std::cout << "Visit: ";
// m.debug_print(out);
// m.debug_print(out->inputs());
auto
outputs
=
out
->
outputs
();
std
::
sort
(
outputs
.
begin
(),
outputs
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
instruction_ref
i
)
{
return
std
::
distance
(
reshape_ins
,
i
);
}));
// m.debug_print(outputs);
for
(
auto
output
:
outputs
)
{
{
if
(
not
std
::
all_of
(
if
(
not
std
::
all_of
(
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
output
->
inputs
().
begin
(),
output
->
inputs
().
end
(),
[
&
](
auto
input
)
{
return
input
->
can_eval
()
or
contains
(
inss
,
input
);
return
input
->
can_eval
()
or
reshape_ins
==
input
or
contains
(
output_
inss
,
input
);
// or dom.strictly_dominate(reshape_ins, input);
}))
}))
continue
;
continue
;
if
(
not
is_pointwise_or_reduce
(
ins
))
if
(
not
is_pointwise_or_reduce
(
output
)
and
not
is_broadcast
(
output
))
continue
;
continue
;
inss
.
insert
(
output
);
if
(
is_reduce
(
output
))
self
(
output
);
}
})(
x_ins
);
std
::
vector
<
instruction_ref
>
instructions
;
std
::
unordered_set
<
instruction_ref
>
aux
;
// Topological sort
fix
([
&
](
auto
self
,
instruction_ref
i
)
{
instructions
.
push_back
(
i
);
for
(
auto
output
:
i
->
outputs
())
{
if
(
not
contains
(
inss
,
output
))
{
{
aux
.
insert
(
output
);
auto
op_axes
=
output
->
get_operator
().
to_value
()[
"axes"
].
to_vector
<
int64_t
>
();
continue
;
if
(
axes
.
empty
())
axes
=
op_axes
;
if
(
axes
!=
op_axes
)
return
;
}
}
output_inss
.
insert
(
output
);
self
(
output
);
self
(
output
);
}
}
})(
entry
);
})(
reshape_ins
);
std
::
vector
<
int64_t
>
common_axes
;
for
(
auto
axis
:
axes
)
{
common_axes
.
insert
(
common_axes
.
end
(),
cd
.
axes_map2
[
axis
].
begin
(),
cd
.
axes_map2
[
axis
].
end
());
}
auto
common_rdims
=
cd
.
dims
;
for
(
auto
axis
:
common_axes
)
{
common_rdims
[
axis
]
=
1
;
}
// Topological sort
std
::
unordered_set
<
instruction_ref
>
aux
;
auto
input_instructions
=
topo_sort
(
input_inss
,
aux
);
auto
output_instructions
=
topo_sort
(
output_inss
,
aux
);
// std::cout << "output_inss:\n";
// m.debug_print({output_inss.begin(), output_inss.end()});
// std::cout << "Output instructions:\n";
// m.debug_print(output_instructions);
// std::cout << "aux:\n";
// m.debug_print({aux.begin(), aux.end()});
auto
last
=
output_instructions
.
back
();
auto
insert_reshape
=
[
&
](
instruction_ref
input
)
{
auto
use_rdims
=
input
->
get_shape
().
elements
()
<
nelements
;
auto
c
=
m
.
insert_instruction
(
last
,
make_op
(
"contiguous"
),
input
);
return
m
.
insert_instruction
(
last
,
make_op
(
"reshape"
,
{{
"dims"
,
use_rdims
?
common_rdims
:
cd
.
dims
}}),
c
);
};
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_ins
;
// map_ins[entry] = insert_reshape(entry);
for
(
auto
i
:
aux
)
{
map_ins
[
i
]
=
insert_reshape
(
i
);
}
auto
inserter
=
[
&
](
module
&
mm
,
instruction_ref
i
,
operation
op
,
const
std
::
vector
<
instruction_ref
>&
args
,
const
std
::
vector
<
module_ref
>&
module_args
)
{
if
(
is_reduce
(
op
))
op
.
from_value
({{
"axes"
,
common_axes
}});
if
(
is_broadcast
(
op
))
op
.
from_value
({{
"out_lens"
,
cd
.
dims
}});
// std::cout << op << std::endl;
// m.debug_print(args);
return
mm
.
insert_instruction
(
i
,
op
,
args
,
module_args
);
};
auto
new_x_ins
=
m
.
insert_instructions
(
inserter
,
last
,
input_instructions
,
map_ins
).
front
();
map_ins
[
reshape_ins
]
=
new_x_ins
;
auto
new_last
=
m
.
insert_instructions
(
inserter
,
last
,
output_instructions
,
map_ins
).
front
();
auto
new_c
=
m
.
insert_instruction
(
last
,
make_op
(
"contiguous"
),
new_last
);
auto
new_reshape
=
m
.
insert_instruction
(
last
,
make_op
(
"reshape"
,
{{
"dims"
,
dims2
}}),
new_c
);
m
.
debug_print
();
m
.
debug_print
(
last
);
m
.
debug_print
(
new_reshape
);
m
.
replace_instruction
(
last
,
new_reshape
);
std
::
abort
();
}
}
};
};
...
@@ -1020,7 +1160,7 @@ void simplify_reshapes::apply(module& m) const
...
@@ -1020,7 +1160,7 @@ void simplify_reshapes::apply(module& m) const
find_resize
{},
find_resize
{},
find_nop_reshapes
{},
find_nop_reshapes
{},
find_reshaper
{},
find_reshaper
{},
find_broadcast_reshaper
{},
//
find_broadcast_reshaper{},
// find_reshape_cont{},
// find_reshape_cont{},
find_transpose
{},
find_transpose
{},
find_concat_transpose
{},
find_concat_transpose
{},
...
@@ -1032,7 +1172,8 @@ void simplify_reshapes::apply(module& m) const
...
@@ -1032,7 +1172,8 @@ void simplify_reshapes::apply(module& m) const
find_slice_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{},
find_transpose_contiguous_reshaper_unary
{},
find_mul_add_transpose_contiguous_reshaper_gemm
{},
find_mul_add_transpose_contiguous_reshaper_gemm
{},
find_reshape_gemm
{});
find_reshape_gemm
{},
find_poinwise_reduce_reshape
{});
dead_code_elimination
{}.
apply
(
m
);
dead_code_elimination
{}.
apply
(
m
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment