Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
11e155c2
Commit
11e155c2
authored
Jun 13, 2022
by
Paul
Browse files
Merge
parents
8a9c5bce
aa7ff911
Changes
397
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
710 additions
and
744 deletions
+710
-744
src/rewrite_pooling.cpp
src/rewrite_pooling.cpp
+7
-8
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+303
-319
src/schedule.cpp
src/schedule.cpp
+37
-37
src/shape.cpp
src/shape.cpp
+11
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+79
-80
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+1
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+37
-38
src/targets/cpu/copy.cpp
src/targets/cpu/copy.cpp
+0
-1
src/targets/cpu/gather.cpp
src/targets/cpu/gather.cpp
+0
-1
src/targets/cpu/include/migraphx/cpu/parallel.hpp
src/targets/cpu/include/migraphx/cpu/parallel.hpp
+9
-0
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
+2
-5
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+2
-7
src/targets/cpu/pooling.cpp
src/targets/cpu/pooling.cpp
+3
-113
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+14
-4
src/targets/gpu/analyze_streams.cpp
src/targets/gpu/analyze_streams.cpp
+10
-10
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+105
-0
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+35
-7
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+39
-0
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+16
-33
src/targets/gpu/compile_pointwise.cpp
src/targets/gpu/compile_pointwise.cpp
+0
-80
No files found.
src/rewrite_pooling.cpp
View file @
11e155c2
...
...
@@ -12,9 +12,9 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_pooling
::
apply
(
module
&
prog
)
const
void
rewrite_pooling
::
apply
(
module
&
m
)
const
{
for
(
auto
ins
:
iterator_for
(
prog
))
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
!=
"pooling"
)
continue
;
...
...
@@ -33,26 +33,25 @@ void rewrite_pooling::apply(module& prog) const
continue
;
std
::
int64_t
n
=
s
.
lens
()[
0
];
std
::
int64_t
c
=
s
.
lens
()[
1
];
auto
reshape
=
prog
.
insert_instruction
(
auto
reshape
=
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
{
n
*
c
,
-
1
}}}),
ins
->
inputs
().
front
());
instruction_ref
pooling
{};
// average pooling
if
(
op
.
mode
==
"
average
"
)
if
(
op
.
mode
==
op
::
pooling_mode
::
average
)
{
pooling
=
prog
.
insert_instruction
(
ins
,
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
}}}),
reshape
);
pooling
=
m
.
insert_instruction
(
ins
,
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
1
}}}),
reshape
);
}
// max pooling
else
{
pooling
=
prog
.
insert_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
reshape
);
pooling
=
m
.
insert_instruction
(
ins
,
make_op
(
"reduce_max"
,
{{
"axes"
,
{
1
}}}),
reshape
);
}
std
::
vector
<
int64_t
>
rsp_lens
(
lens
.
size
(),
1
);
rsp_lens
[
0
]
=
n
;
rsp_lens
[
1
]
=
c
;
prog
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
pooling
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
rsp_lens
}}),
pooling
);
}
}
...
...
src/rewrite_rnn.cpp
100755 → 100644
View file @
11e155c2
...
...
@@ -30,27 +30,27 @@
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_rnn
::
apply
(
module
&
prog
)
const
void
rewrite_rnn
::
apply
(
module
&
m
)
const
{
for
(
auto
ins
:
iterator_for
(
prog
))
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"rnn"
)
{
apply_vanilla_rnn
(
prog
,
ins
);
apply_vanilla_rnn
(
m
,
ins
);
}
else
if
(
ins
->
name
()
==
"gru"
)
{
apply_gru
(
prog
,
ins
);
apply_gru
(
m
,
ins
);
}
else
if
(
ins
->
name
()
==
"lstm"
)
{
apply_lstm
(
prog
,
ins
);
apply_lstm
(
m
,
ins
);
}
}
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void
rewrite_rnn
::
apply_vanilla_rnn
(
module
&
prog
,
instruction_ref
ins
)
const
void
rewrite_rnn
::
apply_vanilla_rnn
(
module
&
m
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"rnn"
);
// could be 3 to 6 inputs, but the parse_rnn function will
...
...
@@ -71,37 +71,37 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
op
::
rnn_direction
dirct
=
rnn_op
.
direction
;
// process sequence length
instruction_ref
seq_lens
=
prog
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"undefined"
)
{
seq_lens
=
args
[
4
];
}
bool
variable_seq_len
=
is_variable_seq_lens
(
prog
,
seq_lens
);
bool
variable_seq_len
=
is_variable_seq_lens
(
m
,
seq_lens
);
instruction_ref
last_output
{};
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
auto
w_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
auto
w_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
1
]);
// hidden state weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
auto
r_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
auto
r_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
2
]);
// process bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
bias_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
3
]);
}
...
...
@@ -111,57 +111,56 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ih_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_forward
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
vanilla_rnn_cell
(
true
,
prog
,
m
,
ins
,
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
seq_lens
,
ih_forward
},
actv_funcs
.
at
(
0
));
if
(
variable_seq_len
)
{
args
[
0
]
=
prog
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
}
auto
ret_reverse
=
vanilla_rnn_cell
(
false
,
prog
,
m
,
ins
,
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
seq_lens
,
ih_reverse
},
actv_funcs
.
at
(
1
));
auto
concat_output
=
prog
.
insert_instruction
(
auto
concat_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_output
);
last_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_output
);
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// sequence len is 1
if
(
ret_forward
[
0
]
==
prog
.
end
())
if
(
ret_forward
[
0
]
==
m
.
end
())
{
prog
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ret_forward
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ret_reverse
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
...
...
@@ -175,7 +174,7 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
auto
r
=
args
[
2
];
// process bias and initial hidden state
instruction_ref
bias
=
prog
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
...
...
@@ -189,43 +188,42 @@ void rewrite_rnn::apply_vanilla_rnn(module& prog, instruction_ref ins) const
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
if
(
!
is_forward
and
variable_seq_len
)
{
args
[
0
]
=
prog
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
}
auto
ret
=
vanilla_rnn_cell
(
is_forward
,
prog
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
seq_lens
,
ih
},
actv_funcs
.
at
(
0
));
last_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
1
]);
is_forward
,
m
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
seq_lens
,
ih
},
actv_funcs
.
at
(
0
));
last_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
1
]);
// following logic is to ensure the last instruction is a
// concat instruction
// sequence len is 1
if
(
ret
[
0
]
==
prog
.
end
())
if
(
ret
[
0
]
==
m
.
end
())
{
prog
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret
[
1
]);
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_arg0
,
concat_arg1
);
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_arg0
,
concat_arg1
);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins
=
pad_hidden_states
(
prog
,
args
[
0
],
seq_lens
,
ins
);
replace_last_hs_output
(
prog
,
ins
,
seq_lens
,
last_output
,
dirct
);
ins
=
pad_hidden_states
(
m
,
args
[
0
],
seq_lens
,
ins
);
replace_last_hs_output
(
m
,
ins
,
seq_lens
,
last_output
,
dirct
);
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
vanilla_rnn_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
operation
&
actv_func
)
const
...
...
@@ -240,60 +238,60 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
// squeeze and transpose w
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
w
);
auto
tran_sw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sw
);
auto
sw
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
w
);
auto
tran_sw
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sw
);
// squeeze and transpose r
auto
sr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
r
);
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sr
);
auto
sr
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
r
);
auto
tran_sr
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ih
);
auto
sih
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ih
);
auto
sih_lens
=
sih
->
get_shape
().
lens
();
// bias
instruction_ref
bb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
m
.
end
())
{
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
bias
);
auto
wb
=
prog
.
insert_instruction
(
auto
sbias
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
bias
);
auto
wb
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
hs
}}}),
sbias
);
auto
rb
=
prog
.
insert_instruction
(
auto
rb
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
hs
}},
{
"ends"
,
{
2
*
hs
}}}),
sbias
);
auto
wrb
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
wb
,
rb
);
bb
=
prog
.
insert_instruction
(
auto
wrb
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
wb
,
rb
);
bb
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
sih_lens
}}),
wrb
);
}
instruction_ref
hidden_out
=
prog
.
end
();
instruction_ref
hidden_out
=
m
.
end
();
instruction_ref
last_out
{};
last_out
=
prog
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
sih
);
long
seq_len
=
get_seq_len
(
prog
,
seq
,
seq_lens
);
last_out
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
sih
);
long
seq_len
=
get_seq_len
(
m
,
seq
,
seq_lens
);
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
auto
xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
seq_index
}},
{
"ends"
,
{
seq_index
+
1
}}}),
seq
);
auto
cont_xt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
cont_xt
);
auto
xt_wi
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
xt
,
tran_sw
);
auto
ht_ri
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
tran_sr
);
if
(
bias
!=
prog
.
end
())
auto
cont_xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
xt
);
xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
cont_xt
);
auto
xt_wi
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
xt
,
tran_sw
);
auto
ht_ri
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
tran_sr
);
if
(
bias
!=
m
.
end
())
{
xt_wi
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_wi
,
bb
);
xt_wi
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_wi
,
bb
);
}
auto
xt_ht
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_wi
,
ht_ri
);
auto
xt_ht
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_wi
,
ht_ri
);
// apply activation function
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
xt_ht
);
auto
ht
=
m
.
insert_instruction
(
ins
,
actv_func
,
xt_ht
);
sih
=
ht
;
// add the dimensions of sequence length (axis 0 for sequence length,
// axis 1 for num_directions
last_out
=
prog
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
ht
);
last_out
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
ht
);
// concatenation for the last last_out is performed in the apply()
// function to ensure the last instruction is concat, then we have
...
...
@@ -304,14 +302,14 @@ std::vector<instruction_ref> rewrite_rnn::vanilla_rnn_cell(bool is_forward,
{
hidden_out
=
(
seq_index
==
0
)
?
last_out
:
prog
.
insert_instruction
(
:
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
hidden_out
,
last_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
last_out
:
prog
.
insert_instruction
(
:
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
last_out
,
hidden_out
);
}
}
...
...
@@ -358,7 +356,7 @@ std::vector<operation> rewrite_rnn::vanilla_rnn_actv_funcs(instruction_ref ins)
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void
rewrite_rnn
::
apply_gru
(
module
&
prog
,
instruction_ref
ins
)
const
void
rewrite_rnn
::
apply_gru
(
module
&
m
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"gru"
);
const
auto
actv_funcs
=
gru_actv_funcs
(
ins
);
...
...
@@ -379,37 +377,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
op
::
rnn_direction
dirct
=
gru_op
.
direction
;
// process sequence length
instruction_ref
seq_lens
=
prog
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"undefined"
)
{
seq_lens
=
args
[
4
];
}
bool
variable_seq_len
=
is_variable_seq_lens
(
prog
,
seq_lens
);
bool
variable_seq_len
=
is_variable_seq_lens
(
m
,
seq_lens
);
instruction_ref
last_output
{};
if
(
dirct
==
op
::
rnn_direction
::
bidirectional
)
{
// w weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
auto
w_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
auto
w_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
1
]);
// r weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
auto
r_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
auto
r_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
2
]);
// bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
bias_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
3
]);
}
...
...
@@ -418,20 +416,20 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
==
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ih_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_forward
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_cell
(
true
,
prog
,
m
,
ins
,
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
seq_lens
,
ih_forward
},
gru_op
.
linear_before_reset
,
...
...
@@ -440,38 +438,37 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
if
(
variable_seq_len
)
{
args
[
0
]
=
prog
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
}
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
m
,
ins
,
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
seq_lens
,
ih_reverse
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
));
auto
concat_output
=
prog
.
insert_instruction
(
auto
concat_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_output
);
last_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_output
);
// The following logic is to ensure the last instruction rewritten
// from gru operator is a concat
if
(
ret_forward
[
0
]
==
prog
.
end
())
if
(
ret_forward
[
0
]
==
m
.
end
())
{
prog
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
1
],
ret_reverse
[
1
]);
}
else
{
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ret_forward
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ret_reverse
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_reverse
[
1
],
ret_reverse
[
0
]);
prog
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
...
...
@@ -483,7 +480,7 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
...
...
@@ -497,47 +494,46 @@ void rewrite_rnn::apply_gru(module& prog, instruction_ref ins) const
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
if
(
!
is_forward
and
variable_seq_len
)
{
args
[
0
]
=
prog
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
}
auto
ret
=
gru_cell
(
is_forward
,
prog
,
m
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
seq_lens
,
ih
},
gru_op
.
linear_before_reset
,
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
1
]);
last_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
1
]);
if
(
ret
[
0
]
==
prog
.
end
())
if
(
ret
[
0
]
==
m
.
end
())
{
prog
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret
[
1
]);
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret
[
1
]);
}
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
prog
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_arg0
,
concat_arg1
);
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_arg0
,
concat_arg1
);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
ins
=
pad_hidden_states
(
prog
,
args
[
0
],
seq_lens
,
ins
);
replace_last_hs_output
(
prog
,
ins
,
seq_lens
,
last_output
,
dirct
);
ins
=
pad_hidden_states
(
m
,
args
[
0
],
seq_lens
,
ins
);
replace_last_hs_output
(
m
,
ins
,
seq_lens
,
last_output
,
dirct
);
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
gru_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
...
...
@@ -552,7 +548,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
auto
seq_lens
=
inputs
.
at
(
4
);
auto
ih
=
inputs
.
at
(
5
);
instruction_ref
hidden_states
=
prog
.
end
();
instruction_ref
hidden_states
=
m
.
end
();
instruction_ref
last_output
{};
migraphx
::
shape
seq_shape
=
seq
->
get_shape
();
migraphx
::
shape
r_shape
=
r
->
get_shape
();
...
...
@@ -560,127 +556,127 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
migraphx
::
shape
ss
(
seq_shape
.
type
(),
{
seq_shape
.
lens
()[
1
],
r_shape
.
lens
()[
2
]});
std
::
vector
<
float
>
data
(
ss
.
elements
(),
1.0
f
);
auto
l1
=
prog
.
add_literal
(
migraphx
::
literal
{
ss
,
data
});
auto
l1
=
m
.
add_literal
(
migraphx
::
literal
{
ss
,
data
});
// w matrix squeeze to 2-dim and do a transpose
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
w
);
auto
tw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sw
);
auto
sw
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
w
);
auto
tw
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sw
);
// r slide to two part, zr and h
auto
sr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
r
);
auto
rzr
=
prog
.
insert_instruction
(
auto
sr
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
r
);
auto
rzr
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
2
*
hs
}}}),
sr
);
auto
trzr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
rzr
);
auto
trzr
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
rzr
);
auto
rh
=
prog
.
insert_instruction
(
auto
rh
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
*
hs
}},
{
"ends"
,
{
3
*
hs
}}}),
sr
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
rh
);
auto
trh
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ih
);
auto
sih
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ih
);
size_t
bs
=
ih
->
get_shape
().
lens
()[
1
];
// bias
instruction_ref
bwb
{};
instruction_ref
brb_zr
{};
instruction_ref
brb_h
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
m
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
bias
);
auto
wb
=
prog
.
insert_instruction
(
auto
sbias
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
bias
);
auto
wb
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
3
*
hs
}}}),
sbias
);
bwb
=
prog
.
insert_instruction
(
bwb
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
bs
,
static_cast
<
size_t
>
(
3
*
hs
)}}}),
wb
);
auto
rb_zr
=
prog
.
insert_instruction
(
auto
rb_zr
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
3
*
hs
}},
{
"ends"
,
{
5
*
hs
}}}),
sbias
);
auto
rb_h
=
prog
.
insert_instruction
(
auto
rb_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
5
*
hs
}},
{
"ends"
,
{
6
*
hs
}}}),
sbias
);
brb_zr
=
prog
.
insert_instruction
(
brb_zr
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
bs
,
static_cast
<
size_t
>
(
2
*
hs
)}}}),
rb_zr
);
brb_h
=
prog
.
insert_instruction
(
brb_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
bs
,
static_cast
<
size_t
>
(
hs
)}}}),
rb_h
);
}
long
seq_len
=
get_seq_len
(
prog
,
seq
,
seq_lens
);
long
seq_len
=
get_seq_len
(
m
,
seq
,
seq_lens
);
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
auto
xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
seq_index
}},
{
"ends"
,
{
seq_index
+
1
}}}),
seq
);
auto
cont_xt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
cont_xt
);
auto
cont_xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
xt
);
xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
cont_xt
);
auto
xt_w
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
xt
,
tw
);
auto
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
trzr
);
if
(
bias
!=
prog
.
end
())
auto
xt_w
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
xt
,
tw
);
auto
ih1_rzr
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
trzr
);
if
(
bias
!=
m
.
end
())
{
xt_w
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_w
,
bwb
);
ih1_rzr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ih1_rzr
,
brb_zr
);
xt_w
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_w
,
bwb
);
ih1_rzr
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ih1_rzr
,
brb_zr
);
}
auto
xw_z
=
prog
.
insert_instruction
(
auto
xw_z
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
hs
}}}),
xt_w
);
auto
xw_r
=
prog
.
insert_instruction
(
auto
xw_r
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
hs
}},
{
"ends"
,
{
2
*
hs
}}}),
xt_w
);
auto
xw_h
=
prog
.
insert_instruction
(
auto
xw_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
*
hs
}},
{
"ends"
,
{
3
*
hs
}}}),
xt_w
);
auto
hr_z
=
prog
.
insert_instruction
(
auto
hr_z
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
hs
}}}),
ih1_rzr
);
auto
hr_r
=
prog
.
insert_instruction
(
auto
hr_r
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
hs
}},
{
"ends"
,
{
2
*
hs
}}}),
ih1_rzr
);
auto
xw_hr_z
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xw_z
,
hr_z
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_z
);
auto
xw_hr_z
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xw_z
,
hr_z
);
auto
zt
=
m
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_z
);
auto
xw_hr_r
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xw_r
,
hr_r
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_r
);
auto
xw_hr_r
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xw_r
,
hr_r
);
auto
rt
=
m
.
insert_instruction
(
ins
,
actv_func1
,
xw_hr_r
);
instruction_ref
hr_h
{};
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
rt_ht1
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
rt
,
sih
);
hr_h
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
rt_ht1
,
trh
);
if
(
bias
!=
prog
.
end
())
auto
rt_ht1
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
rt
,
sih
);
hr_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
rt_ht1
,
trh
);
if
(
bias
!=
m
.
end
())
{
hr_h
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
hr_h
,
brb_h
);
hr_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
hr_h
,
brb_h
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
ht1_rh
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
trh
);
if
(
bias
!=
prog
.
end
())
auto
ht1_rh
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
trh
);
if
(
bias
!=
m
.
end
())
{
ht1_rh
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ht1_rh
,
brb_h
);
ht1_rh
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ht1_rh
,
brb_h
);
}
hr_h
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
rt
,
ht1_rh
);
hr_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
rt
,
ht1_rh
);
}
auto
xw_hr_h
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xw_h
,
hr_h
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xw_hr_h
);
auto
xw_hr_h
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xw_h
,
hr_h
);
auto
ht
=
m
.
insert_instruction
(
ins
,
actv_func2
,
xw_hr_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
one_minus_zt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
l1
,
zt
);
auto
one_minus_zt_ht
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
one_minus_zt
,
ht
);
auto
zt_ht1
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
zt
,
sih
);
sih
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
one_minus_zt_ht
,
zt_ht1
);
last_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
sih
);
auto
one_minus_zt
=
m
.
insert_instruction
(
ins
,
make_op
(
"sub"
),
l1
,
zt
);
auto
one_minus_zt_ht
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
one_minus_zt
,
ht
);
auto
zt_ht1
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
zt
,
sih
);
sih
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
one_minus_zt_ht
,
zt_ht1
);
last_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
sih
);
if
(
i
<
seq_len
-
1
)
{
...
...
@@ -689,7 +685,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states
=
(
seq_index
==
0
)
?
last_output
:
prog
.
insert_instruction
(
:
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
hidden_states
,
last_output
);
}
else
...
...
@@ -697,7 +693,7 @@ std::vector<instruction_ref> rewrite_rnn::gru_cell(bool is_forward,
hidden_states
=
(
seq_index
==
seq_len
-
1
)
?
last_output
:
prog
.
insert_instruction
(
:
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
last_output
,
hidden_states
);
}
}
...
...
@@ -748,7 +744,7 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
// for lstm operators
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
void
rewrite_rnn
::
apply_lstm
(
module
&
prog
,
instruction_ref
ins
)
const
void
rewrite_rnn
::
apply_lstm
(
module
&
m
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"lstm"
);
auto
args
=
ins
->
inputs
();
...
...
@@ -767,13 +763,13 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
op
::
rnn_direction
dirct
=
lstm_op
.
direction
;
// process sequence length
instruction_ref
seq_lens
=
prog
.
end
();
instruction_ref
seq_lens
=
m
.
end
();
if
((
args
.
size
()
>=
5
)
&&
args
[
4
]
->
name
()
!=
"undefined"
)
{
seq_lens
=
args
[
4
];
}
bool
variable_seq_len
=
is_variable_seq_lens
(
prog
,
seq_lens
);
bool
variable_seq_len
=
is_variable_seq_lens
(
m
,
seq_lens
);
instruction_ref
last_hs_output
{};
instruction_ref
last_cell_output
{};
...
...
@@ -783,25 +779,25 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
{
// input weight matrix
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
auto
w_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
auto
w_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
1
]);
// hidden state weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
auto
r_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
auto
r_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
2
]);
// process bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
instruction_ref
bias_forward
=
m
.
end
();
instruction_ref
bias_reverse
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias_forward
=
prog
.
insert_instruction
(
bias_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
bias_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
3
]);
}
...
...
@@ -810,15 +806,15 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref
ih_reverse
{};
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
ih_forward
=
prog
.
insert_instruction
(
ih_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
5
]);
ih_reverse
=
prog
.
insert_instruction
(
ih_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
5
]);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ih_forward
=
m
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ih_reverse
=
m
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
}
// process initial cell value
...
...
@@ -826,30 +822,30 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
instruction_ref
ic_reverse
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"undefined"
)
{
ic_forward
=
prog
.
insert_instruction
(
ic_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
6
]);
ic_reverse
=
prog
.
insert_instruction
(
ic_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
6
]);
}
else
{
ic_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ic_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ic_forward
=
m
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ic_reverse
=
m
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
}
// process weight of the peephole
instruction_ref
pph_forward
=
prog
.
end
();
instruction_ref
pph_reverse
=
prog
.
end
();
instruction_ref
pph_forward
=
m
.
end
();
instruction_ref
pph_reverse
=
m
.
end
();
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
{
pph_forward
=
prog
.
insert_instruction
(
pph_forward
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
args
[
7
]);
pph_reverse
=
prog
.
insert_instruction
(
pph_reverse
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
args
[
7
]);
}
auto
ret_forward
=
lstm_cell
(
true
,
prog
,
m
,
ins
,
{
args
[
0
],
w_forward
,
...
...
@@ -865,11 +861,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if
(
variable_seq_len
)
{
args
[
0
]
=
prog
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
}
auto
ret_reverse
=
lstm_cell
(
false
,
prog
,
m
,
ins
,
{
args
[
0
],
w_reverse
,
...
...
@@ -883,36 +879,36 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
actv_funcs
.
at
(
4
),
actv_funcs
.
at
(
5
));
auto
concat_hs_output
=
prog
.
insert_instruction
(
auto
concat_hs_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
1
],
ret_reverse
[
1
]);
auto
concat_cell_output
=
prog
.
insert_instruction
(
auto
concat_cell_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
3
],
ret_reverse
[
3
]);
last_hs_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_hs_output
);
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_hs_output
);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_cell_output
);
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
concat_cell_output
);
// the following logic is to ensure the last instruction is a concat
if
(
ret_forward
[
0
]
==
prog
.
end
())
if
(
ret_forward
[
0
]
==
m
.
end
())
{
cell_outputs
=
concat_cell_output
;
}
else
{
ret_forward
[
1
]
=
prog
.
insert_instruction
(
ret_forward
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_forward
[
0
],
ret_forward
[
1
]);
ret_reverse
[
1
]
=
prog
.
insert_instruction
(
ret_reverse
[
1
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_reverse
[
1
],
ret_reverse
[
0
]);
ret_forward
[
3
]
=
prog
.
insert_instruction
(
ret_forward
[
3
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_forward
[
2
],
ret_forward
[
3
]);
ret_reverse
[
3
]
=
prog
.
insert_instruction
(
ret_reverse
[
3
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret_reverse
[
3
],
ret_reverse
[
2
]);
cell_outputs
=
prog
.
insert_instruction
(
cell_outputs
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
ret_forward
[
3
],
ret_reverse
[
3
]);
}
hidden_state
=
prog
.
replace_instruction
(
hidden_state
=
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
{
ret_forward
[
1
],
ret_reverse
[
1
]});
}
else
...
...
@@ -923,7 +919,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
instruction_ref
bias
=
m
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
bias
=
args
[
3
];
...
...
@@ -937,7 +933,7 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ih
=
m
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
}
// initial cell value
...
...
@@ -948,11 +944,11 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
}
else
{
ic
=
prog
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
ic
=
m
.
add_literal
(
migraphx
::
literal
{
ihc_shape
,
ihc_data
});
}
// process weight of the peephole
instruction_ref
pph
=
prog
.
end
();
instruction_ref
pph
=
m
.
end
();
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
{
pph
=
args
[
7
];
...
...
@@ -960,54 +956,53 @@ void rewrite_rnn::apply_lstm(module& prog, instruction_ref ins) const
if
(
!
is_forward
and
variable_seq_len
)
{
args
[
0
]
=
prog
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
args
[
0
]
=
m
.
insert_instruction
(
ins
,
make_op
(
"rnn_var_sl_shift_sequence"
),
args
[
0
],
seq_lens
);
}
auto
ret
=
lstm_cell
(
is_forward
,
prog
,
m
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
seq_lens
,
ih
,
ic
,
pph
},
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
));
last_hs_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
1
]);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
3
]);
last_hs_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
1
]);
last_cell_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ret
[
3
]);
if
(
ret
[
0
]
==
prog
.
end
())
if
(
ret
[
0
]
==
m
.
end
())
{
cell_outputs
=
ret
[
3
];
hidden_state
=
prog
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret
[
1
]);
hidden_state
=
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
ret
[
1
]);
}
else
{
auto
concat_cell_arg0
=
is_forward
?
ret
[
2
]
:
ret
[
3
];
auto
concat_cell_arg1
=
is_forward
?
ret
[
3
]
:
ret
[
2
];
cell_outputs
=
prog
.
insert_instruction
(
cell_outputs
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_cell_arg0
,
concat_cell_arg1
);
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
hidden_state
=
prog
.
replace_instruction
(
hidden_state
=
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_arg0
,
concat_arg1
);
}
}
// in case of all sequences are of the same lengths and shorter than the
// max sequence length, need to pad 0's at the end for output hidden states
hidden_state
=
pad_hidden_states
(
prog
,
args
[
0
],
seq_lens
,
hidden_state
);
hidden_state
=
pad_hidden_states
(
m
,
args
[
0
],
seq_lens
,
hidden_state
);
// replace last hidden states with corresponding instructions
ins
=
replace_last_hs_output
(
prog
,
hidden_state
,
seq_lens
,
last_hs_output
,
dirct
);
ins
=
replace_last_hs_output
(
m
,
hidden_state
,
seq_lens
,
last_hs_output
,
dirct
);
// replace last cell outputs with corresponding instructions
replace_last_cell_output
(
prog
,
ins
,
seq_lens
,
cell_outputs
,
last_cell_output
,
dirct
);
replace_last_cell_output
(
m
,
ins
,
seq_lens
,
cell_outputs
,
last_cell_output
,
dirct
);
}
// NOLINTNEXTLINE(readability-function-cognitive-complexity)
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
lstm_cell
(
bool
is_forward
,
module
&
prog
,
module
&
m
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
const
operation
&
actv_func1
,
...
...
@@ -1025,8 +1020,8 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
ic
=
inputs
.
at
(
6
);
auto
pph
=
inputs
.
at
(
7
);
instruction_ref
hidden_states
=
prog
.
end
();
instruction_ref
cell_outputs
=
prog
.
end
();
instruction_ref
hidden_states
=
m
.
end
();
instruction_ref
cell_outputs
=
m
.
end
();
instruction_ref
last_hs_output
{};
instruction_ref
last_cell_output
{};
...
...
@@ -1037,35 +1032,35 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// w matrix, squeeze and transpose
auto
sw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
w
);
auto
tsw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sw
);
auto
sw
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
w
);
auto
tsw
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sw
);
// r matrix, squeeze and transpose
auto
sr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
r
);
auto
tsr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sr
);
auto
sr
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
r
);
auto
tsr
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
sr
);
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ih
);
auto
sih
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ih
);
// initial cell state
auto
sic
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ic
);
auto
sic
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
ic
);
auto
ic_lens
=
sic
->
get_shape
().
lens
();
// bias
instruction_ref
wrb
{};
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
m
.
end
())
{
auto
sbias
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
bias
);
auto
ub_wb
=
prog
.
insert_instruction
(
auto
sbias
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
bias
);
auto
ub_wb
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
4
*
hs
}}}),
sbias
);
auto
ub_rb
=
prog
.
insert_instruction
(
auto
ub_rb
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
4
*
hs
}},
{
"ends"
,
{
8
*
hs
}}}),
sbias
);
auto
ub_wrb
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ub_wb
,
ub_rb
);
auto
ub_wrb
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ub_wb
,
ub_rb
);
wrb
=
prog
.
insert_instruction
(
wrb
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
bs
,
4
*
static_cast
<
size_t
>
(
hs
)}}}),
ub_wrb
);
...
...
@@ -1075,92 +1070,91 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
instruction_ref
pphi_brcst
{};
instruction_ref
ppho_brcst
{};
instruction_ref
pphf_brcst
{};
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
m
.
end
())
{
auto
spph
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
pph
);
auto
pphi
=
prog
.
insert_instruction
(
auto
spph
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
pph
);
auto
pphi
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
hs
}}}),
spph
);
pphi_brcst
=
prog
.
insert_instruction
(
pphi_brcst
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
ic_lens
}}),
pphi
);
auto
ppho
=
prog
.
insert_instruction
(
auto
ppho
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
hs
}},
{
"ends"
,
{
2
*
hs
}}}),
spph
);
ppho_brcst
=
prog
.
insert_instruction
(
ppho_brcst
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
ic_lens
}}),
ppho
);
auto
pphf
=
prog
.
insert_instruction
(
auto
pphf
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
*
hs
}},
{
"ends"
,
{
3
*
hs
}}}),
spph
);
pphf_brcst
=
prog
.
insert_instruction
(
pphf_brcst
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
ic_lens
}}),
pphf
);
}
long
seq_len
=
get_seq_len
(
prog
,
seq
,
seq_lens
);
long
seq_len
=
get_seq_len
(
m
,
seq
,
seq_lens
);
for
(
long
i
=
0
;
i
<
seq_len
;
++
i
)
{
long
seq_index
=
is_forward
?
i
:
(
seq_len
-
1
-
i
);
auto
xt
=
prog
.
insert_instruction
(
auto
xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
seq_index
}},
{
"ends"
,
{
seq_index
+
1
}}}),
seq
);
auto
cont_xt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
cont_xt
);
auto
cont_xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
xt
);
xt
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
cont_xt
);
auto
xt_tsw
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
xt
,
tsw
);
auto
sih_tsr
=
prog
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
tsr
);
auto
xt_sih
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_tsw
,
sih_tsr
);
if
(
bias
!=
prog
.
end
())
auto
xt_tsw
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
xt
,
tsw
);
auto
sih_tsr
=
m
.
insert_instruction
(
ins
,
make_op
(
"dot"
),
sih
,
tsr
);
auto
xt_sih
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_tsw
,
sih_tsr
);
if
(
bias
!=
m
.
end
())
{
xt_sih
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_sih
,
wrb
);
xt_sih
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
xt_sih
,
wrb
);
}
auto
it_before_actv
=
prog
.
insert_instruction
(
auto
it_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
hs
}}}),
xt_sih
);
auto
ot_before_actv
=
prog
.
insert_instruction
(
auto
ot_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
hs
}},
{
"ends"
,
{
2
*
hs
}}}),
xt_sih
);
auto
ft_before_actv
=
prog
.
insert_instruction
(
auto
ft_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
2
*
hs
}},
{
"ends"
,
{
3
*
hs
}}}),
xt_sih
);
auto
ct_before_actv
=
prog
.
insert_instruction
(
auto
ct_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
3
*
hs
}},
{
"ends"
,
{
4
*
hs
}}}),
xt_sih
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
m
.
end
())
{
auto
pphi_ct
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
pphi_brcst
,
sic
);
it_before_actv
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
it_before_actv
,
pphi_ct
);
auto
pphi_ct
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
pphi_brcst
,
sic
);
it_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
it_before_actv
,
pphi_ct
);
auto
pphf_ct
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
pphf_brcst
,
sic
);
ft_before_actv
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ft_before_actv
,
pphf_ct
);
auto
pphf_ct
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
pphf_brcst
,
sic
);
ft_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ft_before_actv
,
pphf_ct
);
}
auto
it
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
auto
ft
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
auto
ct
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
auto
it
=
m
.
insert_instruction
(
ins
,
actv_func1
,
it_before_actv
);
auto
ft
=
m
.
insert_instruction
(
ins
,
actv_func1
,
ft_before_actv
);
auto
ct
=
m
.
insert_instruction
(
ins
,
actv_func2
,
ct_before_actv
);
// equation Ct = ft (.) Ct-1 + it (.) ct
auto
ft_cell
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ft
,
sic
);
auto
it_ct
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
it
,
ct
);
auto
cellt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ft_cell
,
it_ct
);
auto
ft_cell
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ft
,
sic
);
auto
it_ct
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
it
,
ct
);
auto
cellt
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ft_cell
,
it_ct
);
if
(
pph
!=
prog
.
end
())
if
(
pph
!=
m
.
end
())
{
auto
ppho_cellt
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ppho_brcst
,
cellt
);
ot_before_actv
=
prog
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ot_before_actv
,
ppho_cellt
);
auto
ppho_cellt
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ppho_brcst
,
cellt
);
ot_before_actv
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
ot_before_actv
,
ppho_cellt
);
}
auto
ot
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
auto
ot
=
m
.
insert_instruction
(
ins
,
actv_func1
,
ot_before_actv
);
// Ht = ot (.) h(Ct)
auto
h_cellt
=
prog
.
insert_instruction
(
ins
,
actv_func3
,
cellt
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ot
,
h_cellt
);
auto
h_cellt
=
m
.
insert_instruction
(
ins
,
actv_func3
,
cellt
);
auto
ht
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
ot
,
h_cellt
);
sic
=
cellt
;
sih
=
ht
;
last_hs_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
ht
);
last_hs_output
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
ht
);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
cellt
);
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
0
,
1
}}}),
cellt
);
if
(
i
<
seq_len
-
1
)
{
...
...
@@ -1173,12 +1167,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{
auto
concat_hs_arg0
=
is_forward
?
hidden_states
:
last_hs_output
;
auto
concat_hs_arg1
=
is_forward
?
last_hs_output
:
hidden_states
;
hidden_states
=
prog
.
insert_instruction
(
hidden_states
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_hs_arg0
,
concat_hs_arg1
);
auto
concat_cell_arg0
=
is_forward
?
cell_outputs
:
last_cell_output
;
auto
concat_cell_arg1
=
is_forward
?
last_cell_output
:
cell_outputs
;
cell_outputs
=
prog
.
insert_instruction
(
cell_outputs
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
concat_cell_arg0
,
concat_cell_arg1
);
}
}
...
...
@@ -1266,10 +1260,10 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
}
}
bool
rewrite_rnn
::
is_variable_seq_lens
(
const
module
&
prog
,
instruction_ref
seq_lens
)
const
bool
rewrite_rnn
::
is_variable_seq_lens
(
const
module
&
m
,
instruction_ref
seq_lens
)
const
{
bool
is_var_lens
=
false
;
if
(
seq_lens
!=
prog
.
end
())
if
(
seq_lens
!=
m
.
end
())
{
if
(
seq_lens
->
can_eval
())
{
...
...
@@ -1296,12 +1290,12 @@ bool rewrite_rnn::is_variable_seq_lens(const module& prog, instruction_ref seq_l
}
std
::
size_t
rewrite_rnn
::
get_seq_len
(
const
module
&
prog
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
rewrite_rnn
::
get_seq_len
(
const
module
&
m
,
instruction_ref
input
,
instruction_ref
seq_lens
)
const
{
bool
is_var_lens
=
is_variable_seq_lens
(
prog
,
seq_lens
);
bool
is_var_lens
=
is_variable_seq_lens
(
m
,
seq_lens
);
auto
input_shape
=
input
->
get_shape
();
auto
length
=
input_shape
.
lens
()[
0
];
if
(
!
is_var_lens
and
seq_lens
!=
prog
.
end
())
if
(
!
is_var_lens
and
seq_lens
!=
m
.
end
())
{
auto
arg_len
=
seq_lens
->
eval
();
std
::
vector
<
std
::
size_t
>
vec_lens
;
...
...
@@ -1312,33 +1306,33 @@ rewrite_rnn::get_seq_len(const module& prog, instruction_ref input, instruction_
return
length
;
}
instruction_ref
rewrite_rnn
::
replace_last_hs_output
(
module
&
prog
,
instruction_ref
rewrite_rnn
::
replace_last_hs_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
last_hs_output
,
op
::
rnn_direction
dirct
)
const
{
bool
variable_seq_len
=
is_variable_seq_lens
(
prog
,
seq_lens
);
bool
variable_seq_len
=
is_variable_seq_lens
(
m
,
seq_lens
);
instruction_ref
result_ins
{};
if
(
variable_seq_len
)
{
result_ins
=
prog
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"rnn_var_sl_shift_output"
,
{{
"output_name"
,
"hidden_states"
},
{
"direction"
,
dirct
}}),
ins
,
seq_lens
);
prog
.
replace_instruction
(
ins
,
result_ins
);
result_ins
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"rnn_var_sl_shift_output"
,
{{
"output_name"
,
"hidden_states"
},
{
"direction"
,
dirct
}}),
ins
,
seq_lens
);
m
.
replace_instruction
(
ins
,
result_ins
);
auto
hs_outputs
=
find_all
(
result_ins
->
outputs
(),
[
&
](
auto
i
)
{
return
i
->
name
()
==
"rnn_last_hs_output"
;
});
for
(
auto
&
hs_out
:
hs_outputs
)
{
auto
inputs
=
hs_out
->
inputs
();
prog
.
replace_instruction
(
hs_out
,
make_op
(
"rnn_var_sl_last_output"
,
{{
"direction"
,
dirct
}}),
inputs
.
front
(),
seq_lens
);
m
.
replace_instruction
(
hs_out
,
make_op
(
"rnn_var_sl_last_output"
,
{{
"direction"
,
dirct
}}),
inputs
.
front
(),
seq_lens
);
}
}
else
...
...
@@ -1348,7 +1342,7 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
for
(
auto
&
hs_out
:
hs_outputs
)
{
prog
.
replace_instruction
(
hs_out
,
last_hs_output
);
m
.
replace_instruction
(
hs_out
,
last_hs_output
);
}
result_ins
=
ins
;
...
...
@@ -1357,14 +1351,14 @@ instruction_ref rewrite_rnn::replace_last_hs_output(module& prog,
return
result_ins
;
}
void
rewrite_rnn
::
replace_last_cell_output
(
module
&
prog
,
void
rewrite_rnn
::
replace_last_cell_output
(
module
&
m
,
instruction_ref
ins
,
instruction_ref
seq_lens
,
instruction_ref
cell_outputs
,
instruction_ref
last_cell_output
,
op
::
rnn_direction
dirct
)
const
{
bool
variable_seq_len
=
is_variable_seq_lens
(
prog
,
seq_lens
);
bool
variable_seq_len
=
is_variable_seq_lens
(
m
,
seq_lens
);
auto
ins_outputs
=
find_all
(
ins
->
outputs
(),
[
&
](
auto
i
)
{
return
i
->
name
()
==
"rnn_last_cell_output"
;
});
...
...
@@ -1372,7 +1366,7 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
if
(
!
ins_outputs
.
empty
())
{
cell_outputs
=
prog
.
insert_instruction
(
cell_outputs
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"rnn_var_sl_shift_output"
,
{{
"output_name"
,
"cell_outputs"
},
{
"direction"
,
dirct
}}),
...
...
@@ -1382,10 +1376,10 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
for
(
auto
co
:
ins_outputs
)
{
prog
.
replace_instruction
(
co
,
make_op
(
"rnn_var_sl_last_output"
,
{{
"direction"
,
dirct
}}),
cell_outputs
,
seq_lens
);
m
.
replace_instruction
(
co
,
make_op
(
"rnn_var_sl_last_output"
,
{{
"direction"
,
dirct
}}),
cell_outputs
,
seq_lens
);
}
}
// replace the rnn_last_cell_output with the last_cell_output. The while
...
...
@@ -1394,18 +1388,18 @@ void rewrite_rnn::replace_last_cell_output(module& prog,
{
for
(
auto
co
:
ins_outputs
)
{
prog
.
replace_instruction
(
co
,
last_cell_output
);
m
.
replace_instruction
(
co
,
last_cell_output
);
}
}
}
instruction_ref
rewrite_rnn
::
pad_hidden_states
(
module
&
prog
,
instruction_ref
rewrite_rnn
::
pad_hidden_states
(
module
&
m
,
instruction_ref
seq
,
instruction_ref
seq_lens
,
instruction_ref
hs
)
const
{
auto
max_seq_len
=
seq
->
get_shape
().
lens
()[
0
];
auto
seq_len
=
get_seq_len
(
prog
,
seq
,
seq_lens
);
auto
seq_len
=
get_seq_len
(
m
,
seq
,
seq_lens
);
// condition of all sequence are of the same length and
// less than max_seq_len, we need to append the hs outputs
...
...
@@ -1417,23 +1411,13 @@ instruction_ref rewrite_rnn::pad_hidden_states(module& prog,
pad_lens
[
0
]
=
static_cast
<
std
::
size_t
>
(
max_seq_len
-
seq_len
);
shape
pad_s
{
s
.
type
(),
pad_lens
};
std
::
vector
<
float
>
pad_data
(
pad_s
.
elements
(),
0.0
f
);
auto
pl
=
prog
.
add_literal
(
pad_s
,
pad_data
.
begin
(),
pad_data
.
end
());
hs_padded
=
prog
.
insert_instruction
(
std
::
next
(
hs
),
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
hs
,
pl
);
prog
.
replace_instruction
(
hs
,
hs_padded
);
auto
pl
=
m
.
add_literal
(
pad_s
,
pad_data
.
begin
(),
pad_data
.
end
());
hs_padded
=
m
.
insert_instruction
(
std
::
next
(
hs
),
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
hs
,
pl
);
m
.
replace_instruction
(
hs
,
hs_padded
);
}
return
hs_padded
;
}
namespace
op
{
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
rnn_direction
v
)
{
std
::
vector
<
std
::
string
>
rnn_direction_str
=
{
"forward"
,
"reverse"
,
"bidirectional"
};
os
<<
rnn_direction_str
[
static_cast
<
std
::
underlying_type
<
rnn_direction
>::
type
>
(
v
)];
return
os
;
}
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/schedule.cpp
View file @
11e155c2
...
...
@@ -42,7 +42,7 @@ struct stream_info
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
iweights
;
ins_dep_map
mod_implicit_deps
;
void
calc_implicit_deps
(
const
module
&
p
)
{
mod_implicit_deps
=
p
.
calc_implicit_deps
();
}
void
calc_implicit_deps
(
const
module
&
m
)
{
mod_implicit_deps
=
m
.
calc_implicit_deps
();
}
void
accumulate_weights
(
instruction_ref
last
,
const
schedule_model
&
model
)
{
...
...
@@ -116,15 +116,15 @@ struct stream_info
}
};
std
::
size_t
assign_streams
(
module
&
p
,
std
::
size_t
n
)
std
::
size_t
assign_streams
(
module
&
m
,
std
::
size_t
n
)
{
assert
(
n
>
0
);
partition
critical
;
std
::
unordered_map
<
instruction_ref
,
std
::
deque
<
partition
>>
partitions
;
partitions
.
reserve
(
weights
.
size
());
fix
([
&
](
auto
self
,
auto
ins
,
auto
&
part
)
{
assert
(
not
is_end
(
ins
,
p
.
end
()));
if
(
not
p
.
has_instruction
(
ins
))
assert
(
not
is_end
(
ins
,
m
.
end
()));
if
(
not
m
.
has_instruction
(
ins
))
return
;
if
(
contains
(
partitions
,
ins
))
return
;
...
...
@@ -151,8 +151,8 @@ struct stream_info
}
}
// Sort instructions
p
.
move_instruction
(
ins
,
p
.
end
());
})(
std
::
prev
(
p
.
end
()),
critical
);
m
.
move_instruction
(
ins
,
m
.
end
());
})(
std
::
prev
(
m
.
end
()),
critical
);
// Set the critical partition to stream 0
set_stream
(
critical
,
0
);
...
...
@@ -197,13 +197,13 @@ struct stream_info
}
};
void
sort
(
module
&
p
,
std
::
size_t
)
void
sort
(
module
&
m
,
std
::
size_t
)
{
std
::
set
<
weight_ins
,
compare_weight_ins
>
children
;
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
visited
;
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
auto
mw
=
this
->
weights
.
at
(
last
);
auto
nw
=
mw
/
(
p
.
size
()
+
1
);
auto
nw
=
mw
/
(
m
.
size
()
+
1
);
auto
add_child
=
[
&
](
auto
ins
)
{
auto
x
=
1
+
(
mw
-
this
->
weights
.
at
(
ins
))
/
(
nw
+
1
);
auto
w
=
x
*
this
->
iweights
.
at
(
ins
);
...
...
@@ -222,10 +222,10 @@ struct stream_info
// Pop the first element
auto
top
=
children
.
begin
()
->
second
;
children
.
erase
(
children
.
begin
());
p
.
move_instruction
(
top
,
p
.
begin
());
m
.
move_instruction
(
top
,
m
.
begin
());
for
(
auto
ins
:
top
->
inputs
())
{
if
(
not
p
.
has_instruction
(
ins
))
if
(
not
m
.
has_instruction
(
ins
))
continue
;
add_child
(
ins
);
}
...
...
@@ -234,7 +234,7 @@ struct stream_info
{
for
(
auto
ins
:
mod_implicit_deps
.
at
(
top
))
{
assert
(
p
.
has_instruction
(
ins
));
assert
(
m
.
has_instruction
(
ins
));
add_child
(
ins
);
}
}
...
...
@@ -242,12 +242,12 @@ struct stream_info
// move dangling parameter to the front so as not be removed
auto
ins
=
std
::
next
(
last
);
while
(
ins
!=
p
.
end
())
while
(
ins
!=
m
.
end
())
{
auto
next
=
std
::
next
(
ins
);
if
(
ins
->
name
()
==
"@param"
)
{
p
.
move_instruction
(
ins
,
p
.
begin
());
m
.
move_instruction
(
ins
,
m
.
begin
());
}
ins
=
next
;
}
...
...
@@ -364,18 +364,18 @@ struct stream_info
}
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
find_concurrent_instructions
(
module
&
p
)
const
find_concurrent_instructions
(
module
&
m
)
const
{
std
::
unordered_map
<
instruction_ref
,
std
::
vector
<
std
::
vector
<
instruction_ref
>>>
result
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
merge_from
;
dominator_info
di
=
compute_dominator
(
p
);
result
.
reserve
(
p
.
size
());
merge_from
.
reserve
(
p
.
size
());
for
(
auto
ins
:
reverse_iterator_for
(
p
))
dominator_info
di
=
compute_dominator
(
m
);
result
.
reserve
(
m
.
size
());
merge_from
.
reserve
(
m
.
size
());
for
(
auto
ins
:
reverse_iterator_for
(
m
))
{
for
(
auto
&&
arg
:
ins
->
outputs
())
{
if
(
not
p
.
has_instruction
(
arg
))
if
(
not
m
.
has_instruction
(
arg
))
continue
;
if
(
is_merge_point
(
arg
))
merge_from
[
ins
].
insert
(
arg
);
...
...
@@ -415,18 +415,18 @@ struct stream_info
}
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
get_conflicts
(
module
&
p
)
get_conflicts
(
module
&
m
)
{
using
conflict_table_type
=
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
instruction_ref
>>
;
conflict_table_type
conflict_table
;
auto
concur_ins
=
this
->
find_concurrent_instructions
(
p
);
auto
concur_ins
=
this
->
find_concurrent_instructions
(
m
);
// Compute an index for each instruction
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2index
;
std
::
size_t
index_total
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
ins2index
[
ins
]
=
index_total
++
;
std
::
vector
<
conflict_table_type
>
thread_conflict_tables
(
...
...
@@ -507,21 +507,21 @@ struct stream_info
}
};
void
schedule
::
apply
(
module
&
p
)
const
void
schedule
::
apply
(
module
&
m
)
const
{
if
(
not
enable
)
return
;
stream_info
si
;
si
.
calc_implicit_deps
(
p
);
auto
last
=
std
::
prev
(
p
.
end
());
si
.
calc_implicit_deps
(
m
);
auto
last
=
std
::
prev
(
m
.
end
());
si
.
accumulate_weights
(
last
,
model
);
auto
nstreams
=
si
.
assign_streams
(
p
,
model
.
concurrency
());
si
.
sort
(
p
,
model
.
concurrency
());
auto
nstreams
=
si
.
assign_streams
(
m
,
model
.
concurrency
());
si
.
sort
(
m
,
model
.
concurrency
());
if
(
enabled
(
MIGRAPHX_TRACE_COMPILE
{})
or
enabled
(
MIGRAPHX_TRACE_SCHEDULE
{}))
{
p
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
m
.
annotate
(
std
::
cout
,
[
&
](
auto
ins
)
{
if
(
ins
->
name
()
==
"@param"
and
not
contains
(
si
.
weights
,
ins
))
return
;
...
...
@@ -548,9 +548,9 @@ void schedule::apply(module& p) const
std
::
unordered_map
<
instruction_ref
,
std
::
size_t
>
ins2wait
;
std
::
unordered_map
<
std
::
size_t
,
std
::
unordered_set
<
std
::
size_t
>>
waited_for
;
std
::
unordered_map
<
instruction_ref
,
std
::
unordered_set
<
std
::
size_t
>>
ins2waited
;
ins2wait
.
reserve
(
p
.
size
());
ins2waited
.
reserve
(
p
.
size
());
for
(
auto
ins
:
iterator_for
(
p
))
ins2wait
.
reserve
(
m
.
size
());
ins2waited
.
reserve
(
m
.
size
());
for
(
auto
ins
:
iterator_for
(
m
))
{
// Only schedule instructions that have a stream
if
(
not
si
.
has_stream
(
ins
))
...
...
@@ -559,7 +559,7 @@ void schedule::apply(module& p) const
// Schedule instruction on the stream
auto
stream
=
si
.
get_stream
(
ins
);
assert
(
stream
<
model
.
concurrency
());
model
.
sched
(
p
,
ins
,
stream
);
model
.
sched
(
m
,
ins
,
stream
);
// Insert wait instructions
if
(
si
.
is_merge_point
(
ins
,
stream
))
{
...
...
@@ -572,14 +572,14 @@ void schedule::apply(module& p) const
if
(
not
contains
(
ins2wait
,
i
))
{
ins2wait
[
i
]
=
wait_id
;
model
.
record
(
p
,
i
,
wait_id
);
model
.
record
(
m
,
i
,
wait_id
);
wait_id
++
;
}
auto
w
=
ins2wait
.
at
(
i
);
// If we already waited for the event on this stream then dont
// insert another wait event
if
(
not
contains
(
waited_for
[
stream
],
w
))
model
.
wait
(
p
,
ins
,
w
);
model
.
wait
(
m
,
ins
,
w
);
// Store the event as waited
waited_for
[
stream
].
insert
(
w
);
// Store all wait events that have been waited on prior to the recorded instruction
...
...
@@ -594,7 +594,7 @@ void schedule::apply(module& p) const
}
// Add memory conflicts
auto
conflict_table
=
si
.
get_conflicts
(
p
);
auto
conflict_table
=
si
.
get_conflicts
(
m
);
for
(
auto
&&
ip
:
conflict_table
)
{
if
(
ip
.
second
.
empty
())
...
...
@@ -602,7 +602,7 @@ void schedule::apply(module& p) const
std
::
vector
<
instruction_ref
>
args
;
args
.
push_back
(
ip
.
first
);
args
.
insert
(
args
.
end
(),
ip
.
second
.
begin
(),
ip
.
second
.
end
());
p
.
insert_instruction
(
std
::
next
(
ip
.
first
),
make_op
(
"identity"
),
args
);
m
.
insert_instruction
(
std
::
next
(
ip
.
first
),
make_op
(
"identity"
),
args
);
}
}
...
...
src/shape.cpp
100755 → 100644
View file @
11e155c2
...
...
@@ -86,6 +86,8 @@ struct shape_impl
return
std
::
accumulate
(
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
};
const
std
::
vector
<
shape
::
type_t
>&
shape
::
types
()
...
...
@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
shape
::
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
)
...
...
@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return
this
->
with_lens
(
this
->
type
(),
l
);
}
shape
shape
::
with_type
(
type_t
t
)
const
{
auto
c
=
impl
->
copy
();
c
->
m_type
=
t
;
return
{
c
};
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
...
...
src/simplify_algebra.cpp
View file @
11e155c2
...
...
@@ -42,7 +42,7 @@ struct find_mul_conv
match
::
name
(
"broadcast"
).
bind
(
"a"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
conv_ins
=
r
.
instructions
[
"conv"
];
...
...
@@ -53,14 +53,14 @@ struct find_mul_conv
if
(
broadcast_op
.
axis
!=
1
)
return
;
auto
new_a
=
p
.
insert_instruction
(
auto
new_a
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
0
},
{
"out_lens"
,
w_ins
->
get_shape
().
lens
()}}),
a_ins
->
inputs
().
front
());
auto
new_mul
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
new_a
,
w_ins
);
auto
new_conv
=
p
.
insert_instruction
(
auto
new_mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
new_a
,
w_ins
);
auto
new_conv
=
m
.
insert_instruction
(
ins
,
conv_ins
->
get_operator
(),
conv_ins
->
inputs
().
front
(),
new_mul
);
p
.
replace_instruction
(
ins
,
new_conv
);
m
.
replace_instruction
(
ins
,
new_conv
);
}
};
...
...
@@ -80,7 +80,7 @@ struct find_mul_slice_conv
match
::
name
(
"broadcast"
)(
match
::
is_constant
()).
bind
(
"a"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
slice_ins
=
r
.
instructions
[
"slice"
];
...
...
@@ -116,38 +116,38 @@ struct find_mul_slice_conv
auto
w_slice_op
=
slice_op
;
w_slice_op
.
axes
=
{
0
};
auto
slice_w_ins
=
p
.
insert_instruction
(
ins
,
w_slice_op
,
w_ins
);
auto
slice_w_ins
=
m
.
insert_instruction
(
ins
,
w_slice_op
,
w_ins
);
auto
new_a
=
p
.
insert_instruction
(
auto
new_a
=
m
.
insert_instruction
(
ins
,
make_op
(
"broadcast"
,
{{
"axis"
,
0
},
{
"out_lens"
,
slice_w_ins
->
get_shape
().
lens
()}}),
a_ins
->
inputs
().
front
());
auto
new_mul
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
new_a
,
slice_w_ins
);
auto
new_mul
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
new_a
,
slice_w_ins
);
std
::
vector
<
instruction_ref
>
sliced_weights
;
if
(
slice_op
.
starts
.
front
()
!=
0
)
sliced_weights
.
push_back
(
p
.
insert_instruction
(
sliced_weights
.
push_back
(
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
slice_op
.
starts
}}),
w_ins
));
sliced_weights
.
push_back
(
new_mul
);
int64_t
end_axis
=
w_ins
->
get_shape
().
lens
().
at
(
0
);
if
(
slice_op
.
ends
.
front
()
!=
end_axis
)
sliced_weights
.
push_back
(
p
.
insert_instruction
(
sliced_weights
.
push_back
(
m
.
insert_instruction
(
ins
,
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
slice_op
.
ends
},
{
"ends"
,
{
end_axis
}}}),
w_ins
));
auto
new_weights
=
p
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
sliced_weights
);
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
sliced_weights
);
auto
new_conv
=
p
.
insert_instruction
(
auto
new_conv
=
m
.
insert_instruction
(
ins
,
conv_ins
->
get_operator
(),
conv_ins
->
inputs
().
front
(),
new_weights
);
assert
(
conv_ins
->
get_shape
()
==
new_conv
->
get_shape
());
auto
slice1
=
p
.
insert_instruction
(
ins
,
slice_op
,
new_conv
);
auto
slice1
=
m
.
insert_instruction
(
ins
,
slice_op
,
new_conv
);
assert
(
ins
->
get_shape
().
lens
()
==
slice1
->
get_shape
().
lens
());
p
.
replace_instruction
(
ins
,
slice1
);
m
.
replace_instruction
(
ins
,
slice1
);
// TODO: Check each slice doesn't overlap and that it occurs after slice_ins
auto
outputs
=
conv_ins
->
outputs
();
for
(
auto
output
:
outputs
)
...
...
@@ -171,7 +171,7 @@ struct find_mul_add
match
::
is_constant
().
bind
(
"a"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_ins
=
r
.
instructions
[
"a"
];
...
...
@@ -179,9 +179,9 @@ struct find_mul_add
auto
x_ins
=
r
.
instructions
[
"x"
];
assert
(
x_ins
!=
b_ins
);
auto
ax_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
x_ins
);
auto
ab_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
make_op
(
"add"
),
ax_ins
,
ab_ins
);
auto
ax_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
x_ins
);
auto
ab_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
b_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
ax_ins
,
ab_ins
);
}
};
...
...
@@ -193,15 +193,15 @@ struct find_add_lit_broadcast
match
::
either_arg
(
0
,
1
)(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
lit_broadcast
().
bind
(
"b"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
a_ins
=
r
.
instructions
[
"a"
];
auto
b_ins
=
r
.
instructions
[
"b"
];
auto
sumab
=
p
.
insert_instruction
(
ins
,
make_op
(
"add"
),
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
make_op
(
"add"
),
x_ins
,
sumab
);
auto
sumab
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
a_ins
,
b_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
x_ins
,
sumab
);
}
};
...
...
@@ -213,7 +213,7 @@ struct find_double_add_lit_broadcast
match
::
args
(
op_lit_broadcast
(
"add"
,
"a"
,
"x"
),
op_lit_broadcast
(
"add"
,
"b"
,
"y"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
...
...
@@ -228,17 +228,17 @@ struct find_double_add_lit_broadcast
if
(
a_ins
->
inputs
().
at
(
0
)
->
get_shape
()
!=
b_ins
->
inputs
().
at
(
0
)
->
get_shape
())
return
;
auto
op
=
a_ins
->
get_operator
();
auto
presum
=
p
.
insert_instruction
(
auto
presum
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
a_ins
->
inputs
().
at
(
0
),
b_ins
->
inputs
().
at
(
0
));
sumab
=
p
.
insert_instruction
(
ins
,
op
,
presum
);
sumab
=
m
.
insert_instruction
(
ins
,
op
,
presum
);
}
else
{
sumab
=
p
.
insert_instruction
(
ins
,
make_op
(
"add"
),
a_ins
,
b_ins
);
sumab
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
a_ins
,
b_ins
);
}
auto
sumxy
=
p
.
insert_instruction
(
ins
,
make_op
(
"add"
),
x_ins
,
y_ins
);
p
.
replace_instruction
(
ins
,
make_op
(
"add"
),
sumxy
,
sumab
);
auto
sumxy
=
m
.
insert_instruction
(
ins
,
make_op
(
"add"
),
x_ins
,
y_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
sumxy
,
sumab
);
}
};
...
...
@@ -251,7 +251,7 @@ struct find_inner_broadcast
match
::
args
(
match
::
name
(
"broadcast"
).
bind
(
"x"
),
match
::
name
(
"broadcast"
).
bind
(
"y"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
...
...
@@ -263,9 +263,9 @@ struct find_inner_broadcast
if
(
xbroadcast
.
axis
!=
ybroadcast
.
axis
)
return
;
auto
op
=
p
.
insert_instruction
(
auto
op
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
x_ins
->
inputs
().
front
(),
y_ins
->
inputs
().
front
());
p
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
m
.
replace_instruction
(
ins
,
xbroadcast
,
op
);
}
};
...
...
@@ -296,7 +296,7 @@ struct find_concat_op
return
op
.
name
()
==
"broadcast"
or
op
.
attributes
().
contains
(
"pointwise"
);
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
axis
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
()).
axis
;
...
...
@@ -330,12 +330,11 @@ struct find_concat_op
return
j
->
inputs
().
at
(
i
);
});
auto
concat
=
p
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
iaxis
}}),
inputs
);
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
iaxis
}}),
inputs
);
concats
.
push_back
(
concat
);
}
auto
y
=
p
.
insert_instruction
(
ins
,
op
,
concats
);
auto
y
=
m
.
insert_instruction
(
ins
,
op
,
concats
);
return
{
y
};
};
std
::
vector
<
instruction_ref
>
args
;
...
...
@@ -350,9 +349,9 @@ struct find_concat_op
};
group_unique
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
update_args
,
pred
);
if
(
args
.
size
()
==
1
)
p
.
replace_instruction
(
ins
,
args
.
front
());
m
.
replace_instruction
(
ins
,
args
.
front
());
else
p
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
args
);
m
.
replace_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
axis
}}),
args
);
}
};
...
...
@@ -479,14 +478,14 @@ struct find_splits
return
true
;
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
splits
=
get_splits
(
ins
);
if
(
splits
.
empty
())
return
;
for
(
const
auto
&
group
:
get_split_groups
(
p
,
splits
))
for
(
const
auto
&
group
:
get_split_groups
(
m
,
splits
))
{
auto
start
=
group
.
front
();
auto
split_front
=
splits
.
front
();
...
...
@@ -501,10 +500,10 @@ struct find_splits
std
::
next
(
group
.
begin
()),
group
.
end
(),
[
&
](
auto
i
)
{
return
i
==
start
;
}));
auto
split_idx
=
0
;
instruction_ref
c
=
p
.
end
();
instruction_ref
c
=
m
.
end
();
if
(
start
->
inputs
().
size
()
==
1
)
{
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
op
,
ins
);
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
op
,
ins
);
}
else
if
(
start
->
inputs
().
size
()
==
2
)
{
...
...
@@ -531,7 +530,7 @@ struct find_splits
return
;
for
(
auto
data
:
data_args
)
p
.
move_instructions
(
data
,
ins
);
m
.
move_instructions
(
data
,
ins
);
auto
slice_op
=
any_cast
<
op
::
slice
>
(
splits
.
front
()
->
get_operator
());
assert
(
not
slice_op
.
axes
.
empty
());
...
...
@@ -539,16 +538,16 @@ struct find_splits
return
;
auto
concat_axis
=
slice_op
.
axes
.
front
();
// TODO: Check if axises match
auto
concat
=
p
.
insert_instruction
(
auto
concat
=
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
data_args
);
std
::
vector
<
instruction_ref
>
args
;
args
.
resize
(
2
);
args
[
split_idx
]
=
ins
;
args
[
data_idx
]
=
concat
;
c
=
p
.
insert_instruction
(
std
::
next
(
ins
),
op
,
args
);
c
=
m
.
insert_instruction
(
std
::
next
(
ins
),
op
,
args
);
}
if
(
c
!=
p
.
end
())
if
(
c
!=
m
.
end
())
{
for
(
auto
i
:
group
)
{
...
...
@@ -561,11 +560,11 @@ struct find_splits
if
(
not
contains
({
"reshape"
,
"squeeze"
,
"unsqueeze"
},
output
->
name
()))
continue
;
auto
x
=
p
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
output
->
inputs
());
p
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
m
.
insert_instruction
(
output
,
make_op
(
"contiguous"
),
output
->
inputs
());
m
.
replace_instruction
(
output
,
output
->
get_operator
(),
x
);
}
p
.
replace_instruction
(
i
,
split
->
get_operator
(),
c
);
m
.
replace_instruction
(
i
,
split
->
get_operator
(),
c
);
}
}
}
...
...
@@ -580,7 +579,7 @@ struct find_split_concat
match
::
name
(
"slice"
)(
match
::
all_of
[
match
::
outputs
()](
match
::
name
(
"concat"
)))));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
...
...
@@ -620,9 +619,9 @@ struct find_split_concat
args
.
erase
(
std
::
next
(
it
),
it
+
splits
.
size
());
if
(
args
.
size
()
==
1
)
p
.
replace_instruction
(
concat
,
args
.
front
());
m
.
replace_instruction
(
concat
,
args
.
front
());
else
p
.
replace_instruction
(
concat
,
concat
->
get_operator
(),
args
);
m
.
replace_instruction
(
concat
,
concat
->
get_operator
(),
args
);
}
};
...
...
@@ -665,7 +664,7 @@ struct find_add_convs
return
x
.
stride
[
0
]
/
y
.
stride
[
0
];
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
a_conv
=
r
.
instructions
[
"a"
];
...
...
@@ -694,7 +693,7 @@ struct find_add_convs
if
(
n
==
0
)
return
;
new_op
=
a_op
;
b_input
=
p
.
insert_instruction
(
b_input
=
m
.
insert_instruction
(
ins
,
make_op
(
"step"
,
{{
"axes"
,
{
2
,
3
}},
{
"steps"
,
{
n
,
n
}}}),
b_input
);
}
else
if
(
b_op
.
stride
<
a_op
.
stride
)
...
...
@@ -703,7 +702,7 @@ struct find_add_convs
if
(
n
==
0
)
return
;
new_op
=
b_op
;
a_input
=
p
.
insert_instruction
(
a_input
=
m
.
insert_instruction
(
ins
,
make_op
(
"step"
,
{{
"axes"
,
{
2
,
3
}},
{
"steps"
,
{
n
,
n
}}}),
a_input
);
}
else
...
...
@@ -714,10 +713,10 @@ struct find_add_convs
}
auto
concat_input
=
p
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
a_input
,
b_input
);
auto
concat_weights
=
p
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
a_weights
,
b_weights
);
p
.
replace_instruction
(
ins
,
new_op
,
concat_input
,
concat_weights
);
m
.
insert_instruction
(
ins
,
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
a_weights
,
b_weights
);
m
.
replace_instruction
(
ins
,
new_op
,
concat_input
,
concat_weights
);
}
};
...
...
@@ -738,7 +737,7 @@ struct find_conv_dot_horiz_fusion
{
auto
matcher
()
const
{
return
horiz_conv_dot
();
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
...
...
@@ -786,16 +785,16 @@ struct find_conv_dot_horiz_fusion
}
for
(
auto
arg
:
args
)
p
.
move_instructions
(
arg
,
input
);
m
.
move_instructions
(
arg
,
input
);
// TODO: Check if axises match
auto
concat
=
p
.
insert_instruction
(
input
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
args
);
auto
fused
=
p
.
insert_instruction
(
std
::
next
(
input
),
op
,
input
,
concat
);
m
.
insert_instruction
(
input
,
make_op
(
"concat"
,
{{
"axis"
,
concat_axis
}}),
args
);
auto
fused
=
m
.
insert_instruction
(
std
::
next
(
input
),
op
,
input
,
concat
);
int64_t
offset
=
0
;
for
(
auto
arg
:
range
(
start
,
last
))
{
int64_t
len
=
arg
->
get_shape
().
lens
()[
axis
];
p
.
replace_instruction
(
m
.
replace_instruction
(
arg
,
make_op
(
"slice"
,
{{
"axes"
,
{
axis
}},
{
"starts"
,
{
offset
}},
{
"ends"
,
{
offset
+
len
}}}),
...
...
@@ -816,16 +815,16 @@ struct find_div_const
return
match
::
name
(
"div"
)(
match
::
arg
(
1
)(
match
::
is_constant
().
bind
(
"c"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
recip
=
p
.
insert_instruction
(
std
::
next
(
c_ins
),
make_op
(
"recip"
),
c_ins
);
auto
recip
=
m
.
insert_instruction
(
std
::
next
(
c_ins
),
make_op
(
"recip"
),
c_ins
);
auto
args
=
ins
->
inputs
();
p
.
replace_instruction
(
ins
,
make_op
(
"mul"
),
args
.
front
(),
recip
);
m
.
replace_instruction
(
ins
,
make_op
(
"mul"
),
args
.
front
(),
recip
);
}
};
...
...
@@ -836,16 +835,16 @@ struct find_sub_const
return
match
::
name
(
"sub"
)(
match
::
arg
(
1
)(
match
::
is_constant
().
bind
(
"c"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
neg
=
p
.
insert_instruction
(
std
::
next
(
c_ins
),
make_op
(
"neg"
),
c_ins
);
auto
neg
=
m
.
insert_instruction
(
std
::
next
(
c_ins
),
make_op
(
"neg"
),
c_ins
);
auto
args
=
ins
->
inputs
();
p
.
replace_instruction
(
ins
,
make_op
(
"add"
),
args
.
front
(),
neg
);
m
.
replace_instruction
(
ins
,
make_op
(
"add"
),
args
.
front
(),
neg
);
}
};
...
...
@@ -857,12 +856,12 @@ struct find_rsqrt
match
::
name
(
"sqrt"
)(
match
::
used_once
(),
match
::
args
(
match
::
any
().
bind
(
"x"
)))));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
p
.
replace_instruction
(
ins
,
make_op
(
"rsqrt"
),
x_ins
);
m
.
replace_instruction
(
ins
,
make_op
(
"rsqrt"
),
x_ins
);
}
};
...
...
@@ -882,7 +881,7 @@ struct find_split_reshape
.
bind
(
"reshape"
);
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
slc
=
r
.
instructions
[
"slice"
];
auto
rsp
=
r
.
instructions
[
"reshape"
];
...
...
@@ -937,14 +936,14 @@ struct find_split_reshape
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// insert the reshape instruction
auto
rsp_ins
=
p
.
insert_instruction
(
auto
rsp_ins
=
m
.
insert_instruction
(
std
::
next
(
input
),
make_op
(
"reshape"
,
{{
"dims"
,
rsp_out_lens
}}),
input
);
// replace the original reshape with slice
int64_t
start
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
vec_rsp
.
size
();
++
i
)
{
p
.
replace_instruction
(
m
.
replace_instruction
(
vec_rsp
[
i
],
make_op
(
"slice"
,
...
...
@@ -963,7 +962,7 @@ struct find_split_transpose
.
bind
(
"trans"
);
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
slc
=
r
.
instructions
[
"slice"
];
auto
trans
=
r
.
instructions
[
"trans"
];
...
...
@@ -989,14 +988,14 @@ struct find_split_transpose
}
// insert an transpose instruction
auto
tr
=
p
.
insert_instruction
(
auto
tr
=
m
.
insert_instruction
(
std
::
next
(
input
),
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
input
);
// compute the axis in the slice
auto
axis
=
any_cast
<
op
::
slice
>
(
slc
->
get_operator
()).
axes
.
front
();
auto
it
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
);
assert
(
it
!=
perm
.
end
());
auto
axis_new
=
static_cast
<
int64_t
>
(
std
::
distance
(
perm
.
begin
(),
it
)
)
;
int64_t
axis_new
=
std
::
distance
(
perm
.
begin
(),
it
);
for
(
auto
in
:
split_outputs
)
{
...
...
@@ -1004,7 +1003,7 @@ struct find_split_transpose
auto
starts
=
oper
.
starts
;
auto
ends
=
oper
.
ends
;
auto
tr_orig
=
in
->
outputs
().
front
();
p
.
replace_instruction
(
m
.
replace_instruction
(
tr_orig
,
make_op
(
"slice"
,
{{
"axes"
,
{
axis_new
}},
{
"starts"
,
starts
},
{
"ends"
,
ends
}}),
tr
);
...
...
@@ -1012,12 +1011,12 @@ struct find_split_transpose
}
};
void
simplify_algebra
::
apply
(
module
&
p
)
const
void
simplify_algebra
::
apply
(
module
&
m
)
const
{
// Run simplifications multiple times
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
match
::
find_matches
(
p
,
match
::
find_matches
(
m
,
find_inner_broadcast
{},
find_double_add_lit_broadcast
{},
find_add_lit_broadcast
{},
...
...
@@ -1034,7 +1033,7 @@ void simplify_algebra::apply(module& p) const
find_splits
{},
find_split_reshape
{},
find_split_transpose
{});
dead_code_elimination
{}.
apply
(
p
);
dead_code_elimination
{}.
apply
(
m
);
}
}
...
...
src/simplify_qdq.cpp
View file @
11e155c2
...
...
@@ -53,7 +53,7 @@ struct match_find_quantizable_ops
match
::
arg
(
1
)(
dequantizelinear_op
(
"x2"
,
"scale2"
)));
}
void
apply
(
module
&
m
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
qop
=
r
.
result
;
auto
q1
=
r
.
instructions
[
"x1"
];
...
...
src/simplify_reshapes.cpp
View file @
11e155c2
...
...
@@ -70,19 +70,19 @@ struct find_reshaper
match
::
any_of
[
match
::
outputs
()](
match
::
name
(
reshaper_names
())));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
std
::
vector
<
instruction_ref
>
reshapes
{
ins
};
while
(
is_reshaper
(
reshapes
.
back
()))
{
assert
(
!
reshapes
.
back
()
->
inputs
().
empty
());
assert
(
p
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
assert
(
m
.
has_instruction
(
reshapes
.
back
()
->
inputs
().
front
()));
auto
input
=
reshapes
.
back
()
->
inputs
().
front
();
reshapes
.
push_back
(
input
);
}
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
p
.
end
(),
p
.
end
()};
std
::
pair
<
instruction_ref
,
instruction_ref
>
r
{
m
.
end
(),
m
.
end
()};
for
(
auto
start
:
iterator_for
(
reshapes
))
{
auto
last
=
std
::
find_if
(
reshapes
.
rbegin
(),
reshapes
.
rend
(),
[
&
](
auto
&&
i
)
{
...
...
@@ -96,7 +96,7 @@ struct find_reshaper
}
if
(
r
.
first
!=
r
.
second
)
{
p
.
replace_instruction
(
r
.
first
,
r
.
second
);
m
.
replace_instruction
(
r
.
first
,
r
.
second
);
}
}
};
...
...
@@ -117,10 +117,10 @@ struct find_nop_reshapes
return
match
::
name
(
reshapes
)(
match
::
same_shape
(
match
::
arg
(
0
)));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
p
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
ins
->
inputs
().
front
());
}
};
...
...
@@ -132,7 +132,7 @@ struct find_transpose
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
x
=
ins
;
...
...
@@ -149,11 +149,11 @@ struct find_transpose
return
;
if
(
is_no_transpose
(
dims
))
{
p
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
m
.
replace_instruction
(
ins
,
t
->
inputs
().
front
());
}
else
{
p
.
replace_instruction
(
m
.
replace_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
dims
}}),
t
->
inputs
().
front
());
}
}
...
...
@@ -223,7 +223,7 @@ struct find_nested_slice
return
result
;
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
slice
=
ins
->
inputs
().
front
();
...
...
@@ -241,7 +241,7 @@ struct find_nested_slice
op
.
starts
.
push_back
(
pp
.
second
.
first
);
op
.
ends
.
push_back
(
pp
.
second
.
second
);
}
p
.
replace_instruction
(
ins
,
op
,
input
);
m
.
replace_instruction
(
ins
,
op
,
input
);
}
};
...
...
@@ -252,7 +252,7 @@ struct find_concat_transpose
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
transpose_shape
()));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
trans_inputs
=
ins
->
inputs
();
...
...
@@ -279,14 +279,14 @@ struct find_concat_transpose
std
::
vector
<
instruction_ref
>
inputs
;
std
::
transform
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
i
)
{
return
p
.
insert_instruction
(
return
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
permutation
}}),
i
);
});
auto
concat
=
p
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
p
.
insert_instruction
(
auto
concat
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
auto
t
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
ipermutation
}}),
concat
);
assert
(
ins
->
get_shape
().
lens
()
==
t
->
get_shape
().
lens
());
p
.
replace_instruction
(
ins
,
t
);
m
.
replace_instruction
(
ins
,
t
);
}
};
...
...
@@ -303,7 +303,7 @@ struct find_nested_concat
return
op
.
axis
;
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
axis
=
get_axis
(
ins
);
...
...
@@ -316,9 +316,8 @@ struct find_nested_concat
else
args
.
push_back
(
i
);
}
})(
ins
->
inputs
());
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
}
};
...
...
@@ -330,7 +329,7 @@ struct find_resize
match
::
args
(
match
::
name
(
"reshape"
).
bind
(
"data"
),
match
::
is_constant
().
bind
(
"ind"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins_rsp
=
r
.
instructions
[
"data"
];
...
...
@@ -418,13 +417,13 @@ struct find_resize
}
auto
in_rsp
=
ins_rsp
->
inputs
().
front
();
auto
rsp_data
=
p
.
insert_instruction
(
auto
rsp_data
=
m
.
insert_instruction
(
ins_rsp
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
in_dims
}}),
in_rsp
);
auto
mb_rsp
=
p
.
insert_instruction
(
auto
mb_rsp
=
m
.
insert_instruction
(
ins_rsp
,
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_dims
}}),
rsp_data
);
auto
std_mb
=
p
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"contiguous"
),
mb_rsp
);
auto
std_mb
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"contiguous"
),
mb_rsp
);
std
::
vector
<
int64_t
>
rsp_dims
(
out_lens
.
begin
(),
out_lens
.
end
());
p
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
rsp_dims
}}),
std_mb
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
rsp_dims
}}),
std_mb
);
}
};
...
...
@@ -437,7 +436,7 @@ struct find_where_op
match
::
is_constant
().
bind
(
"ind"
)));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
concat
=
r
.
instructions
[
"data"
];
...
...
@@ -476,11 +475,11 @@ struct find_where_op
if
(
val
)
{
p
.
replace_instruction
(
ins
,
inputs
.
at
(
0
));
m
.
replace_instruction
(
ins
,
inputs
.
at
(
0
));
}
else
{
p
.
replace_instruction
(
ins
,
inputs
.
at
(
1
));
m
.
replace_instruction
(
ins
,
inputs
.
at
(
1
));
}
}
};
...
...
@@ -497,7 +496,7 @@ struct find_reshape_cont
match
::
any
()));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
ins_cont
=
r
.
instructions
[
"cont"
];
...
...
@@ -531,11 +530,11 @@ struct find_reshape_cont
else
{
inputs
.
push_back
(
p
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
in
));
m
.
insert_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
in
));
}
}
auto
out
=
p
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
p
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
out_dims
}}),
out
);
auto
out
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
m
.
replace_instruction
(
ins
,
make_op
(
"reshape"
,
{{
"dims"
,
out_dims
}}),
out
);
}
};
...
...
@@ -565,25 +564,25 @@ struct find_transpose_contiguous_reshaper_unary
match
::
args
(
match_transpose_contiguous_reshaper
()));
}
void
apply
(
module
&
p
,
match
::
matcher_result
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
reshaper_ins
=
r
.
instructions
[
"reshaper_ins"
];
auto
trans_ins
=
r
.
instructions
[
"trans_ins"
];
auto
cont_ins
=
r
.
instructions
[
"cont_ins"
];
auto
unary_op_name
=
ins
->
get_operator
().
name
();
auto
unary_ins
=
p
.
insert_instruction
(
cont_ins
,
make_op
(
unary_op_name
),
trans_ins
);
auto
new_cont_ins
=
p
.
insert_instruction
(
cont_ins
,
make_op
(
"contiguous"
),
unary_ins
);
auto
unary_ins
=
m
.
insert_instruction
(
cont_ins
,
make_op
(
unary_op_name
),
trans_ins
);
auto
new_cont_ins
=
m
.
insert_instruction
(
cont_ins
,
make_op
(
"contiguous"
),
unary_ins
);
// older cont and reshape are removed by deadcode elimination
p
.
replace_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
new_cont_ins
);
m
.
replace_instruction
(
ins
,
reshaper_ins
->
get_operator
(),
new_cont_ins
);
}
};
void
simplify_reshapes
::
apply
(
module
&
p
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
match
::
find_matches
(
p
,
match
::
find_matches
(
m
,
find_where_op
{},
find_resize
{},
find_reshape_cont
{},
...
...
@@ -595,7 +594,7 @@ void simplify_reshapes::apply(module& p) const
find_nested_slice
{},
find_nested_concat
{},
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
p
);
dead_code_elimination
{}.
apply
(
m
);
}
}
...
...
src/targets/cpu/copy.cpp
View file @
11e155c2
...
...
@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy>
return
inputs
.
at
(
1
);
}
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
...
...
src/targets/cpu/gather.cpp
View file @
11e155c2
...
...
@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather>
}
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
std
::
size_t
nelements
=
output_shape
.
elements
();
...
...
src/targets/cpu/include/migraphx/cpu/parallel.hpp
View file @
11e155c2
...
...
@@ -7,7 +7,16 @@
#ifdef MIGRAPHX_DISABLE_OMP
#include <migraphx/par_for.hpp>
#else
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
#include <omp.h>
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif
namespace
migraphx
{
...
...
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
View file @
11e155c2
...
...
@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool
is_vectorizable
(
const
Xs
&
...
xs
)
{
return
all_of
({
xs
...},
[](
const
auto
&
s
)
{
if
(
s
.
standard
()
and
(
s
.
lens
().
back
()
%
N
)
==
0
)
return
true
;
if
(
s
.
broadcasted
())
...
...
@@ -320,11 +319,10 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
auto
s
=
inputs
.
at
(
0
);
const
auto
&
s
=
inputs
.
at
(
0
);
return
{
s
.
type
(),
s
.
lens
()};
}
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
...
...
@@ -358,12 +356,11 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
auto
s
=
inputs
.
at
(
0
);
const
auto
&
s
=
inputs
.
at
(
0
);
return
{
s
.
type
(),
s
.
lens
()};
}
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
...
...
src/targets/cpu/lowering.cpp
View file @
11e155c2
...
...
@@ -223,7 +223,7 @@ struct cpu_unary2 : auto_register_op<cpu_unary2<Op>>
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
const
auto
&
s
=
inputs
.
at
(
0
);
return
{
s
.
type
(),
s
.
lens
()};
}
...
...
@@ -352,7 +352,7 @@ struct cpu_apply
std
::
transform
(
bind_inputs
.
begin
(),
bind_inputs
.
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
const
auto
&
s
)
{
return
r
.
instructions
.
at
(
s
)
;
});
[
&
](
const
auto
&
s
)
{
return
r
.
instructions
[
s
]
;
});
inputs
.
push_back
(
this
->
insert_allocation
(
ins
,
ins
->
get_shape
()));
modl
->
replace_instruction
(
ins
,
op
,
inputs
);
});
...
...
@@ -460,11 +460,6 @@ struct cpu_apply
if
(
has_op
(
"dnnl::pooling"
)
and
ins
->
get_shape
().
type
()
==
shape
::
type_t
::
float_type
and
not
v
[
"ceil_mode"
].
to
<
bool
>
())
return
replace
(
ins
,
make_op
(
"dnnl::pooling"
,
op
.
to_value
()));
std
::
string
mode
=
v
[
"mode"
].
to
<
std
::
string
>
();
if
(
mode
==
"max"
)
return
replace
(
ins
,
make_op
(
"cpu::pooling_max"
,
v
));
else
if
(
mode
==
"average"
)
return
replace
(
ins
,
make_op
(
"cpu::pooling_average"
,
v
));
return
ins
;
}
...
...
src/targets/cpu/pooling.cpp
View file @
11e155c2
...
...
@@ -11,125 +11,14 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
cpu
{
struct
max_pool
{
static
std
::
string
name
()
{
return
"max"
;
}
template
<
class
T
>
static
T
start
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
static
double
apply
(
double
x
,
double
y
)
{
double
m
=
std
::
max
(
x
,
y
);
return
(
m
);
}
static
double
final
(
double
x
,
std
::
size_t
)
{
return
(
x
);
}
};
struct
avg_pool
{
static
std
::
string
name
()
{
return
"average"
;
}
template
<
class
T
>
static
double
start
()
{
return
0.0
;
}
static
double
apply
(
double
x
,
double
y
)
{
return
x
+
y
;
}
static
double
final
(
double
x
,
std
::
size_t
y
)
{
return
(
y
==
0
)
?
0.0
:
(
x
/
y
);
}
};
template
<
class
Op
>
struct
cpu_pooling
:
auto_register_op
<
cpu_pooling
<
Op
>>
{
cpu_pooling
()
=
default
;
cpu_pooling
(
op
::
pooling
pop
)
:
op
(
std
::
move
(
pop
))
{}
op
::
pooling
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"cpu::pooling_"
+
Op
::
name
();
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
return
op
.
normalize_compute_shape
(
inputs
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
visit_all
(
args
.
back
(),
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
auto
in_s
=
input
.
get_shape
();
auto
in_lens
=
in_s
.
lens
();
std
::
vector
<
std
::
size_t
>
vec_len
(
in_lens
.
begin
()
+
2
,
in_lens
.
end
());
par_for
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
idx_o
=
output_shape
.
multi
(
i
);
auto
n_dim
=
idx_o
.
size
();
std
::
vector
<
std
::
size_t
>
win_start
;
std
::
vector
<
std
::
size_t
>
win_size
;
for
(
std
::
size_t
dim
=
2
;
dim
<
n_dim
;
++
dim
)
{
auto
d_2
=
dim
-
2
;
int
start
=
static_cast
<
int
>
(
idx_o
[
dim
]
*
op
.
stride
[
d_2
])
-
static_cast
<
int
>
(
op
.
padding
[
d_2
]);
int
end
=
std
::
min
(
start
+
op
.
lengths
[
d_2
],
in_lens
[
dim
]);
start
=
std
::
max
(
start
,
0
);
win_start
.
push_back
(
start
);
win_size
.
push_back
(
end
-
start
);
}
shape
win_shape
{
output_shape
.
type
(),
win_size
};
auto
pool_size
=
win_shape
.
elements
();
double
acc
=
Op
::
template
start
<
type
>();
shape_for_each
(
win_shape
,
[
&
](
auto
idx_w
)
{
auto
idx
=
idx_o
;
std
::
transform
(
idx_w
.
begin
(),
idx_w
.
end
(),
win_start
.
begin
(),
idx
.
begin
()
+
2
,
[](
auto
ii
,
auto
jj
)
{
return
ii
+
jj
;
});
if
(
std
::
all_of
(
idx
.
begin
()
+
2
,
idx
.
end
(),
[
&
](
auto
ii
)
{
return
ii
>=
0
;
})
and
idx
<
in_lens
)
{
acc
=
Op
::
apply
(
acc
,
input
[
in_s
.
index
(
idx
)]);
}
});
output
[
i
]
=
type
(
Op
::
final
(
acc
,
pool_size
));
});
});
return
args
.
back
();
}
};
template
struct
cpu_pooling
<
avg_pool
>;
template
struct
cpu_pooling
<
max_pool
>;
struct
dnnl_pooling
:
dnnl_extend_op
<
dnnl_pooling
,
dnnl
::
pooling_forward
,
op
::
pooling
>
{
std
::
vector
<
int
>
arg_map
(
int
)
const
{
return
{
MIGRAPHX_DNNL_PREFIX
(
ARG_SRC
)};
}
dnnl
::
pooling_forward
::
desc
get_desc
(
const
std
::
unordered_map
<
int
,
dnnl
::
memory
::
desc
>&
m
)
const
{
auto
algo
=
op
.
mode
==
"max"
?
dnnl
::
algorithm
::
pooling_max
:
dnnl
::
algorithm
::
pooling_avg
;
auto
algo
=
op
.
mode
==
op
::
pooling_mode
::
max
?
dnnl
::
algorithm
::
pooling_max
:
dnnl
::
algorithm
::
pooling_avg
;
auto
kdims
=
op
.
kdims
();
std
::
vector
<
size_t
>
padding_l
(
op
.
padding
.
begin
(),
op
.
padding
.
begin
()
+
kdims
);
std
::
vector
<
size_t
>
padding_r
(
op
.
padding
.
begin
()
+
kdims
,
op
.
padding
.
end
());
...
...
@@ -145,5 +34,6 @@ struct dnnl_pooling : dnnl_extend_op<dnnl_pooling, dnnl::pooling_forward, op::po
};
}
// namespace cpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/CMakeLists.txt
View file @
11e155c2
...
...
@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif
()
include
(
Embed
)
file
(
GLOB KERNEL_FILES
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
)
...
...
@@ -93,7 +93,7 @@ add_library(migraphx_device
)
add_library
(
compile_for_gpu INTERFACE
)
target_compile_options
(
compile_for_gpu INTERFACE -std=c++17 -fno-gpu-rdc -Wno-cuda-compat -Wno-unused-command-line-argument -Xclang -fallow-half-arguments-and-returns
)
target_link_libraries
(
compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument
)
target_link_libraries
(
compile_for_gpu INTERFACE hip::device -fno-gpu-rdc -Wno-invalid-command-line-argument -Wno-unused-command-line-argument
-Wno-option-ignored
)
check_cxx_compiler_flag
(
"--cuda-host-only -fhip-lambda-host-device -x hip"
HAS_HIP_LAMBDA_HOST_DEVICE
)
if
(
HAS_HIP_LAMBDA_HOST_DEVICE
)
message
(
STATUS
"Enable -fhip-lambda-host-device"
)
...
...
@@ -114,11 +114,13 @@ foreach(KERNEL_FILE ${KERNEL_FILES})
file
(
WRITE
${
CMAKE_CURRENT_BINARY_DIR
}
/kernels/include/migraphx/kernels/
${
KERNEL_BASE_FILE
}
.cpp
"#include <migraphx/kernels/
${
KERNEL_BASE_FILE
}
.hpp>
\n
"
)
target_sources
(
kernel_file_check PRIVATE
${
CMAKE_CURRENT_BINARY_DIR
}
/kernels/include/migraphx/kernels/
${
KERNEL_BASE_FILE
}
.cpp
)
endforeach
()
target_compile_definitions
(
kernel_file_check PRIVATE -DMIGRAPHX_NLOCAL=256
)
target_include_directories
(
kernel_file_check PRIVATE $<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/>
)
target_link_libraries
(
kernel_file_check compile_for_gpu
)
rocm_clang_tidy_check
(
kernel_file_check
)
file
(
GLOB JIT_GPU_SRCS
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/jit/*.cpp
)
add_library
(
migraphx_gpu
abs.cpp
analyze_streams.cpp
...
...
@@ -129,10 +131,10 @@ add_library(migraphx_gpu
clip.cpp
code_object_op.cpp
compile_ops.cpp
compile_gen.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_pointwise.cpp
compile_roialign.cpp
compiler.cpp
concat.cpp
convert.cpp
convolution.cpp
...
...
@@ -157,6 +159,7 @@ add_library(migraphx_gpu
nonzero.cpp
pack_args.cpp
pack_int8_args.cpp
prefuse_ops.cpp
pad.cpp
pooling.cpp
quant_convolution.cpp
...
...
@@ -170,6 +173,7 @@ add_library(migraphx_gpu
target.cpp
topk.cpp
write_literals.cpp
${
JIT_GPU_SRCS
}
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
...
...
@@ -330,6 +334,12 @@ target_compile_definitions(migraphx_gpu PRIVATE
"-DMIGRAPHX_EXTRACT_KERNEL=
${
MIGRAPHX_EXTRACT_KERNEL
}
"
"-DMIGRAPHX_USE_HIPRTC=0"
)
if
(
DEFINED CMAKE_CXX_COMPILER_LAUNCHER
)
execute_process
(
COMMAND which
${
CMAKE_CXX_COMPILER_LAUNCHER
}
OUTPUT_VARIABLE MIGRAPHX_HIP_COMPILER_LAUNCHER
)
string
(
STRIP
"
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
MIGRAPHX_HIP_COMPILER_LAUNCHER
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_HIP_COMPILER_LAUNCHER=
${
MIGRAPHX_HIP_COMPILER_LAUNCHER
}
"
)
endif
()
endif
()
# Check miopen find mode api
...
...
src/targets/gpu/analyze_streams.cpp
View file @
11e155c2
...
...
@@ -28,30 +28,30 @@ struct hip_stream_model
bool
is_wait
(
migraphx
::
instruction_ref
ins
)
const
{
return
ins
->
name
()
==
"gpu::wait_event"
;
}
};
stream_model
make_stream_model
(
const
module
&
p
)
stream_model
make_stream_model
(
const
module
&
m
)
{
hip_stream_model
m
;
hip_stream_model
hs
m
;
std
::
size_t
stream
=
0
;
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"gpu::set_stream"
)
{
auto
v
=
ins
->
get_operator
().
to_value
();
stream
=
v
[
"stream"
].
to
<
std
::
size_t
>
();
m
.
max_stream
=
std
::
max
(
stream
,
m
.
max_stream
);
auto
v
=
ins
->
get_operator
().
to_value
();
stream
=
v
[
"stream"
].
to
<
std
::
size_t
>
();
hs
m
.
max_stream
=
std
::
max
(
stream
,
hs
m
.
max_stream
);
}
if
(
ins
->
get_operator
().
is_context_free
())
continue
;
if
(
contains
({
"hip::hip_allocate_memory"
,
"hip::hip_copy_literal"
,
"@param"
},
ins
->
name
()))
continue
;
m
.
ins2stream
[
ins
]
=
stream
;
hs
m
.
ins2stream
[
ins
]
=
stream
;
}
return
m
;
return
hs
m
;
}
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
p
)
std
::
vector
<
stream_race
>
analyze_streams
(
const
module
&
m
)
{
return
migraphx
::
analyze_streams
(
p
,
make_stream_model
(
p
));
return
migraphx
::
analyze_streams
(
m
,
make_stream_model
(
m
));
}
}
// namespace gpu
...
...
src/targets/gpu/compile_gen.cpp
0 → 100644
View file @
11e155c2
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gen
{
static
std
::
vector
<
std
::
size_t
>
vector_sizes
(
const
std
::
vector
<
shape
>&
inputs
)
{
// If all inputs are half then only use half2
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
type
()
==
shape
::
half_type
;
}))
return
{
2
};
return
{
4
,
2
};
}
vectorize
vectorize
::
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
max_vec_size
),
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
if
(
len
==
1
and
input
.
elements
()
>
sizes
.
front
())
return
sizes
.
front
();
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
1
;
});
return
{
*
std
::
min_element
(
max_vec_size
.
begin
(),
max_vec_size
.
end
()),
axis
};
}
std
::
string
vectorize
::
str
()
const
{
return
"vectorize<"
+
to_string
(
size
)
+
", "
+
to_string
(
axis
)
+
">()"
;
}
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
if
(
bytes
<
max_lds_bytes
)
return
{
result
};
// TODO: Try to partially preload items
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
return
{
result
};
}
std
::
string
preload
::
str
()
const
{
std
::
vector
<
std
::
string
>
bool_strs
;
std
::
transform
(
args
.
begin
(),
std
::
prev
(
args
.
end
()),
std
::
back_inserter
(
bool_strs
),
[](
bool
b
)
{
if
(
b
)
return
"true"
;
return
"false"
;
});
return
"auto_preload<false, "
+
join_strings
(
bool_strs
,
", "
)
+
">(idx)"
;
}
bool
preload
::
is_preloading
()
const
{
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
false
,
std
::
logical_or
<>
{});
}
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
permutation
=
find_permutation
(
inputs
);
auto
it
=
std
::
max_element
(
permutation
.
begin
(),
permutation
.
end
());
return
it
-
permutation
.
begin
();
}
std
::
string
make_transformer_args
(
std
::
vector
<
std
::
string
>
transformers
)
{
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/compile_hip.cpp
100755 → 100644
View file @
11e155c2
...
...
@@ -21,6 +21,8 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_DEBUG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_OPTIMIZE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_DUMP_ASM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_DUMP_SRC
);
#if MIGRAPHX_USE_HIPRTC
...
...
@@ -178,6 +180,19 @@ bool is_hip_clang_compiler()
return
result
;
}
bool
has_compiler_launcher
()
{
static
const
auto
result
=
fs
::
exists
(
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER_LAUNCHER
));
return
result
;
}
src_compiler
assemble
(
src_compiler
compiler
)
{
compiler
.
out_ext
=
".S"
;
compiler
.
flags
=
replace_string
(
compiler
.
flags
,
" -c"
,
" -S"
);
return
compiler
;
}
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
{
...
...
@@ -210,6 +225,10 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
src_compiler
compiler
;
compiler
.
flags
=
params
;
compiler
.
compiler
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER
);
#ifdef MIGRAPHX_HIP_COMPILER_LAUNCHER
if
(
has_compiler_launcher
())
compiler
.
launcher
=
MIGRAPHX_STRINGIZE
(
MIGRAPHX_HIP_COMPILER_LAUNCHER
);
#endif
if
(
is_hcc_compiler
())
compiler
.
process
=
[
&
](
const
fs
::
path
&
obj_path
)
->
fs
::
path
{
...
...
@@ -228,6 +247,22 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
MIGRAPHX_THROW
(
"Missing hsaco"
);
};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
for
(
const
auto
&
src
:
srcs
)
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
())
<<
std
::
endl
;
}
}
if
(
enabled
(
MIGRAPHX_GPU_DUMP_ASM
{}))
{
std
::
cout
<<
assemble
(
compiler
).
compile
(
srcs
).
data
()
<<
std
::
endl
;
}
return
{
compiler
.
compile
(
srcs
)};
}
...
...
@@ -238,13 +273,6 @@ std::string enum_params(std::size_t count, std::string param)
return
join_strings
(
items
,
","
);
}
std
::
size_t
compute_global
(
std
::
size_t
n
,
std
::
size_t
local
)
{
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
return
nglobal
;
}
#endif // MIGRAPHX_USE_HIPRTC
}
// namespace gpu
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
11e155c2
...
...
@@ -93,8 +93,47 @@ const std::vector<std::string>& compiler_warnings()
return
warnings
;
}
void
hip_compile_options
::
set_launch_params
(
const
value
&
v
,
const
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>&
compute_global
,
std
::
size_t
default_local
)
{
local
=
v
.
get
(
"local"
,
default_local
);
if
(
v
.
contains
(
"global"
))
global
=
v
.
at
(
"global"
).
to
<
std
::
size_t
>
();
else
global
=
compute_global
(
local
);
}
std
::
function
<
std
::
size_t
(
std
::
size_t
local
)
>
compute_global_for
(
context
&
ctx
,
std
::
size_t
n
,
std
::
size_t
over
)
{
assert
(
over
>
0
);
std
::
size_t
max_global
=
ctx
.
get_current_device
().
get_cu_count
()
*
ctx
.
get_current_device
().
get_max_workitems_per_cu
();
return
[
n
,
over
,
max_global
](
std
::
size_t
local
)
{
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
max_blocks
=
max_global
/
local
;
std
::
size_t
nglobal
=
std
::
min
(
max_blocks
*
over
,
groups
)
*
local
;
return
nglobal
;
};
}
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
)
{
assert
(
options
.
global
>
0
);
assert
(
options
.
local
>
0
);
assert
(
not
options
.
inputs
.
empty
());
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
...
...
src/targets/gpu/compile_ops.cpp
View file @
11e155c2
...
...
@@ -6,12 +6,14 @@
#include <migraphx/par_for.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/gpu/compile
_pointwise
.hpp>
#include <migraphx/gpu/compile
r
.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_COMPILE_PARALLEL
);
struct
precompile_op
{
operation
op
=
op
::
identity
{};
...
...
@@ -38,41 +40,22 @@ struct precompile_op
MIGRAPHX_REGISTER_OP
(
precompile_op
);
struct
pointwise_compiler
struct
compiled_result
{
std
::
string
name
()
const
{
return
"pointwise"
;
}
operation
apply
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
)
const
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
return
compile_pointwise
(
ctx
,
to_shapes
(
ins
->
inputs
()),
*
pm
);
}
compiler_replace
replace
;
instruction_ref
ins
;
};
using
compiler_function
=
std
::
function
<
operation
(
context
&
,
instruction_ref
,
operation
)
>
;
template
<
class
T
>
compiler_function
make_compiler_function
(
T
x
)
template
<
class
F
>
void
par_compile
(
std
::
size_t
n
,
F
f
)
{
return
{[
=
](
auto
&&
...
xs
)
{
return
x
.
apply
(
xs
...);
}};
if
(
n
==
0
)
return
;
par_for
(
n
,
n
/
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{},
n
),
f
);
}
template
<
class
...
Ts
>
std
::
unordered_map
<
std
::
string
,
compiler_function
>
make_compilers
(
Ts
...
xs
)
{
return
{{
xs
.
name
(),
make_compiler_function
(
xs
)}...};
}
struct
compiled_result
{
operation
op
;
instruction_ref
ins
;
};
void
compile_ops
::
apply
(
module
&
m
)
const
{
auto
compilers
=
make_compilers
(
pointwise_compiler
{});
std
::
vector
<
std
::
function
<
compiled_result
()
>>
compiles
;
for
(
auto
ins
:
iterator_for
(
m
))
...
...
@@ -80,15 +63,15 @@ void compile_ops::apply(module& m) const
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
continue
;
operation
preop
=
any_cast
<
precompile_op
>
(
ins
->
get_operator
()).
op
;
assert
(
contains
(
compilers
,
preop
.
name
()));
auto
c
=
compilers
[
preop
.
name
()]
;
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
return
{
c
(
*
ctx
,
ins
,
preop
),
ins
};
});
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
return
{
compile
(
*
ctx
,
ins
,
preop
),
ins
}
;
});
}
std
::
vector
<
compiled_result
>
results
(
compiles
.
size
());
par_
for
(
compiles
.
size
(),
1
,
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
par_
compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
for
(
const
auto
&
cr
:
results
)
{
m
.
replace
_instruction
(
cr
.
ins
,
cr
.
op
,
cr
.
ins
->
inputs
()
);
cr
.
replace
(
m
,
cr
.
ins
);
}
}
...
...
src/targets/gpu/compile_pointwise.cpp
deleted
100644 → 0
View file @
8a9c5bce
#include <migraphx/gpu/compile_pointwise.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void kernel(${params})
{
pointwise(${lambda}, ${args});
}
}
} // namespace migraphx
int main() {}
)__migraphx__"
;
operation
compile_pointwise
(
context
&
,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
string
&
lambda
,
const
std
::
string
&
preamble
)
{
hip_compile_options
options
;
options
.
global
=
compute_global
(
inputs
.
front
().
elements
());
options
.
local
=
1024
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
lambda
},
{
"preamble"
,
preamble
}});
return
compile_hip_code_object
(
src
,
options
);
}
operation
compile_pointwise
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
module
m
)
{
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
auto
name
=
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
));
return
compile_pointwise
((
ctx
),
inputs
,
"MIGRAPHX_LIFT("
+
name
+
")"
,
g
.
str
());
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Prev
1
…
5
6
7
8
9
10
11
12
13
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment