Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d9a5acbd
Unverified
Commit
d9a5acbd
authored
May 17, 2022
by
Paul Fultz II
Committed by
GitHub
May 17, 2022
Browse files
Merge branch 'develop' into jit-vector-reduce
parents
d0b7fc9a
a27dd28c
Changes
60
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
598 additions
and
516 deletions
+598
-516
src/include/migraphx/make_op.hpp
src/include/migraphx/make_op.hpp
+13
-1
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+61
-3
src/include/migraphx/memory_coloring.hpp
src/include/migraphx/memory_coloring.hpp
+1
-1
src/include/migraphx/propagate_constant.hpp
src/include/migraphx/propagate_constant.hpp
+1
-1
src/include/migraphx/rewrite_batchnorm.hpp
src/include/migraphx/rewrite_batchnorm.hpp
+1
-1
src/include/migraphx/rewrite_pooling.hpp
src/include/migraphx/rewrite_pooling.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+12
-13
src/include/migraphx/schedule.hpp
src/include/migraphx/schedule.hpp
+1
-1
src/include/migraphx/simplify_algebra.hpp
src/include/migraphx/simplify_algebra.hpp
+1
-1
src/include/migraphx/simplify_reshapes.hpp
src/include/migraphx/simplify_reshapes.hpp
+1
-1
src/make_op.cpp
src/make_op.cpp
+28
-7
src/opt/memory_coloring.cpp
src/opt/memory_coloring.cpp
+2
-2
src/propagate_constant.cpp
src/propagate_constant.cpp
+4
-4
src/rewrite_batchnorm.cpp
src/rewrite_batchnorm.cpp
+9
-9
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+6
-7
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+303
-310
src/schedule.cpp
src/schedule.cpp
+37
-37
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+78
-78
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+1
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+37
-37
No files found.
src/include/migraphx/make_op.hpp
View file @
d9a5acbd
...
@@ -9,7 +9,19 @@ namespace migraphx {
...
@@ -9,7 +9,19 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
);
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
);
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
);
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
);
// A template overload is added for migraphx::value so the initializer_list
// cannot be passed in directly. This is to enforce at compile-time that all
// initializer_list are key-value pairs, whereas migraphx::value allows other
// types of initializer_list such as for arrays.
template
<
class
Value
>
operation
make_op
(
const
std
::
string
&
name
,
const
Value
&
v
)
{
return
make_op_from_value
(
name
,
v
);
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/matcher.hpp
View file @
d9a5acbd
...
@@ -156,6 +156,19 @@ struct id_matcher
...
@@ -156,6 +156,19 @@ struct id_matcher
}
}
};
};
// Forward declare class and constructors
template
<
class
M
>
struct
basic_matcher
;
template
<
class
M
>
basic_matcher
<
M
>
make_basic_matcher
(
M
m
);
template
<
class
F
>
basic_matcher
<
function_matcher
<
F
>>
make_basic_fun_matcher
(
F
f
);
template
<
class
P
>
basic_matcher
<
predicate_matcher
<
P
>>
make_basic_pred_matcher
(
P
p
);
/// The basic matcher provides the all_of composability of the matcher
/// The basic matcher provides the all_of composability of the matcher
template
<
class
M
>
template
<
class
M
>
struct
basic_matcher
struct
basic_matcher
...
@@ -167,7 +180,7 @@ struct basic_matcher
...
@@ -167,7 +180,7 @@ struct basic_matcher
{
{
// Copy m because we cant capture `this` by value
// Copy m because we cant capture `this` by value
auto
mm
=
m
;
auto
mm
=
m
;
return
make_b
f
_matcher
([
=
](
matcher_context
&
ctx
,
return
make_b
asic_fun
_matcher
([
=
](
matcher_context
&
ctx
,
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
instruction_ref
ins
)
->
optional
<
instruction_ref
>
{
auto
result
=
mm
.
match
(
ctx
,
ins
);
auto
result
=
mm
.
match
(
ctx
,
ins
);
if
(
result
)
if
(
result
)
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
...
@@ -239,7 +252,39 @@ struct any_matcher : any_matcher_base
struct
matcher_result
struct
matcher_result
{
{
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
instructions
;
struct
instruction_container
{
instruction_container
()
=
default
;
instruction_container
(
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
x
)
:
ins_map
(
std
::
move
(
x
))
{
}
instruction_ref
operator
[](
const
std
::
string
&
name
)
const
{
auto
it
=
ins_map
.
find
(
name
);
if
(
it
==
ins_map
.
end
())
MIGRAPHX_THROW
(
"Accessing name that wasn't bound in matcher: "
+
name
);
return
it
->
second
;
}
auto
find
(
const
std
::
string
&
name
)
const
{
return
ins_map
.
find
(
name
);
}
auto
begin
()
const
{
return
ins_map
.
cbegin
();
}
auto
end
()
const
{
return
ins_map
.
cend
();
}
bool
has_instructions_in
(
const
module
&
mod
)
const
{
return
std
::
all_of
(
ins_map
.
begin
(),
ins_map
.
end
(),
[
&
](
auto
&&
p
)
{
return
mod
.
has_instruction
(
p
.
second
);
});
}
private:
std
::
unordered_map
<
std
::
string
,
instruction_ref
>
ins_map
;
};
instruction_container
instructions
;
instruction_ref
result
;
instruction_ref
result
;
};
};
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
...
@@ -255,6 +300,7 @@ matcher_result match_instruction(module& mod, instruction_ref ins, M&& m)
{
{
result
.
result
=
ins
;
result
.
result
=
ins
;
result
.
instructions
=
ctx
.
instructions
;
result
.
instructions
=
ctx
.
instructions
;
assert
(
result
.
instructions
.
has_instructions_in
(
mod
));
}
}
else
else
{
{
...
@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
...
@@ -533,6 +579,18 @@ auto skip_output(Ms... ms)
});
});
}
}
inline
auto
var
(
std
::
string
s
)
{
return
make_basic_fun_matcher
(
[
=
,
s
=
std
::
move
(
s
)](
const
matcher_context
&
ctx
,
instruction_ref
)
->
optional
<
instruction_ref
>
{
auto
it
=
ctx
.
instructions
.
find
(
s
);
if
(
it
==
ctx
.
instructions
.
end
())
return
nullopt
;
return
it
->
second
;
});
}
inline
auto
name
(
std
::
string
s
)
inline
auto
name
(
std
::
string
s
)
{
{
return
make_basic_pred_matcher
(
return
make_basic_pred_matcher
(
...
...
src/include/migraphx/memory_coloring.hpp
View file @
d9a5acbd
...
@@ -17,7 +17,7 @@ struct memory_coloring
...
@@ -17,7 +17,7 @@ struct memory_coloring
std
::
string
allocation_op
{};
std
::
string
allocation_op
{};
bool
verify
=
false
;
bool
verify
=
false
;
std
::
string
name
()
const
{
return
"memory coloring"
;
}
std
::
string
name
()
const
{
return
"memory coloring"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/propagate_constant.hpp
View file @
d9a5acbd
...
@@ -15,7 +15,7 @@ struct module;
...
@@ -15,7 +15,7 @@ struct module;
struct
propagate_constant
struct
propagate_constant
{
{
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
std
::
string
name
()
const
{
return
"propagate_constant"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_batchnorm.hpp
View file @
d9a5acbd
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
rewrite_batchnorm
struct
rewrite_batchnorm
{
{
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
std
::
string
name
()
const
{
return
"rewrite_batchnorm"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_pooling.hpp
View file @
d9a5acbd
...
@@ -15,7 +15,7 @@ struct module;
...
@@ -15,7 +15,7 @@ struct module;
struct
rewrite_pooling
struct
rewrite_pooling
{
{
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
std
::
string
name
()
const
{
return
"rewrite_pooling"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
d9a5acbd
...
@@ -19,22 +19,22 @@ struct module;
...
@@ -19,22 +19,22 @@ struct module;
struct
rewrite_rnn
struct
rewrite_rnn
{
{
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
std
::
string
name
()
const
{
return
"rewrite_rnn"
;
}
void
apply
(
module
&
prog
)
const
;
void
apply
(
module
&
m
)
const
;
private:
private:
// for vanilla rnn operators
// for vanilla rnn operators
void
apply_vanilla_rnn
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_vanilla_rnn
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
vanilla_rnn_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
std
::
vector
<
instruction_ref
>
inputs
,
operation
&
actv_func
)
const
;
operation
&
actv_func
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
vanilla_rnn_actv_funcs
(
instruction_ref
ins
)
const
;
// for gru operators
// for gru operators
void
apply_gru
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_gru
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
gru_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
int
linear_before_reset
,
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
...
@@ -44,9 +44,9 @@ struct rewrite_rnn
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
// for lstm operators
void
apply_lstm
(
module
&
prog
,
instruction_ref
ins
)
const
;
void
apply_lstm
(
module
&
m
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
const
operation
&
actv_func1
,
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
...
@@ -55,24 +55,23 @@ struct rewrite_rnn
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
bool
is_variable_seq_lens
(
const
module
&
prog
,
instruction_ref
seq_lens
)
const
;
bool
is_variable_seq_lens
(
const
module
&
m
,
instruction_ref
seq_lens
)
const
;
instruction_ref
replace_last_hs_output
(
module
&
prog
,
instruction_ref
replace_last_hs_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
seq_lens
,
instruction_ref
last_hs_output
,
instruction_ref
last_hs_output
,
op
::
rnn_direction
dirct
)
const
;
op
::
rnn_direction
dirct
)
const
;
void
replace_last_cell_output
(
module
&
prog
,
void
replace_last_cell_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
seq_lens
,
instruction_ref
cell_outputs
,
instruction_ref
cell_outputs
,
instruction_ref
last_cell_output
,
instruction_ref
last_cell_output
,
op
::
rnn_direction
dirct
)
const
;
op
::
rnn_direction
dirct
)
const
;
std
::
size_t
std
::
size_t
get_seq_len
(
const
module
&
m
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
;
instruction_ref
pad_hidden_states
(
module
&
prog
,
instruction_ref
pad_hidden_states
(
module
&
m
,
instruction_ref
seq
,
instruction_ref
seq
,
instruction_ref
seq_lens
,
instruction_ref
seq_lens
,
instruction_ref
hs
)
const
;
instruction_ref
hs
)
const
;
...
...
src/include/migraphx/schedule.hpp
View file @
d9a5acbd
...
@@ -19,7 +19,7 @@ struct schedule
...
@@ -19,7 +19,7 @@ struct schedule
schedule_model
model
{};
schedule_model
model
{};
bool
enable
=
true
;
bool
enable
=
true
;
std
::
string
name
()
const
{
return
"schedule"
;
}
std
::
string
name
()
const
{
return
"schedule"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_algebra.hpp
View file @
d9a5acbd
...
@@ -15,7 +15,7 @@ struct module;
...
@@ -15,7 +15,7 @@ struct module;
struct
simplify_algebra
struct
simplify_algebra
{
{
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
std
::
string
name
()
const
{
return
"simplify_algebra"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/include/migraphx/simplify_reshapes.hpp
View file @
d9a5acbd
...
@@ -16,7 +16,7 @@ struct module;
...
@@ -16,7 +16,7 @@ struct module;
struct
simplify_reshapes
struct
simplify_reshapes
{
{
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
std
::
string
name
()
const
{
return
"simplify_reshapes"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/make_op.cpp
100755 → 100644
View file @
d9a5acbd
...
@@ -5,20 +5,41 @@ namespace migraphx {
...
@@ -5,20 +5,41 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
operation
make_op
(
const
std
::
string
&
name
)
{
return
load_op
(
name
);
}
operation
make_op
(
const
std
::
string
&
name
)
{
return
load_op
(
name
);
}
operation
make_op
(
const
std
::
string
&
name
,
const
value
&
v
)
template
<
class
F
>
operation
make_op_generic
(
const
std
::
string
&
name
,
F
for_each
)
{
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object"
);
auto
op
=
load_op
(
name
);
auto
op
=
load_op
(
name
);
// Merge values
// Merge values
value
w
=
op
.
to_value
();
value
w
=
op
.
to_value
();
for
(
auto
&&
x
:
v
)
for_each
([
&
](
const
auto
&
key
,
const
auto
&
x
)
{
{
if
(
not
w
.
contains
(
key
))
w
.
at
(
x
.
get_key
())
=
x
.
without_key
();
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
}
MIGRAPHX_THROW
(
"No key '"
+
key
+
"' in "
+
name
);
w
.
at
(
key
)
=
x
;
});
op
.
from_value
(
w
);
op
.
from_value
(
w
);
return
op
;
return
op
;
}
}
operation
make_op
(
const
std
::
string
&
name
,
const
std
::
initializer_list
<
std
::
pair
<
std
::
string
,
value
>>&
v
)
{
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
[
key
,
x
]
:
v
)
f
(
key
,
x
);
});
}
operation
make_op_from_value
(
const
std
::
string
&
name
,
const
value
&
v
)
{
if
(
not
(
v
.
is_object
()
or
(
v
.
empty
()
and
v
.
is_array
())))
MIGRAPHX_THROW
(
"Value is not an object for make_op: "
+
name
);
return
make_op_generic
(
name
,
[
&
](
auto
f
)
{
for
(
auto
&&
x
:
v
)
f
(
x
.
get_key
(),
x
.
without_key
());
});
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/opt/memory_coloring.cpp
View file @
d9a5acbd
...
@@ -4,11 +4,11 @@
...
@@ -4,11 +4,11 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
memory_coloring
::
apply
(
module
&
p
)
const
void
memory_coloring
::
apply
(
module
&
m
)
const
{
{
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
if
(
!
enabled
(
MIGRAPHX_DISABLE_MEMORY_COLORING
{}))
{
{
memory_coloring_impl
opt
(
&
p
,
allocation_op
,
verify
);
memory_coloring_impl
opt
(
&
m
,
allocation_op
,
verify
);
opt
.
run
();
opt
.
run
();
}
}
}
}
...
...
src/propagate_constant.cpp
View file @
d9a5acbd
...
@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
...
@@ -20,9 +20,9 @@ bool skip_propogate(instruction_ref ins)
return
false
;
return
false
;
}
}
void
propagate_constant
::
apply
(
module
&
p
)
const
void
propagate_constant
::
apply
(
module
&
m
)
const
{
{
for
(
auto
i
:
iterator_for
(
p
))
for
(
auto
i
:
iterator_for
(
m
))
{
{
if
(
i
->
name
()
!=
"@literal"
)
if
(
i
->
name
()
!=
"@literal"
)
continue
;
continue
;
...
@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
...
@@ -42,8 +42,8 @@ void propagate_constant::apply(module& p) const
if
(
not
r
.
empty
())
if
(
not
r
.
empty
())
{
{
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
assert
(
r
.
get_shape
()
==
child
->
get_shape
());
auto
l
=
p
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
auto
l
=
m
.
add_literal
(
r
.
get_shape
(),
r
.
data
());
self
(
p
.
replace_instruction
(
child
,
l
));
self
(
m
.
replace_instruction
(
child
,
l
));
}
}
}
}
})(
i
);
})(
i
);
...
...
src/rewrite_batchnorm.cpp
View file @
d9a5acbd
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_batchnorm
::
apply
(
module
&
p
)
const
void
rewrite_batchnorm
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
if
(
ins
->
name
()
!=
"batch_norm_inference"
)
continue
;
continue
;
...
@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
...
@@ -46,13 +46,13 @@ void rewrite_batchnorm::apply(module& p) const
});
});
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
broadcast
=
op
::
broadcast
{
1
,
ins
->
get_shape
().
lens
()};
auto
a_ins
=
p
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
a_ins
=
m
.
add_literal
({
a
.
get_shape
(),
a
.
data
()});
auto
a_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
a_broadcast
=
m
.
insert_instruction
(
ins
,
broadcast
,
a_ins
);
auto
mul
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ins
->
inputs
().
front
(),
a_broadcast
);
auto
mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ins
->
inputs
().
front
(),
a_broadcast
);
auto
b_ins
=
p
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_ins
=
m
.
add_literal
({
b
.
get_shape
(),
b
.
data
()});
auto
b_broadcast
=
p
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
b_broadcast
=
m
.
insert_instruction
(
ins
,
broadcast
,
b_ins
);
auto
add
=
p
.
insert_instruction
(
ins
,
make_op
(
"add"
),
mul
,
b_broadcast
);
auto
add
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
mul
,
b_broadcast
);
p
.
replace_instruction
(
ins
,
add
);
m
.
replace_instruction
(
ins
,
add
);
}
}
}
}
...
...
src/rewrite_pooling.cpp
View file @
d9a5acbd
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_pooling
::
apply
(
module
&
prog
)
const
void
rewrite_pooling
::
apply
(
module
&
m
)
const
{
{
for
(
auto
ins
:
iterator_for
(
prog
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
if
(
ins
->
name
()
!=
"pooling"
)
if
(
ins
->
name
()
!=
"pooling"
)
continue
;
continue
;
...
@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
...
@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue
;
continue
;
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
auto
reshape
=
prog
.
insert_instruction
(
auto
reshape
=
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
n
*
c
,
-
1
}}}),
ins
->
inputs
().
front
());
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
n
*
c
,
-
1
}}}),
ins
->
inputs
().
front
());
instruction_ref
pooling
{};
instruction_ref
pooling
{};
// average pooling
// average pooling
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
{
pooling
=
pooling
=
m
.
insert_instruction
(
ins
,
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
}}}),
reshape
);
prog
.
insert_instruction
(
ins
,
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
}}}),
reshape
);
}
}
// max pooling
// max pooling
else
else
{
{
pooling
=
prog
.
insert_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
reshape
);
pooling
=
m
.
insert_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
reshape
);
}
}
std
::
vector
<
int64_t
>
rsp_lens
(
lens
.
size
(),
1
);
std
::
vector
<
int64_t
>
rsp_lens
(
lens
.
size
(),
1
);
rsp_lens
[
0
]
=
n
;
rsp_lens
[
0
]
=
n
;
rsp_lens
[
1
]
=
c
;
rsp_lens
[
1
]
=
c
;
prog
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
pooling
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
pooling
);
}
}
}
}
...
...
src/rewrite_rnn.cpp
View file @
d9a5acbd
This diff is collapsed.
Click to expand it.
src/schedule.cpp
View file @
d9a5acbd
...
@@ -42,7 +42,7 @@ struct stream_info
...
@@ -42,7 +42,7 @@ struct stream_info
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
iweights
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
iweights
;
ins_dep_map
mod_implicit_deps
;
ins_dep_map
mod_implicit_deps
;
void
calc_implicit_deps
(
const
module
&
p
)
{
mod_implicit_deps
=
p
.
calc_implicit_deps
();
}
void
calc_implicit_deps
(
const
module
&
m
)
{
mod_implicit_deps
=
m
.
calc_implicit_deps
();
}
void
accumulate_weights
(
instruction_ref
last
,
const
schedule_model
&
model
)
void
accumulate_weights
(
instruction_ref
last
,
const
schedule_model
&
model
)
{
{
...
@@ -116,15 +116,15 @@ struct stream_info
...
@@ -116,15 +116,15 @@ struct stream_info
}
}
};
};
std
::
size_t
assign_streams
(
module
&
p
,
std
::
size_t
n
)
std
::
size_t
assign_streams
(
module
&
m
,
std
::
size_t
n
)
{
{
assert
(
n
>
0
);
assert
(
n
>
0
);
partition
critical
;
partition
critical
;
std
::
unordered_map
<
instruction_ref
,
std
::
deque
<
partition
>>
partitions
;
std
::
unordered_map
<
instruction_ref
,
std
::
deque
<
partition
>>
partitions
;
partitions
.
reserve
(
weights
.
size
());
partitions
.
reserve
(
weights
.
size
());
fix
([
&
](
auto
self
,
auto
ins
,
auto
&
part
)
{
fix
([
&
](
auto
self
,
auto
ins
,
auto
&
part
)
{
assert
(
not
is_end
(
ins
,
p
.
end
()));
assert
(
not
is_end
(
ins
,
m
.
end
()));
if
(
not
p
.
has_instruction
(
ins
))
if
(
not
m
.
has_instruction
(
ins
))
return
;
return
;
if
(
contains
(
partitions
,
ins
))
if
(
contains
(
partitions
,
ins
))
return
;
return
;
...
@@ -151,8 +151,8 @@ struct stream_info
...
@@ -151,8 +151,8 @@ struct stream_info
}
}
}
}
// Sort instructions
// Sort instructions
p
.
move_instruction
(
ins
,
p
.
end
());
m
.
move_instruction
(
ins
,
m
.
end
());
})(
std
::
prev
(
p
.
end
()),
critical
);
})(
std
::
prev
(
m
.
end
()),
critical
);
// Set the critical partition to stream 0
// Set the critical partition to stream 0
set_stream
(
critical
,
0
);
set_stream
(
critical
,
0
);
...
@@ -197,13 +197,13 @@ struct stream_info
...
@@ -197,13 +197,13 @@ struct stream_info
}
}
};
};
void
sort
(
module
&
p
,
std
::
size_t
)
void
sort
(
module
&
m
,
std
::
size_t
)
{
{
std
::
set
<
weight_ins
,
compare_weight_ins
>
children
;
std
::
set
<
weight_ins
,
compare_weight_ins
>
children
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
visited
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
visited
;
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
auto
mw
=
this
->
weights
.
at
(
last
);
auto
mw
=
this
->
weights
.
at
(
last
);
auto
nw
=
mw
/
(
p
.
size
()
+
1
);
auto
nw
=
mw
/
(
m
.
size
()
+
1
);
auto
add_child
=
[
&
](
auto
ins
)
{
auto
add_child
=
[
&
](
auto
ins
)
{
auto
x
=
1
+
(
mw
-
this
->
weights
.
at
(
ins
))
/
(
nw
+
1
);
auto
x
=
1
+
(
mw
-
this
->
weights
.
at
(
ins
))
/
(
nw
+
1
);
auto
w
=
x
*
this
->
iweights
.
at
(
ins
);
auto
w
=
x
*
this
->
iweights
.
at
(
ins
);
...
@@ -222,10 +222,10 @@ struct stream_info
...
@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element
// Pop the first element
auto
top
=
children
.
begin
()
->
second
;
auto
top
=
children
.
begin
()
->
second
;
children
.
erase
(
children
.
begin
());
children
.
erase
(
children
.
begin
());
p
.
move_instruction
(
top
,
p
.
begin
());
m
.
move_instruction
(
top
,
m
.
begin
());
for
(
auto
ins
:
top
->
inputs
())
for
(
auto
ins
:
top
->
inputs
())
{
{
if
(
not
p
.
has_instruction
(
ins
))
if
(
not
m
.
has_instruction
(
ins
))
continue
;
continue
;
add_child
(
ins
);
add_child
(
ins
);
}
}
...
@@ -234,7 +234,7 @@ struct stream_info
...
@@ -234,7 +234,7 @@ struct stream_info
{
{
for
(
auto
ins
:
mod_implicit_deps
.
at
(
top
))
for
(
auto
ins
:
mod_implicit_deps
.
at
(
top
))
{
{
assert
(
p
.
has_instruction
(
ins
));
assert
(
m
.
has_instruction
(
ins
));
add_child
(
ins
);
add_child
(
ins
);
}
}
}
}
...
@@ -242,12 +242,12 @@ struct stream_info
...
@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed
// move dangling parameter to the front so as not be removed
auto
ins
=
std
::
next
(
last
);
auto
ins
=
std
::
next
(
last
);
while
(
ins
!=
p
.
end
())
while
(
ins
!=
m
.
end
())
{
{
auto
next
=
std
::
next
(
ins
);
auto
next
=
std
::
next
(
ins
);
if
(
ins
->
name
()
==
"@param"
)
if
(
ins
->
name
()
==
"@param"
)
{
{
p
.
move_instruction
(
ins
,
p
.
begin
());
m
.
move_instruction
(
ins
,
m
.
begin
());
}
}
ins
=
next
;
ins
=
next
;
}
}
...
@@ -364,18 +364,18 @@ struct stream_info
...
@@ -364,18 +364,18 @@ struct stream_info
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
find_concurrent_instructions
(
module
&
p
)
const
find_concurrent_instructions
(
module
&
m
)
const
{
{
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge_from
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge_from
;
dominator_info
di
=
compute_dominator
(
p
);
dominator_info
di
=
compute_dominator
(
m
);
result
.
reserve
(
p
.
size
());
result
.
reserve
(
m
.
size
());
merge_from
.
reserve
(
p
.
size
());
merge_from
.
reserve
(
m
.
size
());
for
(
auto
ins
:
reverse_iterator_for
(
p
))
for
(
auto
ins
:
reverse_iterator_for
(
m
))
{
{
for
(
auto
&&
arg
:
ins
->
outputs
())
for
(
auto
&&
arg
:
ins
->
outputs
())
{
{
if
(
not
p
.
has_instruction
(
arg
))
if
(
not
m
.
has_instruction
(
arg
))
continue
;
continue
;
if
(
is_merge_point
(
arg
))
if
(
is_merge_point
(
arg
))
merge_from
[
ins
].
insert
(
arg
);
merge_from
[
ins
].
insert
(
arg
);
...
@@ -415,18 +415,18 @@ struct stream_info
...
@@ -415,18 +415,18 @@ struct stream_info
}
}
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
get_conflicts
(
module
&
p
)
get_conflicts
(
module
&
m
)
{
{
using
conflict_table_type
=
using
conflict_table_type
=
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
;
conflict_table_type
conflict_table
;
conflict_table_type
conflict_table
;
auto
concur_ins
=
this
->
find_concurrent_instructions
(
p
);
auto
concur_ins
=
this
->
find_concurrent_instructions
(
m
);
// Compute an index for each instruction
// Compute an index for each instruction
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2index
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2index
;
std
::
size_t
index_total
=
0
;
std
::
size_t
index_total
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
ins2index
[
ins
]
=
index_total
++
;
ins2index
[
ins
]
=
index_total
++
;
std
::
vector
<
conflict_table_type
>
thread_conflict_tables
(
std
::
vector
<
conflict_table_type
>
thread_conflict_tables
(
...
@@ -507,21 +507,21 @@ struct stream_info
...
@@ -507,21 +507,21 @@ struct stream_info
}
}
};
};
void
schedule
::
apply
(
module
&
p
)
const
void
schedule
::
apply
(
module
&
m
)
const
{
{
if
(
not
enable
)
if
(
not
enable
)
return
;
return
;
stream_info
si
;
stream_info
si
;
si
.
calc_implicit_deps
(
p
);
si
.
calc_implicit_deps
(
m
);
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
si
.
accumulate_weights
(
last
,
model
);
si
.
accumulate_weights
(
last
,
model
);
auto
nstreams
=
si
.
assign_streams
(
p
,
model
.
concurrency
());
auto
nstreams
=
si
.
assign_streams
(
m
,
model
.
concurrency
());
si
.
sort
(
p
,
model
.
concurrency
());
si
.
sort
(
m
,
model
.
concurrency
());
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{})
or
enabled
(
MIGRAPHX_TRACE_SCHEDULE
{}))
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{})
or
enabled
(
MIGRAPHX_TRACE_SCHEDULE
{}))
{
{
p
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
m
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
if
(
ins
->
name
()
==
"@param"
and
not
contains
(
si
.
weights
,
ins
))
if
(
ins
->
name
()
==
"@param"
and
not
contains
(
si
.
weights
,
ins
))
return
;
return
;
...
@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
...
@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
unordered_map
<
std
::
size_t
,
std
::
unordered_set
<
std
::
size_t
>>
waited_for
;
std
::
unordered_map
<
std
::
size_t
,
std
::
unordered_set
<
std
::
size_t
>>
waited_for
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
std
::
size_t
>>
ins2waited
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
std
::
size_t
>>
ins2waited
;
ins2wait
.
reserve
(
p
.
size
());
ins2wait
.
reserve
(
m
.
size
());
ins2waited
.
reserve
(
p
.
size
());
ins2waited
.
reserve
(
m
.
size
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Only schedule instructions that have a stream
// Only schedule instructions that have a stream
if
(
not
si
.
has_stream
(
ins
))
if
(
not
si
.
has_stream
(
ins
))
...
@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
...
@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream
// Schedule instruction on the stream
auto
stream
=
si
.
get_stream
(
ins
);
auto
stream
=
si
.
get_stream
(
ins
);
assert
(
stream
<
model
.
concurrency
());
assert
(
stream
<
model
.
concurrency
());
model
.
sched
(
p
,
ins
,
stream
);
model
.
sched
(
m
,
ins
,
stream
);
// Insert wait instructions
// Insert wait instructions
if
(
si
.
is_merge_point
(
ins
,
stream
))
if
(
si
.
is_merge_point
(
ins
,
stream
))
{
{
...
@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
...
@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if
(
not
contains
(
ins2wait
,
i
))
if
(
not
contains
(
ins2wait
,
i
))
{
{
ins2wait
[
i
]
=
wait_id
;
ins2wait
[
i
]
=
wait_id
;
model
.
record
(
p
,
i
,
wait_id
);
model
.
record
(
m
,
i
,
wait_id
);
wait_id
++
;
wait_id
++
;
}
}
auto
w
=
ins2wait
.
at
(
i
);
auto
w
=
ins2wait
.
at
(
i
);
// If we already waited for the event on this stream then dont
// If we already waited for the event on this stream then dont
// insert another wait event
// insert another wait event
if
(
not
contains
(
waited_for
[
stream
],
w
))
if
(
not
contains
(
waited_for
[
stream
],
w
))
model
.
wait
(
p
,
ins
,
w
);
model
.
wait
(
m
,
ins
,
w
);
// Store the event as waited
// Store the event as waited
waited_for
[
stream
].
insert
(
w
);
waited_for
[
stream
].
insert
(
w
);
// Store all wait events that have been waited on prior to the recorded instruction
// Store all wait events that have been waited on prior to the recorded instruction
...
@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
...
@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
}
}
// Add memory conflicts
// Add memory conflicts
auto
conflict_table
=
si
.
get_conflicts
(
p
);
auto
conflict_table
=
si
.
get_conflicts
(
m
);
for
(
auto
&&
ip
:
conflict_table
)
for
(
auto
&&
ip
:
conflict_table
)
{
{
if
(
ip
.
second
.
empty
())
if
(
ip
.
second
.
empty
())
...
@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
...
@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
args
.
push_back
(
ip
.
first
);
args
.
push_back
(
ip
.
first
);
args
.
insert
(
args
.
end
(),
ip
.
second
.
begin
(),
ip
.
second
.
end
());
args
.
insert
(
args
.
end
(),
ip
.
second
.
begin
(),
ip
.
second
.
end
());
p
.
insert_instruction
(
std
::
next
(
ip
.
first
),
make_op
(
"identity"
),
args
);
m
.
insert_instruction
(
std
::
next
(
ip
.
first
),
make_op
(
"identity"
),
args
);
}
}
}
}
...
...
src/simplify_algebra.cpp
View file @
d9a5acbd
This diff is collapsed.
Click to expand it.
src/simplify_qdq.cpp
View file @
d9a5acbd
...
@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
...
@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match
::
arg
(
1
)(
dequantizelinear_op
(
"x2"
,
"scale2"
)));
match
::
arg
(
1
)(
dequantizelinear_op
(
"x2"
,
"scale2"
)));
}
}
void
apply
(
module
&
m
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
qop
=
r
.
result
;
auto
qop
=
r
.
result
;
auto
q1
=
r
.
instructions
[
"x1"
];
auto
q1
=
r
.
instructions
[
"x1"
];
...
...
src/simplify_reshapes.cpp
View file @
d9a5acbd
...
@@ -70,19 +70,19 @@ struct find_reshaper
...
@@ -70,19 +70,19 @@ struct find_reshaper
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
while
(
is_reshaper
(
reshapes
.
back
()))
{
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
reshapes
.
push_back
(
input
);
}
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
m
.
end
(),
m
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
for
(
auto
start
:
iterator_for
(
reshapes
))
{
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
...
@@ -96,7 +96,7 @@ struct find_reshaper
...
@@ -96,7 +96,7 @@ struct find_reshaper
}
}
if
(
r
.
first
!=
r
.
second
)
if
(
r
.
first
!=
r
.
second
)
{
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
m
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
}
}
};
};
...
@@ -117,10 +117,10 @@ struct find_nop_reshapes
...
@@ -117,10 +117,10 @@ struct find_nop_reshapes
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
)));
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
)));
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
p
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
}
};
};
...
@@ -132,7 +132,7 @@ struct find_transpose
...
@@ -132,7 +132,7 @@ struct find_transpose
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
x
=
ins
;
auto
x
=
ins
;
...
@@ -149,11 +149,11 @@ struct find_transpose
...
@@ -149,11 +149,11 @@ struct find_transpose
return
;
return
;
if
(
is_no_transpose
(
dims
))
if
(
is_no_transpose
(
dims
))
{
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
}
else
else
{
{
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
dims
}}),
t
->
inputs
().
front
());
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
dims
}}),
t
->
inputs
().
front
());
}
}
}
}
...
@@ -223,7 +223,7 @@ struct find_nested_slice
...
@@ -223,7 +223,7 @@ struct find_nested_slice
return
result
;
return
result
;
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
slice
=
ins
->
inputs
().
front
();
auto
slice
=
ins
->
inputs
().
front
();
...
@@ -241,7 +241,7 @@ struct find_nested_slice
...
@@ -241,7 +241,7 @@ struct find_nested_slice
op
.
starts
.
push_back
(
pp
.
second
.
first
);
op
.
starts
.
push_back
(
pp
.
second
.
first
);
op
.
ends
.
push_back
(
pp
.
second
.
second
);
op
.
ends
.
push_back
(
pp
.
second
.
second
);
}
}
p
.
replace_instruction
(
ins
,
op
,
input
);
m
.
replace_instruction
(
ins
,
op
,
input
);
}
}
};
};
...
@@ -252,7 +252,7 @@ struct find_concat_transpose
...
@@ -252,7 +252,7 @@ struct find_concat_transpose
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
trans_inputs
=
ins
->
inputs
();
auto
trans_inputs
=
ins
->
inputs
();
...
@@ -279,14 +279,14 @@ struct find_concat_transpose
...
@@ -279,14 +279,14 @@ struct find_concat_transpose
std
::
vector
<
instruction_ref
>
inputs
;
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
return
p
.
insert_instruction
(
return
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
i
);
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
i
);
});
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
concat
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
auto
t
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
ipermutation
}}),
concat
);
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
ipermutation
}}),
concat
);
assert
(
ins
->
get_shape
().
lens
()
==
t
->
get_shape
().
lens
());
assert
(
ins
->
get_shape
().
lens
()
==
t
->
get_shape
().
lens
());
p
.
replace_instruction
(
ins
,
t
);
m
.
replace_instruction
(
ins
,
t
);
}
}
};
};
...
@@ -303,7 +303,7 @@ struct find_nested_concat
...
@@ -303,7 +303,7 @@ struct find_nested_concat
return
op
.
axis
;
return
op
.
axis
;
}
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
{
auto
ins
=
mr
.
result
;
auto
ins
=
mr
.
result
;
auto
axis
=
get_axis
(
ins
);
auto
axis
=
get_axis
(
ins
);
...
@@ -317,7 +317,7 @@ struct find_nested_concat
...
@@ -317,7 +317,7 @@ struct find_nested_concat
args
.
push_back
(
i
);
args
.
push_back
(
i
);
}
}
})(
ins
->
inputs
());
})(
ins
->
inputs
());
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
}
}
};
};
...
@@ -329,7 +329,7 @@ struct find_resize
...
@@ -329,7 +329,7 @@ struct find_resize
match
::
args
(
match
::
name
(
"reshape"
).
bind
(
"data"
),
match
::
is_constant
().
bind
(
"ind"
)));
match
::
args
(
match
::
name
(
"reshape"
).
bind
(
"data"
),
match
::
is_constant
().
bind
(
"ind"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
ins_rsp
=
r
.
instructions
[
"data"
];
auto
ins_rsp
=
r
.
instructions
[
"data"
];
...
@@ -417,13 +417,13 @@ struct find_resize
...
@@ -417,13 +417,13 @@ struct find_resize
}
}
auto
in_rsp
=
ins_rsp
->
inputs
().
front
();
auto
in_rsp
=
ins_rsp
->
inputs
().
front
();
auto
rsp_data
=
p
.
insert_instruction
(
auto
rsp_data
=
m
.
insert_instruction
(
ins_rsp
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
in_dims
}}),
in_rsp
);
ins_rsp
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
in_dims
}}),
in_rsp
);
auto
mb_rsp
=
p
.
insert_instruction
(
auto
mb_rsp
=
m
.
insert_instruction
(
ins_rsp
,
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_dims
}}),
rsp_data
);
ins_rsp
,
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_dims
}}),
rsp_data
);
auto
std_mb
=
p
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"contiguous"
),
mb_rsp
);
auto
std_mb
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"contiguous"
),
mb_rsp
);
std
::
vector
<
int64_t
>
rsp_dims
(
out_lens
.
begin
(),
out_lens
.
end
());
std
::
vector
<
int64_t
>
rsp_dims
(
out_lens
.
begin
(),
out_lens
.
end
());
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
rsp_dims
}}),
std_mb
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
rsp_dims
}}),
std_mb
);
}
}
};
};
...
@@ -436,7 +436,7 @@ struct find_where_op
...
@@ -436,7 +436,7 @@ struct find_where_op
match
::
is_constant
().
bind
(
"ind"
)));
match
::
is_constant
().
bind
(
"ind"
)));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
concat
=
r
.
instructions
[
"data"
];
auto
concat
=
r
.
instructions
[
"data"
];
...
@@ -475,11 +475,11 @@ struct find_where_op
...
@@ -475,11 +475,11 @@ struct find_where_op
if
(
val
)
if
(
val
)
{
{
p
.
replace_instruction
(
ins
,
inputs
.
at
(
0
));
m
.
replace_instruction
(
ins
,
inputs
.
at
(
0
));
}
}
else
else
{
{
p
.
replace_instruction
(
ins
,
inputs
.
at
(
1
));
m
.
replace_instruction
(
ins
,
inputs
.
at
(
1
));
}
}
}
}
};
};
...
@@ -496,7 +496,7 @@ struct find_reshape_cont
...
@@ -496,7 +496,7 @@ struct find_reshape_cont
match
::
any
()));
match
::
any
()));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
ins_cont
=
r
.
instructions
[
"cont"
];
auto
ins_cont
=
r
.
instructions
[
"cont"
];
...
@@ -530,11 +530,11 @@ struct find_reshape_cont
...
@@ -530,11 +530,11 @@ struct find_reshape_cont
else
else
{
{
inputs
.
push_back
(
inputs
.
push_back
(
p
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
in
));
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
in
));
}
}
}
}
auto
out
=
p
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
auto
out
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
p
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
out_dims
}}),
out
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
out_dims
}}),
out
);
}
}
};
};
...
@@ -564,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
...
@@ -564,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
match
::
args
(
match_transpose_contiguous_reshaper
()));
match
::
args
(
match_transpose_contiguous_reshaper
()));
}
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
cont_ins
=
r
.
instructions
[
"cont_ins"
];
auto
cont_ins
=
r
.
instructions
[
"cont_ins"
];
auto
unary_op_name
=
ins
->
get_operator
().
name
();
auto
unary_op_name
=
ins
->
get_operator
().
name
();
auto
unary_ins
=
p
.
insert_instruction
(
cont_ins
,
make_op
(
unary_op_name
),
trans_ins
);
auto
unary_ins
=
m
.
insert_instruction
(
cont_ins
,
make_op
(
unary_op_name
),
trans_ins
);
auto
new_cont_ins
=
p
.
insert_instruction
(
cont_ins
,
make_op
(
"contiguous"
),
unary_ins
);
auto
new_cont_ins
=
m
.
insert_instruction
(
cont_ins
,
make_op
(
"contiguous"
),
unary_ins
);
// older cont and reshape are removed by deadcode elimination
// older cont and reshape are removed by deadcode elimination
p
.
replace_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
new_cont_ins
);
m
.
replace_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
new_cont_ins
);
}
}
};
};
void
simplify_reshapes
::
apply
(
module
&
p
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
{
match
::
find_matches
(
p
,
match
::
find_matches
(
m
,
find_where_op
{},
find_where_op
{},
find_resize
{},
find_resize
{},
find_reshape_cont
{},
find_reshape_cont
{},
...
@@ -594,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
...
@@ -594,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
find_nested_slice
{},
find_nested_slice
{},
find_nested_concat
{},
find_nested_concat
{},
find_transpose_contiguous_reshaper_unary
{});
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
p
);
dead_code_elimination
{}.
apply
(
m
);
}
}
}
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment