Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
148e548d
"...resnet50_tensorflow.git" did not exist on "42bb922cd1566c872c1797c46207a43424b98673"
Commit
148e548d
authored
Feb 05, 2019
by
Shucai Xiao
Browse files
code refinement.
parent
42d2549d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
18 deletions
+61
-18
src/include/migraphx/rewrite_gru.hpp
src/include/migraphx/rewrite_gru.hpp
+6
-4
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+2
-4
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+53
-10
No files found.
src/include/migraphx/rewrite_gru.hpp
View file @
148e548d
...
...
@@ -25,13 +25,15 @@ struct rewrite_gru
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
w
x
,
instruction_ref
wh
,
instruction_ref
w
,
instruction_ref
r
,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
;
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
compute_actv_funcs
(
instruction_ref
ins
)
const
;
};
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/onnx.cpp
View file @
148e548d
...
...
@@ -818,10 +818,8 @@ struct onnx_parser
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
for
(
auto
&
fn
:
names
)
{
vec_names
.
push_back
(
fn
);
}
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[]
(
auto
&
str
)
{
return
str
;
});
}
// need 4 activation functions
...
...
src/rewrite_gru.cpp
View file @
148e548d
...
...
@@ -15,6 +15,7 @@ void rewrite_gru::apply(program& prog) const
{
if
(
ins
->
name
()
==
"gru"
)
{
const
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
...
...
@@ -70,8 +71,8 @@ void rewrite_gru::apply(program& prog) const
bias_forward
,
ih_forward
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
auto
ret_reverse
=
gru_cell
(
false
,
prog
,
...
...
@@ -82,8 +83,8 @@ void rewrite_gru::apply(program& prog) const
bias_reverse
,
ih_reverse
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
3
));
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
...
...
@@ -110,7 +111,7 @@ void rewrite_gru::apply(program& prog) const
}
else
{
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
)
?
true
:
false
;
bool
is_forward
=
(
dicrt
==
op
::
gru
::
forward
);
// weight matrix
auto
w
=
args
[
1
];
auto
r
=
args
[
2
];
...
...
@@ -142,8 +143,8 @@ void rewrite_gru::apply(program& prog) const
bias
,
ih
,
gru_op
.
linear_before_reset
,
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
));
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
));
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
...
...
@@ -155,7 +156,7 @@ void rewrite_gru::apply(program& prog) const
else
{
auto
concat_arg0
=
is_forward
?
ret
[
0
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
1
];
auto
concat_arg1
=
is_forward
?
ret
[
1
]
:
ret
[
0
];
hidden_state
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
...
...
@@ -186,9 +187,10 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
instruction_ref
bias
,
instruction_ref
ih
,
int
linear_before_reset
,
operation
&
actv_func1
,
operation
&
actv_func2
)
const
const
operation
&
actv_func1
,
const
operation
&
actv_func2
)
const
{
assert
(
actv_funcs
.
size
()
==
2
);
instruction_ref
hidden_states
=
prog
.
end
(),
last_output
;
long
seq_len
=
static_cast
<
long
>
(
input
->
get_shape
().
lens
()[
0
]);
long
hs
=
static_cast
<
long
>
(
r
->
get_shape
().
lens
()[
2
]);
...
...
@@ -334,5 +336,46 @@ std::vector<instruction_ref> rewrite_gru::gru_cell(bool is_forward,
return
{
hidden_states
,
last_output
};
}
std
::
vector
<
operation
>
rewrite_gru
::
compute_actv_funcs
(
instruction_ref
ins
)
const
{
auto
gru_op
=
any_cast
<
op
::
gru
>
(
ins
->
get_operator
());
// before rewrite the gru operator, need to ensure
// we have 4 actv funcs, even though a user does not
// specifiy any actv func. If less than 4, use the
// algorithm in parse_gru to make 4 actv functions
if
(
gru_op
.
direction
==
op
::
gru
::
bidirectional
)
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
2
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
)};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
3
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
1
),
gru_op
.
actv_funcs
.
at
(
2
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
else
{
if
(
gru_op
.
actv_funcs
.
empty
())
return
{
op
::
sigmoid
{},
op
::
tanh
{}};
else
if
(
gru_op
.
actv_funcs
.
size
()
==
1
)
return
{
gru_op
.
actv_funcs
.
at
(
0
),
gru_op
.
actv_funcs
.
at
(
0
)};
else
return
gru_op
.
actv_funcs
;
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment