Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f8c319e3
Commit
f8c319e3
authored
Feb 07, 2019
by
Shucai Xiao
Browse files
add more code for lstm operator
parent
398c0157
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
246 additions
and
1 deletion
+246
-1
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-1
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+13
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+153
-0
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+79
-0
No files found.
src/include/migraphx/operators.hpp
View file @
f8c319e3
...
@@ -1267,7 +1267,7 @@ struct lstm
...
@@ -1267,7 +1267,7 @@ struct lstm
std
::
size_t
hidden_size
=
1
;
std
::
size_t
hidden_size
=
1
;
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{},
tanh
{}};
std
::
vector
<
operation
>
actv_funcs
{
sigmoid
{},
tanh
{},
tanh
{}};
gru
_direction_t
direction
=
forward
;
lstm
_direction_t
direction
=
forward
;
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
int
input_forget
=
0
;
int
input_forget
=
0
;
...
...
src/include/migraphx/rewrite_rnn.hpp
View file @
f8c319e3
...
@@ -45,6 +45,19 @@ struct rewrite_rnn
...
@@ -45,6 +45,19 @@ struct rewrite_rnn
const
operation
&
actv_func2
)
const
;
const
operation
&
actv_func2
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
std
::
vector
<
operation
>
gru_actv_funcs
(
instruction_ref
ins
)
const
;
// for lstm operators
void
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
;
std
::
vector
<
instruction_ref
>
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
;
std
::
vector
<
operation
>
lstm_actv_funcs
(
instruction_ref
ins
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/onnx/onnx.cpp
View file @
f8c319e3
...
@@ -900,6 +900,159 @@ struct onnx_parser
...
@@ -900,6 +900,159 @@ struct onnx_parser
return
{
hidden_states
,
last_output
};
return
{
hidden_states
,
last_output
};
}
}
std
::
vector
<
instruction_ref
>
parse_lstm
(
const
std
::
string
&
,
attribute_map
attributes
,
std
::
vector
<
instruction_ref
>
args
)
{
migraphx
::
shape
input_shape
=
args
[
0
]
->
get_shape
();
std
::
size_t
hidden_size
=
args
[
2
]
->
get_shape
().
lens
()[
2
];
if
(
contains
(
attributes
,
"hidden_size"
))
{
std
::
size_t
hidden_size_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
if
(
hidden_size
!=
hidden_size_att
)
{
MIGRAPHX_THROW
(
"LSTM: hidden size mismatch in input and attribute"
);
}
}
// Handling of direction to be added later
std
::
string
direction
{
"forward"
};
if
(
contains
(
attributes
,
"direction"
))
{
direction
=
attributes
.
at
(
"direction"
).
s
();
}
op
::
lstm
::
lstm_direction_t
dirct
=
op
::
lstm
::
forward
;
if
(
direction
==
"bidirectional"
)
{
dirct
=
op
::
lstm
::
bidirectional
;
}
else
if
(
direction
==
"reverse"
)
{
dirct
=
op
::
lstm
::
reverse
;
}
else
if
(
direction
==
"forward"
)
{
dirct
=
op
::
lstm
::
forward
;
}
else
{
MIGRAPHX_THROW
(
"LSTM: incorrect direction attribute"
);
}
std
::
vector
<
std
::
string
>
vec_names
=
{
"sigmoid"
,
"tanh"
,
"tanh"
};
if
(
contains
(
attributes
,
"activations"
))
{
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
// need 6 activation functions for bidirectional directions
if
(
dirct
==
op
::
lstm
::
bidirectional
)
{
// 6 activation functions are used in the bidirectional
// scenario. No spec is provided in onnx::operator. we
// use the algorithm that: if 1 actv function is provided,
// repeat 1st six times. If 2 actv functins are provided,
// repeat 2nd once, then repeat all three once
// if 3 actv funcs are provide, repeat all three once.
// the same algorithm is used for 4, 5, and 6 actv funcions
// provided. This may need change later
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
.
insert
(
vec_names
.
end
(),
5
,
vec_names
.
back
());
break
;
case
2
:
// repeat the 2nd actv func once, then repeat all three another time
vec_names
.
push_back
(
vec_names
.
back
());
vec_names
.
insert
(
vec_names
.
end
(),
vec_names
.
begin
(),
vec_names
.
end
());
break
;
case
3
:
// repeat all three actv funcs once
vec_names
.
insert
(
vec_names
.
end
(),
vec_names
.
begin
(),
vec_names
.
end
());
break
;
case
4
:
vec_names
.
insert
(
vec_names
.
end
(),
2
,
vec_names
.
back
());
break
;
case
5
:
vec_names
.
push_back
(
vec_names
.
back
());
break
;
default:
break
;
}
}
else
{
switch
(
vec_names
.
size
())
{
case
1
:
vec_names
.
insert
(
vec_names
.
end
(),
2
,
vec_names
.
back
());
break
;
case
2
:
// repeat the 2nd actv func once, so we have 3 actv funcs
vec_names
.
push_back
(
vec_names
.
back
());
break
;
default:
break
;
}
}
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
if
(
map_actv_funcs
.
count
(
name
)
==
0
)
{
MIGRAPHX_THROW
(
"LSTM: activation function "
+
std
::
string
(
name
)
+
" not supported"
);
}
});
std
::
vector
<
operation
>
vec_actv_funcs
(
vec_names
.
size
());
std
::
transform
(
vec_names
.
begin
(),
vec_names
.
end
(),
vec_actv_funcs
.
begin
(),
[
&
](
auto
&
name
)
{
return
map_actv_funcs
[
name
];
});
float
clip
=
0.0
;
if
(
contains
(
attributes
,
"clip"
))
{
clip
=
parse_value
(
attributes
.
at
(
"clip"
)).
at
<
float
>
();
}
int
input_forget
=
0
;
if
(
contains
(
attributes
,
"input_forget"
))
{
input_forget
=
parse_value
(
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
{
auto
ins
=
prog
.
add_instruction
(
op
::
undefined
{});
args
.
insert
(
args
.
end
(),
6
-
args
.
size
(),
ins
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
prog
.
add_instruction
(
op
::
lstm
{
hidden_size
,
vec_actv_funcs
,
dirct
,
clip
,
input_forget
},
std
::
move
(
args
));
// second output for last lstm output
auto
last_output
=
prog
.
add_instruction
(
op
::
lstm_last_output
{},
hidden_states
);
// third output for last cell output
auto
last_cell_output
=
prog
.
add_instruction
(
op
::
lstm_last_cell_output
{},
hidden_states
);
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
void
parse_from
(
std
::
istream
&
is
)
void
parse_from
(
std
::
istream
&
is
)
{
{
onnx
::
ModelProto
model
;
onnx
::
ModelProto
model
;
...
...
src/rewrite_rnn.cpp
View file @
f8c319e3
...
@@ -668,5 +668,84 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
...
@@ -668,5 +668,84 @@ std::vector<operation> rewrite_rnn::gru_actv_funcs(instruction_ref ins) const
}
}
}
}
// for lstm operators
void
rewrite_rnn
::
apply_lstm
(
program
&
prog
,
instruction_ref
ins
)
const
{
assert
(
ins
->
name
()
==
"lstm"
);
auto
args
=
ins
->
inputs
();
}
std
::
vector
<
instruction_ref
>
rewrite_rnn
::
lstm_cell
(
bool
is_forward
,
program
&
prog
,
instruction_ref
ins
,
std
::
vector
<
instruction_ref
>
inputs
,
int
linear_before_reset
,
const
operation
&
actv_func1
,
const
operation
&
actv_func2
,
const
operation
&
actv_func3
)
const
{
return
{};
}
std
::
vector
<
operation
>
rewrite_rnn
::
lstm_actv_funcs
(
instruction_ref
ins
)
const
{
auto
lstm_op
=
any_cast
<
op
::
lstm
>
(
ins
->
get_operator
());
// before rewrite the lstm operator, need to ensure
// we have 6 actv funcs, even though a user does not
// specifiy any actv func. If less than 46, use the
// algorithm in parse_lstm to make 6 actv functions
const
auto
&
actv_funcs
=
lstm_op
.
actv_funcs
;
std
::
size_t
num_actv_funcs
=
actv_funcs
.
size
();
if
(
lstm_op
.
direction
==
op
::
lstm
::
bidirectional
)
{
switch
(
num_actv_funcs
)
{
case
0
:
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{},
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{}};
case
1
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
)};
case
2
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
)};
case
3
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
)};
case
4
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
3
)};
case
5
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
2
),
actv_funcs
.
at
(
3
),
actv_funcs
.
at
(
4
),
actv_funcs
.
at
(
4
)};
default:
return
actv_funcs
;
}
}
else
{
switch
(
num_actv_funcs
)
{
case
0
:
return
{
op
::
sigmoid
{},
op
::
tanh
{},
op
::
tanh
{}};
case
1
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
0
)};
case
2
:
return
{
actv_funcs
.
at
(
0
),
actv_funcs
.
at
(
1
),
actv_funcs
.
at
(
1
)};
default:
return
actv_funcs
;
}
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment