Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bddd8454
Commit
bddd8454
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
more test examples and code cleanup.
parent
62eea2df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
106 additions
and
29 deletions
+106
-29
src/rewrite_gru.cpp
src/rewrite_gru.cpp
+11
-5
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+16
-24
test/cpu_rnn_ops_test.cpp
test/cpu_rnn_ops_test.cpp
+79
-0
No files found.
src/rewrite_gru.cpp
View file @
bddd8454
...
@@ -154,13 +154,19 @@ void rewrite_gru::apply(program& prog) const
...
@@ -154,13 +154,19 @@ void rewrite_gru::apply(program& prog) const
// replace the corresponding gru_last_output instruction
// replace the corresponding gru_last_output instruction
// with the last_output, if gru_last_output exists
// with the last_output, if gru_last_output exists
auto
last_output_it
=
// while loop to handle case of multiple gru_last_output operators
std
::
find_if
(
ins
->
outputs
().
begin
(),
ins
->
outputs
().
end
(),
[](
auto
i
)
{
auto
last_output_it
=
ins
->
outputs
().
begin
();
while
(
last_output_it
!=
ins
->
outputs
().
end
())
{
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[]
(
auto
i
)
{
return
i
->
name
()
==
"gru_last_output"
;
return
i
->
name
()
==
"gru_last_output"
;
});
});
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
if
(
last_output_it
!=
ins
->
outputs
().
end
())
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
}
}
}
}
...
...
src/rewrite_rnn.cpp
View file @
bddd8454
...
@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -10,7 +10,6 @@ inline namespace MIGRAPHX_INLINE_NS {
void
rewrite_rnn
::
apply
(
program
&
prog
)
const
void
rewrite_rnn
::
apply
(
program
&
prog
)
const
{
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
map_last_output
;
for
(
auto
ins
:
iterator_for
(
prog
))
for
(
auto
ins
:
iterator_for
(
prog
))
{
{
// rewrite rnn operator
// rewrite rnn operator
...
@@ -32,6 +31,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -32,6 +31,7 @@ void rewrite_rnn::apply(program& prog) const
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
auto
actv_funcs
=
compute_actv_funcs
(
ins
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
instruction_ref
last_output
{};
if
(
dicrt
==
op
::
rnn
::
bidirectional
)
if
(
dicrt
==
op
::
rnn
::
bidirectional
)
{
{
// input weight matrix
// input weight matrix
...
@@ -87,7 +87,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -87,7 +87,7 @@ void rewrite_rnn::apply(program& prog) const
auto
concat_output
=
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// The following logic is to ensure the last instruction rewritten from
// The following logic is to ensure the last instruction rewritten from
// rnn operator is a concat instruction
// rnn operator is a concat instruction
...
@@ -107,7 +107,6 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -107,7 +107,6 @@ void rewrite_rnn::apply(program& prog) const
hidden_output
=
prog
.
replace_instruction
(
hidden_output
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
ins
,
op
::
concat
{
1
},
{
ret_forward
[
0
],
ret_reverse
[
0
]});
}
}
map_last_output
[
hidden_output
]
=
last_output
;
}
}
else
else
{
{
...
@@ -138,7 +137,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -138,7 +137,7 @@ void rewrite_rnn::apply(program& prog) const
auto
ret
=
auto
ret
=
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
rnn_cell
(
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
actv_funcs
.
at
(
0
));
auto
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
// following logic is to ensure the last instruction is a
// following logic is to ensure the last instruction is a
// concat instruction
// concat instruction
...
@@ -155,30 +154,23 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -155,30 +154,23 @@ void rewrite_rnn::apply(program& prog) const
hidden_output
=
hidden_output
=
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
concat_arg0
,
concat_arg1
);
}
}
// auto last_it = std::find_if();
// if(last_it != ins->outputs().end())
// {
// }
map_last_output
[
hidden_output
]
=
last_output
;
}
}
}
// rewrite the rnn_last_output operator that right after the rnn
// search its output to find if there are rnn_last_output operator
// operator. Intuitively, we can do a slice on its input to get
// while loop to handle case of multiple rnn_last_output operators
// the last output, but it is already existed in the rnn operator,
auto
last_output_it
=
ins
->
outputs
().
begin
();
// so we can just use it as the output here
while
(
last_output_it
!=
ins
->
outputs
().
end
())
if
(
ins
->
name
()
==
"rnn_last_output"
)
{
auto
inputs
=
ins
->
inputs
();
assert
(
inputs
.
size
()
==
1
);
auto
arg
=
inputs
[
0
];
if
(
map_last_output
.
count
(
arg
)
==
0
)
{
{
MIGRAPHX_THROW
(
"RNN_LAST_OUTPUT: no related rnn operator as its input"
);
last_output_it
=
std
::
find_if
(
last_output_it
,
ins
->
outputs
().
end
(),
[]
(
auto
i
)
{
}
return
i
->
name
()
==
"rnn_last_output"
;
});
prog
.
replace_instruction
(
ins
,
map_last_output
[
arg
]);
if
(
last_output_it
!=
ins
->
outputs
().
end
())
{
prog
.
replace_instruction
(
*
last_output_it
,
last_output
);
last_output_it
++
;
}
}
}
}
}
}
}
}
...
...
test/cpu_rnn_ops_test.cpp
View file @
bddd8454
...
@@ -138,6 +138,43 @@ TEST_CASE(rnn_forward)
...
@@ -138,6 +138,43 @@ TEST_CASE(rnn_forward)
EXPECT
(
migraphx
::
verify_range
(
last_output_data
,
last_output_data_gold
));
EXPECT
(
migraphx
::
verify_range
(
last_output_data
,
last_output_data_gold
));
}
}
// multiple rnn_last_output operators
{
migraphx
::
program
p
;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input
});
auto
ih
=
p
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hidden_size
,
{},
migraphx
::
op
::
rnn
::
forward
,
clip
},
seq
,
w
,
r
,
bias
,
und
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
last_output
=
p
.
eval
({});
std
::
vector
<
float
>
last_output_data
;
last_output
.
visit
([
&
](
auto
out
)
{
last_output_data
.
assign
(
out
.
begin
(),
out
.
end
());
});
std
::
vector
<
float
>
last_output_data_gold
{
0.03445704
,
0.19167931
,
-
0.3946827
,
-
0.30889652
,
-
0.22276389
,
0.44193283
,
-
0.16477929
,
-
0.11893477
};
EXPECT
(
migraphx
::
verify_range
(
last_output_data
,
last_output_data_gold
));
}
// 3 args
// 3 args
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -617,6 +654,48 @@ TEST_CASE(gru_forward)
...
@@ -617,6 +654,48 @@ TEST_CASE(gru_forward)
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
}
}
// two gru_last_output operators after gru
{
migraphx
::
program
p
;
auto
seq
=
p
.
add_literal
(
migraphx
::
literal
{
in_shape
,
input
});
auto
w
=
p
.
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
p
.
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
p
.
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
und
=
p
.
add_instruction
(
migraphx
::
op
::
undefined
{});
auto
ih
=
p
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
concat_hs
=
p
.
add_instruction
(
migraphx
::
op
::
gru
{
hidden_size
,
{
migraphx
::
op
::
sigmoid
{},
migraphx
::
op
::
tanh
{}},
migraphx
::
op
::
gru
::
forward
,
clip
,
1
},
seq
,
w
,
r
,
bias
,
und
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
gru_last_output
{},
concat_hs
);
p
.
add_instruction
(
migraphx
::
op
::
gru_last_output
{},
concat_hs
);
p
.
compile
(
migraphx
::
cpu
::
target
{});
auto
hs_concat
=
p
.
eval
({});
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
-
0.3969709
,
0.43360898
,
0.35775262
,
0.23280787
,
-
0.52179873
,
-
0.21944991
,
0.4535257
,
-
0.13735442
,
0.51757574
,
0.50380427
};
EXPECT
(
migraphx
::
verify_range
(
hs_data
,
hs_data_gold
));
}
// last output for output, linear_before_reset = 0
// last output for output, linear_before_reset = 0
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment