Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1596cf1f
"src/vscode:/vscode.git/clone" did not exist on "d628942b6656176d4d6b3c16405e4f640d62cf29"
Commit
1596cf1f
authored
Jan 23, 2019
by
Shucai Xiao
Browse files
clang format
parent
6d0742b6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
74 additions
and
62 deletions
+74
-62
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+10
-7
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+64
-55
No files found.
src/onnx/onnx.cpp
View file @
1596cf1f
...
@@ -741,12 +741,12 @@ struct onnx_parser
...
@@ -741,12 +741,12 @@ struct onnx_parser
act_funcs
[
1
]
=
attributes
.
at
(
"activations"
).
strings
(
1
);
act_funcs
[
1
]
=
attributes
.
at
(
"activations"
).
strings
(
1
);
}
}
if
(
act_funcs
.
size
()
!=
2
)
if
(
act_funcs
.
size
()
!=
2
)
{
{
MIGRAPHX_THROW
(
"GRU: wrong activation function attribute"
);
MIGRAPHX_THROW
(
"GRU: wrong activation function attribute"
);
}
}
for
(
std
::
size_t
i
=
0
;
i
<
act_funcs
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
act_funcs
.
size
();
++
i
)
{
{
if
(
actv_funcs
.
count
(
act_funcs
.
at
(
i
))
==
0
)
if
(
actv_funcs
.
count
(
act_funcs
.
at
(
i
))
==
0
)
{
{
...
@@ -762,14 +762,17 @@ struct onnx_parser
...
@@ -762,14 +762,17 @@ struct onnx_parser
}
}
int
linear_before_reset
=
0
;
int
linear_before_reset
=
0
;
if
(
contains
(
attributes
,
"linear_before_reset"
))
if
(
contains
(
attributes
,
"linear_before_reset"
))
{
{
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
linear_before_reset
=
parse_value
(
attributes
.
at
(
"linear_before_reset"
)).
at
<
int
>
();
}
}
return
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
return
prog
.
add_instruction
(
op
::
gru
{
hidden_size
,
{
actv_funcs
[
act_funcs
.
at
(
0
)],
actv_funcs
[
act_funcs
.
at
(
1
)]},
{
actv_funcs
[
act_funcs
.
at
(
0
)],
actv_funcs
[
act_funcs
.
at
(
1
)]},
dirct
,
clip
,
linear_before_reset
},
dirct
,
clip
,
linear_before_reset
},
std
::
move
(
args
));
std
::
move
(
args
));
}
}
...
...
src/rewrite_gru.cpp
View file @
1596cf1f
...
@@ -104,7 +104,8 @@ void rewrite_gru::apply(program& prog) const
...
@@ -104,7 +104,8 @@ void rewrite_gru::apply(program& prog) const
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
gru_op
.
actv_funcs
.
at
(
3
));
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1], ret_reverse[1]);
// auto final_output = prog.insert_instruction(ins, op::concat{0}, ret_forward[1],
// ret_reverse[1]);
// add the dimension of num_direction
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
...
@@ -138,8 +139,17 @@ void rewrite_gru::apply(program& prog) const
...
@@ -138,8 +139,17 @@ void rewrite_gru::apply(program& prog) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
}
}
auto
ret
=
gru_oper
(
auto
ret
=
gru_oper
(
is_forward
,
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
ih
,
bias
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
prog
,
ins
,
args
[
0
],
w
,
r
,
ih
,
bias
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
// add the dimension of num_direction
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
...
@@ -185,16 +195,16 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -185,16 +195,16 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// bias
// bias
instruction_ref
br_bz
,
br_br
,
br_wbh
,
br_rbh
,
br_bh
;
instruction_ref
br_bz
,
br_br
,
br_wbh
,
br_rbh
,
br_bh
;
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
wbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
bias
);
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
wbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
bias
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
br_wbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
wbh
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbz
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
3
*
hs
},
{
4
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
4
*
hs
},
{
5
*
hs
}},
bias
);
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
rbh
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
5
*
hs
},
{
6
*
hs
}},
bias
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
br_rbh
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ih
->
get_shape
()},
rbh
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
auto
bz
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbz
,
rbz
);
...
@@ -212,7 +222,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -212,7 +222,7 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto
xwzt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twz
);
auto
xwzt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twz
);
auto
hrzt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trz
);
auto
hrzt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
trz
);
auto
xwhr_zt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwzt
,
hrzt
);
auto
xwhr_zt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwzt
,
hrzt
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xwhr_zt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhr_zt
,
br_bz
);
xwhr_zt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhr_zt
,
br_bz
);
}
}
...
@@ -222,21 +232,21 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -222,21 +232,21 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
auto
xwrt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twr
);
auto
xwrt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twr
);
auto
hrrt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
trr
);
auto
hrrt
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
trr
);
auto
xwhr_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwrt
,
hrrt
);
auto
xwhr_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwrt
,
hrrt
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xwhr_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhr_rt
,
br_br
);
xwhr_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhr_rt
,
br_br
);
}
}
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xwhr_rt
);
auto
rt
=
prog
.
insert_instruction
(
ins
,
actv_func1
,
xwhr_rt
);
instruction_ref
xwhh_rt
;
instruction_ref
xwhh_rt
;
if
(
linear_before_reset
==
0
)
if
(
linear_before_reset
==
0
)
{
{
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
// equation g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_ht
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
rt_ht
,
trh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_rt
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_rt
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_bh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_bh
);
}
}
...
@@ -246,13 +256,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -246,13 +256,13 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
// equation ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
xwht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
twh
);
auto
ih_rht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
twh
);
auto
ih_rht
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
twh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
ih_rht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih_rht
,
br_rbh
);
ih_rht
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
ih_rht
,
br_rbh
);
}
}
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih_rht
);
auto
rt_rh
=
prog
.
insert_instruction
(
ins
,
op
::
mul
{},
rt
,
ih_rht
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_rh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwht
,
rt_rh
);
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_wbh
);
xwhh_rt
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
xwhh_rt
,
br_wbh
);
}
}
...
@@ -268,9 +278,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
...
@@ -268,9 +278,8 @@ std::vector<instruction_ref> rewrite_gru::gru_oper(bool is_forward,
if
(
is_forward
)
if
(
is_forward
)
{
{
hidden_out
=
(
seq_index
==
0
)
hidden_out
=
?
ih
(
seq_index
==
0
)
?
ih
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
ih
);
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
ih
);
}
}
else
else
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment