Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
e238ad93
Commit
e238ad93
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
clang format
parent
97568d53
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
17 deletions
+17
-17
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+17
-17
No files found.
src/rewrite_rnn.cpp
View file @
e238ad93
...
@@ -682,7 +682,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -682,7 +682,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
migraphx
::
shape
pph_shape
{
type
,
{
1
,
3
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
type
,
{
1
,
3
*
hidden_size
}};
std
::
vector
<
float
>
pph_data
(
pph_shape
.
elements
(),
0.0
);
std
::
vector
<
float
>
pph_data
(
pph_shape
.
elements
(),
0.0
);
auto
actv_funcs
=
lstm_actv_funcs
(
ins
);
auto
actv_funcs
=
lstm_actv_funcs
(
ins
);
auto
lstm_op
=
any_cast
<
op
::
lstm
>
(
ins
->
get_operator
());
auto
lstm_op
=
any_cast
<
op
::
lstm
>
(
ins
->
get_operator
());
op
::
lstm
::
lstm_direction_t
dirct
=
lstm_op
.
direction
;
op
::
lstm
::
lstm_direction_t
dirct
=
lstm_op
.
direction
;
...
@@ -802,14 +802,14 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -802,14 +802,14 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// bias
// bias
instruction_ref
bias
=
prog
.
end
();
instruction_ref
bias
=
prog
.
end
();
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
>=
4
&&
args
[
3
]
->
name
()
!=
"undefined"
)
{
{
bias
=
args
[
3
];
bias
=
args
[
3
];
}
}
// initial hidden state
// initial hidden state
instruction_ref
ih
{};
instruction_ref
ih
{};
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
>=
6
&&
args
[
5
]
->
name
()
!=
"undefined"
)
{
{
ih
=
args
[
5
];
ih
=
args
[
5
];
}
}
...
@@ -820,7 +820,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -820,7 +820,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// initial cell value
// initial cell value
instruction_ref
ic
{};
instruction_ref
ic
{};
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"undefined"
)
if
(
args
.
size
()
>=
7
&&
args
[
6
]
->
name
()
!=
"undefined"
)
{
{
ic
=
args
[
6
];
ic
=
args
[
6
];
}
}
...
@@ -840,8 +840,8 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -840,8 +840,8 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
{
{
pph
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
pph
=
prog
.
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
}
}
auto
ret
=
lstm_cell
(
is_forward
,
auto
ret
=
lstm_cell
(
is_forward
,
prog
,
prog
,
ins
,
ins
,
{
args
[
0
],
w
,
r
,
bias
,
ih
,
ic
,
pph
},
{
args
[
0
],
w
,
r
,
bias
,
ih
,
ic
,
pph
},
...
@@ -850,9 +850,9 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -850,9 +850,9 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
));
actv_funcs
.
at
(
2
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
1
]);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
2
]);
last_cell_output
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ret
[
2
]);
if
(
ret
[
0
]
==
prog
.
end
())
if
(
ret
[
0
]
==
prog
.
end
())
{
{
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
prog
.
replace_instruction
(
ins
,
op
::
concat
{
0
},
ret
[
1
]);
}
}
...
@@ -866,7 +866,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
...
@@ -866,7 +866,7 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
// replace the corresponding lstm_last_output instruction
// replace the corresponding lstm_last_output instruction
// with the last_output, and the lstm_last_cell_output with
// with the last_output, and the lstm_last_cell_output with
// the last_cell_output. The while loop is to handle the case
// the last_cell_output. The while loop is to handle the case
// of multiple lstm_last_output and lstm_last_cell_output
// of multiple lstm_last_output and lstm_last_cell_output
// operators
// operators
auto
last_output_it
=
ins
->
outputs
().
begin
();
auto
last_output_it
=
ins
->
outputs
().
begin
();
...
@@ -909,17 +909,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
...
@@ -909,17 +909,17 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
{
{
// must have 7 args in the input vector
// must have 7 args in the input vector
assert
(
inputs
.
size
()
==
7
);
assert
(
inputs
.
size
()
==
7
);
auto
seq
=
inputs
.
at
(
0
);
auto
seq
=
inputs
.
at
(
0
);
auto
w
=
inputs
.
at
(
1
);
auto
w
=
inputs
.
at
(
1
);
auto
r
=
inputs
.
at
(
2
);
auto
r
=
inputs
.
at
(
2
);
auto
bias
=
inputs
.
at
(
3
);
auto
bias
=
inputs
.
at
(
3
);
auto
ih
=
inputs
.
at
(
4
);
auto
ih
=
inputs
.
at
(
4
);
auto
ic
=
inputs
.
at
(
5
);
auto
ic
=
inputs
.
at
(
5
);
auto
pph
=
inputs
.
at
(
6
);
auto
pph
=
inputs
.
at
(
6
);
instruction_ref
instruction_ref
return
{};
return
{};
}
}
std
::
vector
<
operation
>
rewrite_rnn
::
lstm_actv_funcs
(
instruction_ref
ins
)
const
std
::
vector
<
operation
>
rewrite_rnn
::
lstm_actv_funcs
(
instruction_ref
ins
)
const
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment