Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
af90a792
Commit
af90a792
authored
Feb 25, 2019
by
Shucai Xiao
Browse files
Merge branch 'lstm_operator' into seq2seq_example
parents
89b80be6
7031da28
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
24 deletions
+15
-24
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+15
-24
No files found.
src/onnx/onnx.cpp
View file @
af90a792
...
@@ -787,17 +787,14 @@ struct onnx_parser
...
@@ -787,17 +787,14 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
}
if
(
std
::
any_of
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
}))
});
if
(
name_it
!=
vec_names
.
end
())
{
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
}
...
@@ -881,8 +878,7 @@ struct onnx_parser
...
@@ -881,8 +878,7 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
}
// need 4 activation functions
// need 4 activation functions
...
@@ -920,13 +916,11 @@ struct onnx_parser
...
@@ -920,13 +916,11 @@ struct onnx_parser
}
}
}
}
if
(
std
::
any_of
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
}))
});
if
(
name_it
!=
vec_names
.
end
())
{
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
MIGRAPHX_THROW
(
"GRU: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
}
...
@@ -1011,8 +1005,7 @@ struct onnx_parser
...
@@ -1011,8 +1005,7 @@ struct onnx_parser
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
auto
names
=
attributes
.
at
(
"activations"
).
strings
();
vec_names
.
clear
();
vec_names
.
clear
();
vec_names
.
resize
(
names
.
size
());
vec_names
.
resize
(
names
.
size
());
std
::
transform
(
std
::
copy
(
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
());
names
.
begin
(),
names
.
end
(),
vec_names
.
begin
(),
[](
auto
&
str
)
{
return
str
;
});
}
}
// need 6 activation functions for bidirectional directions
// need 6 activation functions for bidirectional directions
...
@@ -1093,13 +1086,11 @@ struct onnx_parser
...
@@ -1093,13 +1086,11 @@ struct onnx_parser
}
}
}
}
if
(
std
::
any_of
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
}))
});
if
(
name_it
!=
vec_names
.
end
())
{
{
auto
name_it
=
std
::
find_if
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
name
)
{
return
(
map_actv_funcs
.
count
(
name
)
==
0
);
});
MIGRAPHX_THROW
(
"LSTM: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
MIGRAPHX_THROW
(
"LSTM: activation function "
+
std
::
string
(
*
name_it
)
+
" not supported"
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment