Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
661f5b54
Commit
661f5b54
authored
Jan 29, 2019
by
Shucai Xiao
Browse files
add tests for rnn operator
parent
62044b86
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
167 additions
and
10 deletions
+167
-10
src/include/migraphx/rewrite_rnn.hpp
src/include/migraphx/rewrite_rnn.hpp
+3
-3
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+6
-6
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+1
-1
test/onnx/onnx_rnn_3args.onnx
test/onnx/onnx_rnn_3args.onnx
+0
-0
test/onnx/onnx_rnn_5args.onnx
test/onnx/onnx_rnn_5args.onnx
+0
-0
test/onnx/onnx_rnn_bi.onnx
test/onnx/onnx_rnn_bi.onnx
+0
-0
test/onnx/onnx_rnn_forward.onnx
test/onnx/onnx_rnn_forward.onnx
+0
-0
test/onnx/onnx_rnn_reverse.onnx
test/onnx/onnx_rnn_reverse.onnx
+0
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+157
-0
No files found.
src/include/migraphx/rewrite_rnn.hpp
View file @
661f5b54
...
@@ -25,10 +25,10 @@ struct rewrite_rnn
...
@@ -25,10 +25,10 @@ struct rewrite_rnn
program
&
prog
,
program
&
prog
,
instruction_ref
ins
,
instruction_ref
ins
,
instruction_ref
input
,
instruction_ref
input
,
instruction_ref
wx
,
instruction_ref
w
,
instruction_ref
wh
,
instruction_ref
r
,
instruction_ref
ih
,
instruction_ref
bias
,
instruction_ref
bias
,
instruction_ref
ih
,
operation
&
actv_func
)
const
;
operation
&
actv_func
)
const
;
};
};
...
...
src/onnx/onnx.cpp
View file @
661f5b54
...
@@ -664,11 +664,11 @@ struct onnx_parser
...
@@ -664,11 +664,11 @@ struct onnx_parser
if
(
contains
(
attributes
,
"hidden_size"
))
if
(
contains
(
attributes
,
"hidden_size"
))
{
{
hidden_size
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
std
::
size_t
hidden_size
_att
=
parse_value
(
attributes
.
at
(
"hidden_size"
)).
at
<
int
>
();
}
if
(
hidden_size
!=
hidden_size_att
)
else
{
{
MIGRAPHX_THROW
(
"RNN: hidden size mismatch in input and attribute"
);
MIGRAPHX_THROW
(
"RNN: hidden size attribute missing"
);
}
}
}
// Handling of direction to be added later
// Handling of direction to be added later
...
@@ -699,7 +699,7 @@ struct onnx_parser
...
@@ -699,7 +699,7 @@ struct onnx_parser
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
for_each
(
vec_names
.
begin
(),
vec_names
.
end
(),
[
&
](
auto
&
fn
)
{
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
if
(
map_actv_funcs
.
count
(
fn
)
==
0
)
{
{
MIGRAPHX_THROW
(
"RNN: activation function "
+
fn
+
" not supported"
);
MIGRAPHX_THROW
(
"RNN: activation function "
+
std
::
string
(
fn
)
+
" not supported"
);
}
}
});
});
...
...
src/rewrite_rnn.cpp
View file @
661f5b54
...
@@ -97,7 +97,7 @@ void rewrite_rnn::apply(program& prog) const
...
@@ -97,7 +97,7 @@ void rewrite_rnn::apply(program& prog) const
}
}
else
else
{
{
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
)
?
true
:
false
;
bool
is_forward
=
(
dicrt
==
op
::
rnn
::
rnn_direction_t
::
forward
);
// input weight matrix
// input weight matrix
auto
w
=
args
[
1
];
auto
w
=
args
[
1
];
...
...
test/onnx/onnx_rnn_3args.onnx
0 → 100644
View file @
661f5b54
File added
test/onnx/onnx_rnn_5args.onnx
0 → 100644
View file @
661f5b54
File added
test/onnx/onnx_rnn_bi.onnx
0 → 100644
View file @
661f5b54
File added
test/onnx/onnx_rnn_forward.onnx
0 → 100644
View file @
661f5b54
File added
test/onnx/onnx_rnn_reverse.onnx
0 → 100644
View file @
661f5b54
File added
test/onnx/onnx_test.cpp
View file @
661f5b54
...
@@ -439,6 +439,163 @@ TEST_CASE(shape_gather_test)
...
@@ -439,6 +439,163 @@ TEST_CASE(shape_gather_test)
EXPECT
(
p
==
prog
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
rnn_test
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
2
;
// num directions
float
clip
=
0.0
f
;
// bidirectional
{
migraphx
::
program
p
;
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
hs
}});
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
2
*
hs
}});
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn
::
bidirectional
,
clip
},
seq
,
w
,
r
,
bias
,
seq_len
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
auto
prog
=
migraphx
::
parse_onnx
(
"onnx_rnn_bi.onnx"
);
EXPECT
(
p
==
prog
);
}
// forward
{
nd
=
1
;
migraphx
::
program
p
;
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
hs
}});
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
2
*
hs
}});
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn
::
forward
,
clip
},
seq
,
w
,
r
,
bias
,
seq_len
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
auto
prog
=
migraphx
::
parse_onnx
(
"onnx_rnn_forward.onnx"
);
EXPECT
(
p
==
prog
);
}
// reverse
{
nd
=
1
;
migraphx
::
program
p
;
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
hs
}});
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
2
*
hs
}});
auto
seq_len
=
p
.
add_parameter
(
"seq_len"
,
migraphx
::
shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn
::
reverse
,
clip
},
seq
,
w
,
r
,
bias
,
seq_len
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
auto
prog
=
migraphx
::
parse_onnx
(
"onnx_rnn_reverse.onnx"
);
EXPECT
(
p
==
prog
);
}
// 3 argumments
{
nd
=
1
;
migraphx
::
program
p
;
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn
::
reverse
,
clip
},
seq
,
w
,
r
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
auto
prog
=
migraphx
::
parse_onnx
(
"onnx_rnn_3args.onnx"
);
EXPECT
(
p
==
prog
);
}
// 5 argumments
{
nd
=
1
;
migraphx
::
program
p
;
auto
seq
=
p
.
add_parameter
(
"seq"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
sl
,
bs
,
is
}});
auto
w
=
p
.
add_parameter
(
"w"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
is
}});
auto
r
=
p
.
add_parameter
(
"r"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
hs
,
hs
}});
auto
bias
=
p
.
add_parameter
(
"bias"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
2
*
hs
}});
auto
ih
=
p
.
add_parameter
(
"h0"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
bs
,
hs
}});
auto
out_hs
=
p
.
add_instruction
(
migraphx
::
op
::
rnn
{
hs
,
{
migraphx
::
op
::
tanh
{},
migraphx
::
op
::
sigmoid
{}},
migraphx
::
op
::
rnn
::
reverse
,
clip
},
seq
,
w
,
r
,
bias
,
ih
);
p
.
add_instruction
(
migraphx
::
op
::
rnn_last_output
{},
out_hs
);
auto
prog
=
migraphx
::
parse_onnx
(
"onnx_rnn_5args.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
flatten_test
)
TEST_CASE
(
flatten_test
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment