Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
0427ce2b
Commit
0427ce2b
authored
Jan 22, 2019
by
Shucai Xiao
Browse files
clang format.
parent
a2ea4ecd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
43 deletions
+44
-43
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+7
-8
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+0
-1
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+3
-2
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+34
-32
No files found.
src/include/migraphx/operators.hpp
View file @
0427ce2b
...
...
@@ -1062,28 +1062,28 @@ struct rnn
bidirectional
,
};
std
::
size_t
hidden_size
=
1
;
operation
actv_func
=
tanh
{};
std
::
size_t
hidden_size
=
1
;
operation
actv_func
=
tanh
{};
rnn_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
std
::
string
name
()
const
{
return
"rnn"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
auto
in_dims
=
inputs
[
0
].
lens
();
auto
in_dims
=
inputs
[
0
].
lens
();
auto
hidden_dims
=
inputs
[
1
].
lens
();
if
(
hidden_size
!=
hidden_dims
[
1
])
if
(
hidden_size
!=
hidden_dims
[
1
])
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in attribute and input"
);
}
std
::
size_t
num_directions
=
1
;
if
(
direction
==
rnn_direction_t
::
bidirectional
)
if
(
direction
==
rnn_direction_t
::
bidirectional
)
{
num_directions
=
2
;
}
if
(
num_directions
!=
hidden_dims
[
0
])
if
(
num_directions
!=
hidden_dims
[
0
])
{
MIGRAPHX_THROW
(
"RNN: num_direction does not match the direction attribute"
);
}
...
...
@@ -1096,7 +1096,6 @@ struct rnn
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
0427ce2b
...
...
@@ -7,7 +7,6 @@
#include <migraphx/operators.hpp>
#include <migraphx/config.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/onnx/onnx.cpp
View file @
0427ce2b
...
...
@@ -661,7 +661,7 @@ struct onnx_parser
actv_func_map
.
insert
(
std
::
make_pair
(
"relu"
,
op
::
relu
{}));
actv_func_map
.
insert
(
std
::
make_pair
(
"sigmoid"
,
op
::
sigmoid
{}));
if
(
actv_func_map
.
count
(
activation_func
)
==
0
)
if
(
actv_func_map
.
count
(
activation_func
)
==
0
)
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
activation_func
+
" not supported"
);
}
...
...
@@ -690,7 +690,8 @@ struct onnx_parser
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
actv_func_map
[
activation_func
],
dirct
,
clip
},
std
::
move
(
args
));
return
prog
.
add_instruction
(
op
::
rnn
{
hidden_size
,
actv_func_map
[
activation_func
],
dirct
,
clip
},
std
::
move
(
args
));
}
void
parse_from
(
std
::
istream
&
is
)
...
...
src/rewrite_rnn.cpp
View file @
0427ce2b
...
...
@@ -18,22 +18,21 @@ void rewrite_rnn::apply(program& prog) const
}
// could be 3 to 5 inputs (though onnx::rnn has 6 inputs,
// the 5th one is undefined and ignored by protobuf. so
// the 5th one is undefined and ignored by protobuf. so
// we need to process up to 5 inputs
auto
args
=
ins
->
inputs
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
shape
wgt_shape
=
args
[
1
]
->
get_shape
();
shape
seq_shape
=
args
[
0
]
->
get_shape
();
shape
wgt_shape
=
args
[
1
]
->
get_shape
();
std
::
size_t
hidden_size
=
wgt_shape
.
lens
()[
1
];
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
std
::
size_t
batch_size
=
seq_shape
.
lens
()[
1
];
shape
::
type_t
type
=
seq_shape
.
type
();
migraphx
::
shape
s
{
type
,
{
batch_size
,
hidden_size
}};
std
::
vector
<
char
>
data
(
s
.
bytes
(),
0
);
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
auto
rnn_op
=
any_cast
<
op
::
rnn
>
(
ins
->
get_operator
());
op
::
rnn
::
rnn_direction_t
dicrt
=
rnn_op
.
direction
;
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
if
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
bidirectional
)
{
std
::
vector
<
int64_t
>
perm
{
1
,
0
};
// process input weight matrix
...
...
@@ -65,17 +64,19 @@ void rewrite_rnn::apply(program& prog) const
long
h_size
=
static_cast
<
long
>
(
hidden_size
);
auto
b_forward
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
1
}},
args
[
3
]);
b_forward
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
b_forward
);
auto
wbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
h_size
}},
b_forward
);
auto
rbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
h_size
},
{
2
*
h_size
}},
b_forward
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbf
,
rbf
);
auto
wbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
h_size
}},
b_forward
);
auto
rbf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
h_size
},
{
2
*
h_size
}},
b_forward
);
auto
bf
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbf
,
rbf
);
bias_forward
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
},
bf
);
// backward
auto
b_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
1
},
{
2
}},
args
[
3
]);
b_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
b_reverse
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
h_size
}},
b_reverse
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
h_size
},
{
2
*
h_size
}},
b_reverse
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
auto
wbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
h_size
}},
b_reverse
);
auto
rbr
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
h_size
},
{
2
*
h_size
}},
b_reverse
);
auto
br
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wbr
,
rbr
);
bias_reverse
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
},
br
);
}
...
...
@@ -144,9 +145,9 @@ void rewrite_rnn::apply(program& prog) const
long
h_size
=
static_cast
<
long
>
(
hidden_size
);
auto
bwr
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
args
[
3
]);
auto
wb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
h_size
}},
bwr
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
h_size
},
{
2
*
h_size
}},
bwr
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
},
b
);
auto
rb
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
h_size
},
{
2
*
h_size
}},
bwr
);
auto
b
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
wb
,
rb
);
bias
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
s
},
b
);
}
// process intial hidden state
...
...
@@ -159,7 +160,8 @@ void rewrite_rnn::apply(program& prog) const
{
ih
=
prog
.
add_literal
(
migraphx
::
literal
{
s
,
data
});
}
auto
ret
=
rnn_oper
(
is_forward
,
prog
,
ins
,
args
[
0
],
trans_xw
,
trans_hw
,
ih
,
bias
,
rnn_op
.
actv_func
);
auto
ret
=
rnn_oper
(
is_forward
,
prog
,
ins
,
args
[
0
],
trans_xw
,
trans_hw
,
ih
,
bias
,
rnn_op
.
actv_func
);
// add the dimension of num_direction
prog
.
replace_instruction
(
ins
,
op
::
unsqueeze
{{
1
}},
ret
[
0
]);
...
...
@@ -168,14 +170,14 @@ void rewrite_rnn::apply(program& prog) const
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
rnn_oper
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
bias
,
operation
&
actv_func
)
const
program
&
prog
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
wh
,
instruction_ref
ih
,
instruction_ref
bias
,
operation
&
actv_func
)
const
{
instruction_ref
hidden_out
,
final_out
;
migraphx
::
shape
input_shape
=
input
->
get_shape
();
...
...
@@ -183,8 +185,8 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
long
seq_index
=
is_forward
?
0
:
seq_len
-
1
;
for
(
std
::
size_t
i
=
0
;
i
<
seq_len
;
i
++
)
{
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
xt
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
seq_index
},
{
seq_index
+
1
}},
input
);
xt
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
xt
);
auto
x_w
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
xt
,
wx
);
auto
h_r
=
prog
.
insert_instruction
(
ins
,
op
::
dot
{},
ih
,
wh
);
auto
x_h
=
prog
.
insert_instruction
(
ins
,
op
::
add
{},
x_w
,
h_r
);
...
...
@@ -208,14 +210,14 @@ std::vector<instruction_ref> rewrite_rnn::rnn_oper(bool is_forward,
if
(
is_forward
)
{
hidden_out
=
(
seq_index
==
0
)
?
output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
output
);
?
output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
hidden_out
,
output
);
}
else
{
hidden_out
=
(
seq_index
==
seq_len
-
1
)
?
output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
output
,
hidden_out
);
?
output
:
prog
.
insert_instruction
(
ins
,
op
::
concat
{
0
},
output
,
hidden_out
);
}
seq_index
=
is_forward
?
(
seq_index
+
1
)
:
(
seq_index
-
1
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment