Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
20f89fcc
Commit
20f89fcc
authored
Jan 28, 2019
by
Shucai Xiao
Browse files
clang format
parent
0fe4c56b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
39 deletions
+34
-39
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+3
-3
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+31
-36
No files found.
src/onnx/onnx.cpp
View file @
20f89fcc
...
@@ -729,12 +729,12 @@ struct onnx_parser
...
@@ -729,12 +729,12 @@ struct onnx_parser
std
::
vector
<
instruction_ref
>
result
;
std
::
vector
<
instruction_ref
>
result
;
// first output for the concatenation of hidden states
// first output for the concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
},
auto
hidden_states
=
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
},
std
::
move
(
args
));
std
::
move
(
args
));
result
.
push_back
(
hidden_states
);
result
.
push_back
(
hidden_states
);
// second out for the last hidden state
// second out for the last hidden state
//auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//
auto last_output = prog.add_instruction(op::rnn_last_output{}, hidden_states);
//result.push_back(last_output);
//
result.push_back(last_output);
return
result
;
return
result
;
}
}
...
...
src/rewrite_rnn.cpp
View file @
20f89fcc
...
@@ -16,9 +16,9 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -16,9 +16,9 @@ void rewrite_rnn::apply(program& prog) const
// rewrite rnn operator
// rewrite rnn operator
if
(
ins
->
name
()
==
"rnn"
)
if
(
ins
->
name
()
==
"rnn"
)
{
{
// could be 3 to 6 inputs, but the 5th input is undefined in
// could be 3 to 6 inputs, but the 5th input is undefined in
// pytorch exported onnx, and it is ignored by protobuf. So
// pytorch exported onnx, and it is ignored by protobuf. So
// for input arguments 5 and 6, we need to check the shape,
// for input arguments 5 and 6, we need to check the shape,
// then based on the shape to judge the specific input info
// then based on the shape to judge the specific input info
auto
args
=
ins
->
inputs
();
auto
args
=
ins
->
inputs
();
...
@@ -34,12 +34,12 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -34,12 +34,12 @@ void rewrite_rnn::apply(program& prog) const
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
{
{
// input weight matrix
// input weight matrix
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
auto
w_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
1
]);
// hidden state weight matrix
// hidden state weight matrix
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
auto
r_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
2
]);
// process bias
// process bias
instruction_ref
bias_forward
,
bias_reverse
;
instruction_ref
bias_forward
,
bias_reverse
;
...
@@ -53,11 +53,12 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -53,11 +53,12 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state, it could be the 6th argument
// process intial hidden state, it could be the 6th argument
// or the 5th one (if the sequence len argument is ignored)
// or the 5th one (if the sequence len argument is ignored)
instruction_ref
ih_forward
,
ih_reverse
;
instruction_ref
ih_forward
,
ih_reverse
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
{
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
auto
arg_ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
ih_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
arg_ih
);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
arg_ih
);
ih_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
arg_ih
);
}
}
else
else
{
{
...
@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -84,7 +85,8 @@ void rewrite_rnn::apply(program& prog) const
ih_reverse
,
ih_reverse
,
rnn_op
.
actv_funcs
.
at
(
1
));
rnn_op
.
actv_funcs
.
at
(
1
));
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
last_output
=
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
ret_forward
[
1
],
ret_reverse
[
1
]);
// add the dimension of num_direction
// add the dimension of num_direction
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
ret_forward
[
0
]
=
prog
.
insert_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret_forward
[
0
]);
...
@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -111,7 +113,8 @@ void rewrite_rnn::apply(program& prog) const
// process intial hidden state
// process intial hidden state
instruction_ref
ih
;
instruction_ref
ih
;
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
if
(
args
.
size
()
==
6
||
(
args
.
size
()
==
5
&&
args
[
4
]
->
get_shape
().
lens
().
size
()
==
3
))
{
{
ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
ih
=
(
args
.
size
()
==
6
)
?
args
[
5
]
:
args
[
4
];
}
}
...
@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -120,15 +123,8 @@ void rewrite_rnn::apply(program& prog) const
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
ih_shape
,
data
});
}
}
auto
ret
=
rnn_cell
(
is_forward
,
auto
ret
=
rnn_cell
(
prog
,
is_forward
,
prog
,
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
ins
,
args
[
0
],
w
,
r
,
bias
,
ih
,
rnn_op
.
actv_funcs
.
at
(
0
));
last_output
=
ret
[
1
];
last_output
=
ret
[
1
];
// add the dimension of num_direction
// add the dimension of num_direction
...
@@ -136,11 +132,11 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -136,11 +132,11 @@ void rewrite_rnn::apply(program& prog) const
}
}
}
}
// rewrite the rnn_last_output operator that right after the rnn
// rewrite the rnn_last_output operator that right after the rnn
// operator. Intuitively, we can do a slice on the input to get
// operator. Intuitively, we can do a slice on the input to get
// the last output, but it is already existed in the rnn operator,
// the last output, but it is already existed in the rnn operator,
// so we can just use it as the output here
// so we can just use it as the output here
//if (ins->name() == "rnn_last_output")
//
if (ins->name() == "rnn_last_output")
//{
//{
// // if rnn operator is executed, the last_output != prog.end()
// // if rnn operator is executed, the last_output != prog.end()
// if (last_output != prog.end())
// if (last_output != prog.end())
...
@@ -164,31 +160,30 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -164,31 +160,30 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
{
{
// squeeze and transpose w
// squeeze and transpose w
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
sw
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
w
);
auto
tran_sw
=
prog
.
insert_instruction
(
sw
,
op
::
transpose
{
perm
},
sw
);
auto
tran_sw
=
prog
.
insert_instruction
(
sw
,
op
::
transpose
{
perm
},
sw
);
// squeeze and transpose r
// squeeze and transpose r
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
sr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
r
);
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
auto
tran_sr
=
prog
.
insert_instruction
(
ins
,
op
::
transpose
{
perm
},
sr
);
// initial hidden state
// initial hidden state
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
auto
sih
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
ih
);
// bias
// bias
if
(
bias
!=
prog
.
end
())
if
(
bias
!=
prog
.
end
())
{
{
long
hs
=
r
->
get_shape
().
lens
()[
2
];
long
hs
=
r
->
get_shape
().
lens
()[
2
];
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
sbias
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
bias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
sbias
);
auto
rb
=
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
sbias
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
sih
->
get_shape
()},
b
);
}
}
instruction_ref
hidden_out
,
last_out
;
instruction_ref
hidden_out
,
last_out
;
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
std
::
size_t
seq_len
=
input
->
get_shape
().
lens
()[
0
];
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
{
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
...
@@ -207,7 +202,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
...
@@ -207,7 +202,7 @@ std::vector<instruction_ref> rewrite_rnn::rnn_cell(bool is_forward,
}
}
// apply activation function
// apply activation function
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
ht
=
prog
.
insert_instruction
(
ins
,
actv_func
,
ht
);
sih
=
ht
;
sih
=
ht
;
// add the dimension of sequence length
// add the dimension of sequence length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment