Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d7f3523
Commit
2d7f3523
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
rewrite the gru operator to support two outputs.
parent
1fbe8c48
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
235 additions
and
201 deletions
+235
-201
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+14
-0
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+3
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+12
-3
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+204
-193
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+2
-2
No files found.
src/include/migraphx/operators.hpp
View file @
2d7f3523
...
...
@@ -1167,6 +1167,20 @@ struct rnn_last_output
}
};
struct
gru_last_output
{
std
::
string
name
()
const
{
return
"gru_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/rewrite_gru.hpp
View file @
2d7f3523
...
...
@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program
;
/**
* Rewrite
rnn
to gemm and add.
* Rewrite
gru
to gemm
, mul,
and add.
*/
struct
rewrite_gru
{
...
...
@@ -21,14 +21,14 @@ struct rewrite_gru
void
apply
(
program
&
prog
)
const
;
private:
std
::
vector
<
instruction_ref
>
gru_
oper
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
gru_
cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
;
...
...
src/onnx/onnx.cpp
View file @
2d7f3523
...
...
@@ -732,14 +732,14 @@ struct onnx_parser
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
// second out
put
for the last hidden state
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
return
result
;
}
instruction_ref
std
::
vector
<
instruction_ref
>
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
...
...
@@ -842,9 +842,18 @@ struct onnx_parser
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
return
prog
.
add_instruction
(
std
::
vector
<
instruction_ref
>
result
;
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
// second output for last gru output
auto
last_output
=
prog
.
add_instruction
(
op
::
gru_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
return
result
;
}
void
parse_from
(
std
::
istream
&
is
)
...
...
src/rewrite_gru.cpp
View file @
2d7f3523
...
...
@@ -10,168 +10,165 @@ inline namespace MIGRAPHX_INLINE_NS {
void
rewrite_gru
::
apply
(
program
&
prog
)
const
{
instruction_ref
last_output
=
prog
.
end
();
for
(
auto
ins
:
iterator_for
(
prog
))
{
if
(
ins
->
name
()
!
=
"gru"
)
if
(
ins
->
name
()
=
=
"gru"
)
{
continue
;
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batchs
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
batchs
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
if
(
dicrt
==
op
::
gru
::
bidirectional
)
{
// forward weight
auto
uw_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uw_forward
);
auto
ur_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_forward
);
// reverse weight
auto
uw_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uw_reverse
);
auto
ur_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_reverse
);
// process bias
instruction_ref
bias_forward
,
bias_reverse
;
bias_forward
=
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batchs
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batchs
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
if
(
dicrt
==
op
::
gru
::
bidirectional
)
{
// forward bias
auto
uwb_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uwb_forward
);
// backward bias
auto
uwb_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uwb_reverse
);
}
// intial hidden state
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
>=
5
)
{
// forward
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
4
]);
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih_forward
);
// reverse
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
4
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih_reverse
);
// w weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// r weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// bias
instruction_ref
bias_forward
,
bias_reverse
;
bias_forward
=
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
}
// intial hidden state
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
arg_ih
);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_cell
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
)
?
true
:
false
;
// weight matrix
auto
w
=
args
[
1
];
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
bias
=
args
[
3
];
}
// intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
ih
=
args
.
size
()
==
6
?
args
[
5
]
:
args
[
4
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
gru_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
last_output
=
ret
[
1
];
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
}
auto
ret_forward
=
gru_oper
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
ih_forward
,
bias_forward
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_oper
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
ih_reverse
,
bias_reverse
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
// auto final_output =
// prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
)
?
true
:
false
;
// weight matrix
auto
w
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
1
]);
auto
r
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
2
]);
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
bias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
3
]);
}
// intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
>=
5
)
{
i
h
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
4
]);
}
else
// rewrite the gru_last_output operator that right after the gru
// operator. Intuitively, we can do a slice on its input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
i
f
(
ins
->
name
()
==
"gru_last_output"
)
{
if
(
last_output
!=
prog
.
end
())
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
last_output
=
prog
.
end
();
}
auto
ret
=
gru_oper
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
ih
,
bias
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_gru
::
gru_
oper
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
rewrite_gru
::
gru_
cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
{
instruction_ref
hidden_out
,
final
_out
;
instruction_ref
hidden_out
,
last
_out
;
long
seq_len
=
static_cast
<
long
>
(
input
->
get_shape
().
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
1
]);
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
migraphx
::
shape
s
(
input
->
get_shape
().
type
(),
{
input
->
get_shape
().
lens
()[
1
],
static_cast
<
std
::
size_t
>
(
hs
)});
...
...
@@ -180,122 +177,136 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
w
);
auto
twz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
w
);
auto
twr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
w
);
auto
twh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
r
);
auto
trz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
r
);
auto
trr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
r
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
instruction_ref
br_bz
,
br_br
,
br_wbh
,
br_rbh
,
br_bh
;
instruction_ref
br
cst
_bz
,
br
cst
_br
,
br
cst
_wbh
,
br
cst
_rbh
,
br
cst
_bh
;
if
(
bias
!=
prog
.
end
())
{
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
s
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
s
bias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
s
bias
);
br
cst
_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
ih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
br_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
bz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
br_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
br
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
br_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
bh
);
br
cst
_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
ih
->
get_shape
()},
bh
);
}
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xwz
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twz
);
auto
hrz
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trz
);
auto
x
whr
_z
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwz
t
,
hrz
t
);
auto
x
t_
wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wz
);
auto
h
t_
rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
s
ih
,
t
ran_
rz
);
auto
x
ht
_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wz
,
h
t_
rz
);
if
(
bias
!=
prog
.
end
())
{
x
whr
_z
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whr
_z
t
,
br_bz
);
x
ht
_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht
_z
,
br
cst
_bz
);
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
whr
_z
t
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
ht
_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xwr
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twr
);
auto
hrr
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trr
);
auto
x
whr
_r
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwr
t
,
hrr
t
);
auto
x
t_
wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wr
);
auto
h
t_
rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
s
ih
,
t
ran_
rr
);
auto
x
ht
_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wr
,
h
t_
rr
);
if
(
bias
!=
prog
.
end
())
{
x
whr
_r
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whr
_r
t
,
br_br
);
x
ht
_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht
_r
,
br
cst
_br
);
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
whr
_r
t
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
ht
_r
);
instruction_ref
x
whh_rt
;
instruction_ref
x
ht_h
;
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xwh
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwh
t
,
rt_rh
);
auto
x
t_
wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wh
);
auto
rt_ht
1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
s
ih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
1
,
t
ran_
rh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whh_rt
,
br_bh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht_h
,
br
cst
_bh
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xwh
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
i
h_rh
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trh
);
auto
x
t_
wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wh
);
auto
h
t1
_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
s
ih
,
t
ran_
rh
);
if
(
bias
!=
prog
.
end
())
{
i
h_rh
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
i
h_rh
t
,
br_rbh
);
h
t1
_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
h
t1
_rh
,
br
cst
_rbh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
i
h_rh
t
);
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwh
t
,
rt_rh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
h
t1
_rh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whh_rt
,
br_wbh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht_h
,
br
cst
_wbh
);
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
x
whh_rt
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
x
ht_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
z
1
t
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
z1t
ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
z
1
t
,
ht
);
auto
ztht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
ih
);
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
z1t
ht
,
ztht1
);
final
_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
ih
);
auto
one_minus_
zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_zt_
ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_
zt
,
ht
);
auto
zt
_
ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
s
ih
);
s
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
one_minus_zt_
ht
,
zt
_
ht1
);
last
_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
s
ih
);
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
?
final
_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
final
_out
);
?
last
_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last
_out
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
final
_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
final
_out
,
hidden_out
);
?
last
_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last
_out
,
hidden_out
);
}
seq_index
=
is_forward
?
(
seq_index
+
1
)
:
(
seq_index
-
1
);
}
std
::
vector
<
instruction_ref
>
out_args
;
out_args
.
push_back
(
hidden_out
);
out_args
.
push_back
(
final
_out
);
out_args
.
push_back
(
last
_out
);
return
out_args
;
}
...
...
src/rewrite_rnn.cpp
View file @
2d7f3523
...
...
@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
...
...
@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
}
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on
the
input to get
// operator. Intuitively, we can do a slice on
its
input to get
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
if
(
ins
->
name
()
==
"rnn_last_output"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment