Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
483c4508
Commit
483c4508
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
clang format
parent
f8c319e3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
61 deletions
+64
-61
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-2
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+7
-7
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+9
-20
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+46
-32
No files found.
src/include/migraphx/operators.hpp
View file @
483c4508
...
...
@@ -1268,8 +1268,8 @@ struct lstm
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{},
tanh
{}};
lstm_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
int
input_forget
=
0
;
float
clip
=
0.0
f
;
int
input_forget
=
0
;
std
::
string
name
()
const
{
return
"lstm"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
483c4508
...
...
@@ -49,13 +49,13 @@ struct rewrite_rnn
// for lstm operators
void
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
};
...
...
src/onnx/onnx.cpp
View file @
483c4508
...
...
@@ -931,7 +931,7 @@ struct onnx_parser
{
dirct
=
op
::
lstm
::
reverse
;
}
else
if
(
direction
==
"forward"
)
else
if
(
direction
==
"forward"
)
{
dirct
=
op
::
lstm
::
forward
;
}
...
...
@@ -958,14 +958,12 @@ struct onnx_parser
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
.
insert
(
vec_names
.
end
(),
5
,
vec_names
.
back
());
break
;
case
1
:
vec_names
.
insert
(
vec_names
.
end
(),
5
,
vec_names
.
back
());
break
;
case
2
:
// repeat the 2nd actv func once, then repeat all three another time
...
...
@@ -978,33 +976,25 @@ struct onnx_parser
vec_names
.
insert
(
vec_names
.
end
(),
vec_names
.
begin
(),
vec_names
.
end
());
break
;
case
4
:
vec_names
.
insert
(
vec_names
.
end
(),
2
,
vec_names
.
back
());
break
;
case
4
:
vec_names
.
insert
(
vec_names
.
end
(),
2
,
vec_names
.
back
());
break
;
case
5
:
vec_names
.
push_back
(
vec_names
.
back
());
break
;
case
5
:
vec_names
.
push_back
(
vec_names
.
back
());
break
;
default:
break
;
default:
break
;
}
}
else
{
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
.
insert
(
vec_names
.
end
(),
2
,
vec_names
.
back
());
break
;
case
1
:
vec_names
.
insert
(
vec_names
.
end
(),
2
,
vec_names
.
back
());
break
;
case
2
:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names
.
push_back
(
vec_names
.
back
());
break
;
default:
break
;
default:
break
;
}
}
...
...
@@ -1041,8 +1031,7 @@ struct onnx_parser
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
lstm
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
input_forget
},
std
::
move
(
args
));
op
::
lstm
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
input_forget
},
std
::
move
(
args
));
// second output for last lstm output
auto
last_output
=
prog
.
add_instruction
(
op
::
lstm_last_output
{},
hidden_states
);
...
...
src/rewrite_rnn.cpp
View file @
483c4508
...
...
@@ -676,13 +676,13 @@ void rewrite_rnn::apply_lstm(program& prog, instruction_ref ins) const
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
{
return
{};
}
...
...
@@ -694,55 +694,69 @@ std::vector<operation> rewrite_rnn::lstm_actv_funcs(instruction_ref ins) const
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const
auto
&
actv_funcs
=
lstm_op
.
actv_funcs
;
const
auto
&
actv_funcs
=
lstm_op
.
actv_funcs
;
std
::
size_t
num_actv_funcs
=
actv_funcs
.
size
();
if
(
lstm_op
.
direction
==
op
::
lstm
::
bidirectional
)
{
switch
(
num_actv_funcs
)
{
case
0
:
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{}};
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{}};
case
1
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
)};
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
)};
case
2
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
)};
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
)};
case
3
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
)};
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
)};
case
4
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
3
)};
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
3
)};
case
5
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
4
),
actv_funcs
.
at
(
4
)};
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
4
),
actv_funcs
.
at
(
4
)};
default:
return
actv_funcs
;
default:
return
actv_funcs
;
}
}
else
{
switch
(
num_actv_funcs
)
{
case
0
:
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{}};
case
0
:
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{}};
case
1
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
)};
case
1
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
)};
case
2
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
)};
default:
return
actv_funcs
;
case
2
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
)};
default:
return
actv_funcs
;
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment