Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6d8fcb3d
Commit
6d8fcb3d
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
clang format
parent
233f3bcc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
33 deletions
+36
-33
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+36
-33
No files found.
src/rewrite_rnn.cpp
View file @
6d8fcb3d
...
@@ -675,23 +675,23 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -675,23 +675,23 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
assert
(
ins
->
name
()
==
"lstm"
);
assert
(
ins
->
name
()
==
"lstm"
);
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
hidden_size
=
args
[
1
]
->
get_shape
().
lens
()[
2
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
ihc_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ihc_shape
{
type
,
{
1
,
batch_size
,
hidden_size
}};
std
::
vector
<
float
>
ihc_data
(
ih_shape
.
elements
(),
0.0
);
std
::
vector
<
float
>
ihc_data
(
ih_shape
.
elements
(),
0.0
);
migraphx
::
shape
pph_shape
{
type
,
{
1
,
3
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
type
,
{
1
,
3
*
hidden_size
}};
std
::
vector
<
float
>
ppl_data
(
pph_shape
.
elements
(),
0.0
);
std
::
vector
<
float
>
ppl_data
(
pph_shape
.
elements
(),
0.0
);
auto
&
actv_funcs
=
lstm_actv_funcs
(
ins
);
auto
&
actv_funcs
=
lstm_actv_funcs
(
ins
);
auto
lstm_op
=
any_cast
<
op
::
lstm
>
(
ins
->
get_operator
());
auto
lstm_op
=
any_cast
<
op
::
lstm
>
(
ins
->
get_operator
());
op
::
lstm
::
lstm_direction_t
dirct
=
lstm_op
.
direction
;
op
::
lstm
::
lstm_direction_t
dirct
=
lstm_op
.
direction
;
instruction_ref
last_output
{};
instruction_ref
last_output
{};
instruction_ref
last_cell_output
{};
instruction_ref
last_cell_output
{};
if
(
dirct
==
op
::
lstm
::
bidirectional
)
if
(
dirct
==
op
::
lstm
::
bidirectional
)
{
{
// input weight matrix
// input weight matrix
// input weight matrix
// input weight matrix
...
@@ -705,7 +705,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -705,7 +705,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process bias
// process bias
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_forward
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
instruction_ref
bias_reverse
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
{
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
...
@@ -728,7 +728,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -728,7 +728,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process initial cell value
// process initial cell value
instruction_ref
ic_forward
{};
instruction_ref
ic_forward
{};
instruction_ref
ic_reverse
{};
instruction_ref
ic_reverse
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"undefined"
)
{
{
ic_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
6
]);
ic_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
6
]);
ic_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
6
]);
ic_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
6
]);
...
@@ -742,7 +742,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -742,7 +742,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// process weight of the peephole
// process weight of the peephole
instruction_ref
pph_forward
{};
instruction_ref
pph_forward
{};
instruction_ref
pph_reverse
{};
instruction_ref
pph_reverse
{};
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
==
8
&&
args
[
7
]
->
name
()
!=
"undefined"
)
{
{
pph_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
7
]);
pph_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
7
]);
pph_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
7
]);
pph_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
7
]);
...
@@ -752,45 +752,48 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -752,45 +752,48 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
pph_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
pph_forward
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
pph_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
pph_reverse
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
}
}
auto
ret_forward
=
lstm_cell
(
true
,
prog
,
ins
,
auto
ret_forward
=
lstm_cell
(
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
true
,
ih_forward
,
ic_forward
,
pph_forward
},
prog
,
lstm_op
.
input_forget
,
ins
,
actv_funcs
.
at
(
0
),
{
args
[
0
],
w_forward
,
r_forward
,
bias_forward
,
ih_forward
,
ic_forward
,
pph_forward
},
actv_funcs
.
at
(
1
),
lstm_op
.
input_forget
,
actv_funcs
.
at
(
2
));
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
auto
ret_reverse
=
lstm_cell
(
false
,
prog
,
ins
,
actv_funcs
.
at
(
2
));
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
,
ic_reverse
,
pph_reverse
},
auto
ret_reverse
=
lstm_cell
(
lstm_op
.
input_forget
,
false
,
actv_funcs
.
at
(
3
),
prog
,
actv_funcs
.
at
(
4
),
ins
,
actv_funcs
.
at
(
5
));
{
args
[
0
],
w_reverse
,
r_reverse
,
bias_reverse
,
ih_reverse
,
ic_reverse
,
pph_reverse
},
lstm_op
.
input_forget
,
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
4
),
actv_funcs
.
at
(
5
));
auto
concat_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
concat_output
);
// last cell output
// last cell output
auto
concat_cell_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
2
],
ret_reverse
[
2
]);
auto
concat_cell_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
2
],
ret_reverse
[
2
]);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
squeeze
{{
0
}},
concat_cell_output
);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
squeeze
{{
0
}},
concat_cell_output
);
// the following logic is to ensure the last instruction is a concat
// the following logic is to ensure the last instruction is a concat
if
(
ret_forward
[
0
]
==
prog
.
end
())
if
(
ret_forward
[
0
]
==
prog
.
end
())
{
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
1
},
ret_forward
[
1
],
ret_reverse
[
1
]);
}
}
else
else
{
{
}
}
}
}
else
else
{
{
}
}
}
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
lstm_cell
(
bool
is_forward
,
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
lstm_cell
(
bool
is_forward
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment