Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6d0742b6
Commit
6d0742b6
authored
Jan 23, 2019
by
Shucai Xiao
Browse files
save implementation of gru operator
parent
67491293
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
403 additions
and
2 deletions
+403
-2
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+38
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+73
-2
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+292
-0
No files found.
src/include/migraphx/rewrite_gru.hpp
0 → 100644
View file @
6d0742b6
#ifndef MIGRAPHX_GUARD_RTGLIB_REWRITE_GRU_HPP
#define MIGRAPHX_GUARD_RTGLIB_REWRITE_GRU_HPP
#include <string>
#include <vector>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
program
;
/**
* Rewrite rnn to gemm and add.
*/
struct
rewrite_gru
{
std
::
string
name
()
const
{
return
"rewrite_gru"
;
}
void
apply
(
program
&
prog
)
const
;
private:
std
::
vector
<
instruction_ref
>
rnn_gru
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
bias
,
operation
&
actv_func
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/onnx/onnx.cpp
View file @
6d0742b6
...
@@ -86,6 +86,7 @@ struct onnx_parser
...
@@ -86,6 +86,7 @@ struct onnx_parser
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"ConstantFill"
,
&
onnx_parser
::
parse_constant_fill
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"Transpose"
,
&
onnx_parser
::
parse_transpose
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"RNN"
,
&
onnx_parser
::
parse_rnn
);
add_mem_op
(
"GRU"
,
&
onnx_parser
::
parse_gru
);
// init the activation function map
// init the activation function map
init_actv_func
();
init_actv_func
();
...
@@ -651,8 +652,7 @@ struct onnx_parser
...
@@ -651,8 +652,7 @@ struct onnx_parser
parse_rnn
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
parse_rnn
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
migraphx
::
shape
w_shape
=
args
[
1
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
1
];
std
::
size_t
hidden_size
=
w_shape
.
lens
()[
1
];
if
(
contains
(
attributes
,
"hidden_size"
))
if
(
contains
(
attributes
,
"hidden_size"
))
{
{
...
@@ -702,6 +702,77 @@ struct onnx_parser
...
@@ -702,6 +702,77 @@ struct onnx_parser
std
::
move
(
args
));
std
::
move
(
args
));
}
}
instruction_ref
parse_gru
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
hidden_size
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
}
else
{
MIGRAPHX_THROW
(
"GRU: hidden size attribute missing"
);
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
gru
::
gru_direction_t
dirct
=
op
::
gru
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
gru
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
gru
::
reverse
;
}
std
::
vector
<
std
::
string
>
act_funcs
=
{
"sigmoid"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
act_funcs
[
0
]
=
attributes
.
at
(
"activations"
).
strings
(
0
);
act_funcs
[
1
]
=
attributes
.
at
(
"activations"
).
strings
(
1
);
}
if
(
act_funcs
.
size
()
!=
2
)
{
MIGRAPHX_THROW
(
"GRU: wrong activation function attribute"
);
}
for
(
std
::
size_t
i
=
0
;
i
<
act_funcs
.
size
();
++
i
)
{
if
(
actv_funcs
.
count
(
act_funcs
.
at
(
i
))
==
0
)
{
MIGRAPHX_THROW
(
"GRU: activation function "
+
act_funcs
.
at
(
i
)
+
" not supported"
);
}
}
// To be added later
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
int
linear_before_reset
=
0
;
if
(
contains
(
attributes
,
"linear_before_reset"
))
{
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
return
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
{
actv_funcs
[
act_funcs
.
at
(
0
)],
actv_funcs
[
act_funcs
.
at
(
1
)]},
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
onnx
::
ModelProto
model
;
onnx
::
ModelProto
model
;
...
...
src/rewrite_gru.cpp
0 → 100644
View file @
6d0742b6
#include <migraphx/rewrite_gru.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
void
rewrite_gru
::
apply
(
program
&
prog
)
const
{
for
(
auto
ins
:
iterator_for
(
prog
))
{
if
(
ins
->
name
()
!=
"gru"
)
{
continue
;
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batchs
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ih_shape
{
type
,
{
batchs
,
hidden_size
}};
std
::
vector
<
char
>
data
(
ih_shape
.
bytes
(),
0
);
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
op
::
gru
::
gru_direction_t
dicrt
=
gru_op
.
direction
;
if
(
dicrt
==
op
::
gru
::
bidirectional
)
{
long
hs
=
static_cast
<
long
>
(
hidden_size
);
// forward weight
auto
uw_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
},
uw_forward
});
auto
ur_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_forward
);
// reverse weight
auto
uw_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
},
uw_reverse
});
auto
ur_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ur_reverse
);
// process bias
instruction_ref
bias_forward
,
bias_reverse
;
bias_forward
=
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
// forward bias
auto
uwb_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uwb_forward
);
// backward bias
auto
uwb_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
uwb_reverse
);
}
// intial hidden state
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
>=
5
)
{
// forward
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
4
]);
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih_forward
);
// reverse
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
4
]);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih_reverse
);
}
else
{
ih_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
auto
ret_forward
=
gru_oper
(
true
,
prog
,
ins
,
args
[
0
],
w_forward
,
r_forward
,
ih_forward
,
bias_forward
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
rnn_oper
(
false
,
prog
,
ins
,
args
[
0
],
w_reverse
,
r_reverse
,
ih_reverse
,
bias_reverse
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_reverse
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_reverse
[
0
]);
// concat the forward and reverse output
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
)
?
true
:
false
;
// weight matrix
auto
w
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
1
]);
auto
r
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
2
]);
// bias
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
)
{
bias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
3
]);
}
// intial hidden state
instruction_ref
ih
;
if
(
args
.
size
()
>=
5
)
{
ih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
4
]);
}
else
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
}
auto
ret
=
gru_oper
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
ih
,
bias
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_gru
::
gru_oper
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
ih
,
instruction_ref
bias
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
{
instruction_ref
hidden_out
,
final_out
;
long
seq_len
=
static_cast
<
long
>
(
input
->
get_shape
().
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
1
]);
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
migraphx
::
shape
s
(
input
->
get_shape
().
type
(),
{
1
});
auto
l1
=
prog
.
add_literal
(
migraphx
::
leteral
{
s
,
{
1
}});
// weight matrix
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
wz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
w
);
auto
twz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wz
);
auto
wr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
w
);
auto
twr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wr
);
auto
wh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
w
);
auto
twh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
wh
);
auto
rz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
r
);
auto
trz
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rz
);
auto
rr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
r
);
auto
trr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rr
);
auto
rh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
r
);
auto
trh
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
rh
);
// bias
instruction_ref
br_bz
,
br_br
,
br_wbh
,
br_rbh
,
br_bh
;
if
(
bias
!=
prog
.
end
())
{
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
br_bz
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
bz
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
br_br
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
br
);
br_bh
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
br_wbh
,
br_rbh
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
i
++
)
{
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
// equation f(xt*(Wz^T) + Ht-1 * (Rz^T) + Wbz + Rbz)
auto
xwzt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twz
);
auto
hrzt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trz
);
auto
xwhr_zt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwzt
,
hrzt
);
if
(
bias
!=
prog
.
end
())
{
xwhr_zt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhr_zt
,
br_bz
);
}
auto
zt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xwhr_zt
);
// equation f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto
xwrt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twr
);
auto
hrrt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
trr
);
auto
xwhr_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwrt
,
hrrt
);
if
(
bias
!=
prog
.
end
())
{
xwhr_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhr_rt
,
br_br
);
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xwhr_rt
);
instruction_ref
xwhh_rt
;
if
(
linear_before_reset
==
0
)
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_rt
);
if
(
bias
!=
prog
.
end
())
{
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_bh
);
}
}
else
{
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
ih_rht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
twh
);
if
(
bias
!=
prog
.
end
())
{
ih_rht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih_rht
,
br_rbh
);
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih_rht
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
{
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_wbh
);
}
}
ht
=
prog
.
insert_instruction
(
ins
,
actv_func2
,
xwhh_rt
);
// equation Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto
1
zt
=
prog
.
insert_instruction
(
ins
,
op
::
sub
{},
l1
,
zt
);
auto
1
ztht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
1
zt
,
ht
);
auto
ztht1
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
zt
,
ih
);
ih
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
1
ztht
ztht1
);
final_out
=
ih
;
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
?
ih
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
ih
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
ih
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ih
,
hidden_out
);
}
seq_index
=
is_forward
?
(
seq_index
+
1
)
:
(
seq_index
-
1
);
}
std
::
vector
<
instruction_ref
>
out_args
;
out_args
.
push_back
(
hidden_out
);
out_args
.
push_back
(
final_out
);
return
out_args
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment