Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2d7f3523
Commit
2d7f3523
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
rewrite the gru operator to support two outputs.
parent
1fbe8c48
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
235 additions
and
201 deletions
+235
-201
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+14
-0
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+3
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+12
-3
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+204
-193
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+2
-2
No files found.
src/include/migraphx/operators.hpp
View file @
2d7f3523
...
@@ -1167,6 +1167,20 @@ struct rnn_last_output
...
@@ -1167,6 +1167,20 @@ struct rnn_last_output
}
}
};
};
struct
gru_last_output
{
std
::
string
name
()
const
{
return
"gru_last_output"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
dims
=
inputs
[
0
].
lens
();
// remove the first dimension, remaing are output shape
dims
.
erase
(
dims
.
begin
());
return
{
inputs
[
0
].
type
(),
dims
};
}
};
}
// namespace op
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/include/migraphx/rewrite_gru.hpp
View file @
2d7f3523
...
@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -13,7 +13,7 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
program
;
struct
program
;
/**
/**
* Rewrite
rnn
to gemm and add.
* Rewrite
gru
to gemm
, mul,
and add.
*/
*/
struct
rewrite_gru
struct
rewrite_gru
{
{
...
@@ -21,14 +21,14 @@ struct rewrite_gru
...
@@ -21,14 +21,14 @@ struct rewrite_gru
void
apply
(
program
&
prog
)
const
;
void
apply
(
program
&
prog
)
const
;
private:
private:
std
::
vector
<
instruction_ref
>
gru_
oper
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
gru_
cell
(
bool
is_forward
,
program
&
prog
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
wx
,
instruction_ref
wh
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
;
operation
&
actv_func2
)
const
;
...
...
src/onnx/onnx.cpp
View file @
2d7f3523
...
@@ -732,14 +732,14 @@ struct onnx_parser
...
@@ -732,14 +732,14 @@ struct onnx_parser
std
::
move
(
args
));
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
// second out
put
for the last hidden state
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
auto
last_output
=
prog
.
add_instruction
(
op
::
rnn_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
result
.
push_back
(
last_output
);
return
result
;
return
result
;
}
}
instruction_ref
std
::
vector
<
instruction_ref
>
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
...
@@ -842,9 +842,18 @@ struct onnx_parser
...
@@ -842,9 +842,18 @@ struct onnx_parser
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
}
return
prog
.
add_instruction
(
std
::
vector
<
instruction_ref
>
result
;
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
op
::
gru
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
// second output for last gru output
auto
last_output
=
prog
.
add_instruction
(
op
::
gru_last_output
{},
hidden_states
);
result
.
push_back
(
last_output
);
return
result
;
}
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
...
...
src/rewrite_gru.cpp
View file @
2d7f3523
...
@@ -10,168 +10,165 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,168 +10,165 @@ inline namespace MIGRAPHX_INLINE_NS {
void
rewrite_gru
::
apply
(
program
&
prog
)
const
void
rewrite_gru
::
apply
(
program
&
prog
)
const
{
{
instruction_ref
last_output
=
prog
.
end
();
for
(
auto
ins
:
iterator_for
(
prog
))
for
(
auto
ins
:
iterator_for
(
prog
))
{
{
if
(
ins
->
name
()
!
=
"gru"
)
if
(
ins
->
name
()
=
=
"gru"
)
{
{
continue
;
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
}
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
auto
args
=
ins
->
inputs
();
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
shape
seq_shape
=
args
[
0
]
->
get_shape
();
auto
args
=
ins
->
inputs
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batchs
=
seq_shape
.
lens
()[
1
];
shape
seq_shape
=
args
[
0
]
->
get_shape
();
shape
::
type_t
type
=
seq_shape
.
type
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batchs
,
hidden_size
}};
std
::
size_t
batchs
=
seq_shape
.
lens
()[
1
];
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
batchs
,
hidden_size
}};
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
if
(
dicrt
==
op
::
gru
::
bidirectional
)
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
if
(
dicrt
==
op
::
gru
::
bidirectional
)
{
// forward weight
auto
uw_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uw_forward
);
auto
ur_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_forward
);
// reverse weight
auto
uw_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uw_reverse
);
auto
ur_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_reverse
);
// process bias
instruction_ref
bias_forward
,
bias_reverse
;
bias_forward
=
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
{
// forward bias
// w weight matrix
auto
uwb_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uwb_forward
);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// backward bias
// r weight matrix
auto
uwb_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uwb_reverse
);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
}
// bias
// intial hidden state
instruction_ref
bias_forward
,
bias_reverse
;
instruction_ref
ih_forward
,
ih_reverse
;
bias_forward
=
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
5
)
if
(
args
.
size
()
>=
4
)
{
{
// forward
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
4
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih_forward
);
}
// reverse
// intial hidden state
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
4
]);
instruction_ref
ih_forward
,
ih_reverse
;
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih_reverse
);
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
arg_ih
);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_cell
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
else
{
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
)
?
true
:
false
;
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
// weight matrix
auto
w
=
args
[
1
];
auto
r
=
args
[
2
];
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
bias
=
args
[
3
];
}
// intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
ih
=
args
.
size
()
==
6
?
args
[
5
]
:
args
[
4
];
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret
=
gru_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
last_output
=
ret
[
1
];
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
}
}
auto
ret_forward
=
gru_oper
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
ih_forward
,
bias_forward
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_oper
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
ih_reverse
,
bias_reverse
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
// auto final_output =
// prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
)
?
true
:
false
;
// weight matrix
auto
w
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
1
]);
auto
r
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
2
]);
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
bias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
3
]);
}
// intial hidden state
// rewrite the gru_last_output operator that right after the gru
instruction_ref
ih
;
// operator. Intuitively, we can do a slice on its input to get
if
(
args
.
size
()
>=
5
)
// the last output, but it is already existed in the rnn operator,
{
// so we can just use it as the output here
i
h
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
4
]);
i
f
(
ins
->
name
()
==
"gru_last_output"
)
}
{
else
if
(
last_output
!=
prog
.
end
())
{
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
prog
.
replace_instruction
(
ins
,
op
::
identity
{},
last_output
);
last_output
=
prog
.
end
();
}
}
auto
ret
=
gru_oper
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
ih
,
bias
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
}
}
}
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_gru
::
gru_
oper
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
rewrite_gru
::
gru_
cell
(
bool
is_forward
,
program
&
prog
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
r
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
operation
&
actv_func2
)
const
{
{
instruction_ref
hidden_out
,
final
_out
;
instruction_ref
hidden_out
,
last
_out
;
long
seq_len
=
static_cast
<
long
>
(
input
->
get_shape
().
lens
()[
0
]);
long
seq_len
=
static_cast
<
long
>
(
input
->
get_shape
().
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
1
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
migraphx
::
shape
s
(
input
->
get_shape
().
type
(),
migraphx
::
shape
s
(
input
->
get_shape
().
type
(),
{
input
->
get_shape
().
lens
()[
1
],
static_cast
<
std
::
size_t
>
(
hs
)});
{
input
->
get_shape
().
lens
()[
1
],
static_cast
<
std
::
size_t
>
(
hs
)});
...
@@ -180,122 +177,136 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -180,122 +177,136 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// weight matrix
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
twz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sw
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
w
);
auto
tran_wz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
twr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
w
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sw
);
auto
twh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
tran_wr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
r
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sw
);
auto
trz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
tran_wh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
r
);
auto
trr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
r
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sr
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
auto
tran_rz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sr
);
auto
tran_rr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sr
);
auto
tran_rh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// initial states
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
// bias
instruction_ref
br_bz
,
br_br
,
br_wbh
,
br_rbh
,
br_bh
;
instruction_ref
br
cst
_bz
,
br
cst
_br
,
br
cst
_wbh
,
br
cst
_rbh
,
br
cst
_bh
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
auto
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
sbias
);
brcst_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
s
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
s
bias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
auto
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
s
bias
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
br
cst
_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
ih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
br_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
bz
);
brcst_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
br_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
br
);
brcst_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
br
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
auto
bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbh
,
rbh
);
br_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
bh
);
br
cst
_bh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
ih
->
get_shape
()},
bh
);
}
}
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
{
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xwz
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twz
);
auto
x
t_
wz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wz
);
auto
hrz
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trz
);
auto
h
t_
rz
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
s
ih
,
t
ran_
rz
);
auto
x
whr
_z
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwz
t
,
hrz
t
);
auto
x
ht
_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wz
,
h
t_
rz
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
x
whr
_z
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whr
_z
t
,
br_bz
);
x
ht
_z
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht
_z
,
br
cst
_bz
);
}
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
whr
_z
t
);
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
ht
_z
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xwr
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twr
);
auto
x
t_
wr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wr
);
auto
hrr
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trr
);
auto
h
t_
rr
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
s
ih
,
t
ran_
rr
);
auto
x
whr
_r
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwr
t
,
hrr
t
);
auto
x
ht
_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wr
,
h
t_
rr
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
x
whr
_r
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whr
_r
t
,
br_br
);
x
ht
_r
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht
_r
,
br
cst
_br
);
}
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
whr
_r
t
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
x
ht
_r
);
instruction_ref
x
whh_rt
;
instruction_ref
x
ht_h
;
if
(
linear_before_reset
==
0
)
if
(
linear_before_reset
==
0
)
{
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xwh
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
x
t_
wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wh
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_ht
1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
s
ih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
1
,
t
ran_
rh
);
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwh
t
,
rt_rh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whh_rt
,
br_bh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht_h
,
br
cst
_bh
);
}
}
}
}
else
else
{
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xwh
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
x
t_
wh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
t
ran_
wh
);
auto
i
h_rh
t
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trh
);
auto
h
t1
_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
s
ih
,
t
ran_
rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
i
h_rh
t
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
i
h_rh
t
,
br_rbh
);
h
t1
_rh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
h
t1
_rh
,
br
cst
_rbh
);
}
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
i
h_rh
t
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
h
t1
_rh
);
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwh
t
,
rt_rh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
t_
wh
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
x
whh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
whh_rt
,
br_wbh
);
x
ht_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x
ht_h
,
br
cst
_wbh
);
}
}
}
}
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
x
whh_rt
);
auto
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
x
ht_h
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
z
1
t
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
one_minus_
zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
z1t
ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
z
1
t
,
ht
);
auto
one_minus_zt_
ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
one_minus_
zt
,
ht
);
auto
ztht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
ih
);
auto
zt
_
ht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
s
ih
);
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
z1t
ht
,
ztht1
);
s
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
one_minus_zt_
ht
,
zt
_
ht1
);
final
_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
ih
);
last
_out
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
0
}},
s
ih
);
if
(
is_forward
)
if
(
is_forward
)
{
{
hidden_out
=
(
seq_index
==
0
)
hidden_out
=
(
seq_index
==
0
)
?
final
_out
?
last
_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
final
_out
);
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
last
_out
);
}
}
else
else
{
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
final
_out
?
last
_out
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
final
_out
,
hidden_out
);
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
last
_out
,
hidden_out
);
}
}
seq_index
=
is_forward
?
(
seq_index
+
1
)
:
(
seq_index
-
1
);
seq_index
=
is_forward
?
(
seq_index
+
1
)
:
(
seq_index
-
1
);
}
}
std
::
vector
<
instruction_ref
>
out_args
;
std
::
vector
<
instruction_ref
>
out_args
;
out_args
.
push_back
(
hidden_out
);
out_args
.
push_back
(
hidden_out
);
out_args
.
push_back
(
final
_out
);
out_args
.
push_back
(
last
_out
);
return
out_args
;
return
out_args
;
}
}
...
...
src/rewrite_rnn.cpp
View file @
2d7f3523
...
@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -26,7 +26,7 @@ void rewrite_rnn::apply(program& prog) const
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
...
@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -133,7 +133,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
// rewrite the rnn_last_output operator that right after the rnn
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on
the
input to get
// operator. Intuitively, we can do a slice on
its
input to get
// the last output, but it is already existed in the rnn operator,
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
// so we can just use it as the output here
if
(
ins
->
name
()
==
"rnn_last_output"
)
if
(
ins
->
name
()
==
"rnn_last_output"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment