Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
61775eab
Commit
61775eab
authored
Nov 16, 2023
by
Umang Yadav
Browse files
Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_fp8
parents
a5c38ebe
e7e5ba23
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3274 additions
and
410 deletions
+3274
-410
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+19
-0
src/onnx/parse_lstm.cpp
src/onnx/parse_lstm.cpp
+47
-0
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+103
-35
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+171
-0
test/onnx/lstm_bi_layout_cell_test.onnx
test/onnx/lstm_bi_layout_cell_test.onnx
+0
-0
test/onnx/lstm_bi_layout_last_test.onnx
test/onnx/lstm_bi_layout_last_test.onnx
+0
-0
test/onnx/lstm_f_layout_cell_test.onnx
test/onnx/lstm_f_layout_cell_test.onnx
+0
-0
test/onnx/lstm_f_layout_hs_test.onnx
test/onnx/lstm_f_layout_hs_test.onnx
+0
-0
test/onnx/lstm_r_layout_hs_cell_test.onnx
test/onnx/lstm_r_layout_hs_cell_test.onnx
+0
-0
test/onnx/lstm_r_layout_test.onnx
test/onnx/lstm_r_layout_test.onnx
+0
-0
test/onnx/onnx_rnn_test.cpp
test/onnx/onnx_rnn_test.cpp
+332
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/quantization.cpp
test/quantization.cpp
+70
-70
test/ref/rnn_ops.cpp
test/ref/rnn_ops.cpp
+1727
-234
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+363
-70
test/verify/test_lstm_bidirct_3args_layout.cpp
test/verify/test_lstm_bidirct_3args_layout.cpp
+77
-0
test/verify/test_lstm_bidirct_last_layout.cpp
test/verify/test_lstm_bidirct_last_layout.cpp
+95
-0
test/verify/test_lstm_forward_hs_layout.cpp
test/verify/test_lstm_forward_hs_layout.cpp
+95
-0
test/verify/test_lstm_forward_last_layout.cpp
test/verify/test_lstm_forward_last_layout.cpp
+97
-0
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
+78
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
61775eab
...
...
@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_PRED_MATCHER
(
has_same_value
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
bool
all_same
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
s
)
{
all_same
=
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
const
auto
&
scale
)
{
return
float_equal
(
scale
,
s
.
front
());
});
});
return
all_same
;
}
MIGRAPHX_BASIC_MATCHER
(
output
,
const
matcher_context
&
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
...
...
@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
...
Ms
>
auto
skip_broadcasts_transposes_contiguous
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"transpose"
))(
ms
...);
}
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
...
...
src/onnx/parse_lstm.cpp
View file @
61775eab
...
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
void
lstm_transpose_inputs
(
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
0
]);
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_undefined
())
{
args
[
5
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
5
]);
}
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_undefined
())
{
args
[
6
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
6
]);
}
}
void
lstm_transpose_outputs
(
onnx_parser
::
node_info
&
info
,
instruction_ref
&
hidden_states
,
instruction_ref
&
last_output
,
instruction_ref
&
last_cell_output
)
{
std
::
vector
<
int64_t
>
perm_hs
{
2
,
0
,
1
,
3
};
hidden_states
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hs
}}),
hidden_states
);
std
::
vector
<
int64_t
>
perm_last
{
1
,
0
,
2
};
last_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_output
);
last_cell_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_cell_output
);
}
struct
parse_lstm
:
op_parser
<
parse_lstm
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
...
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
int
layout
=
0
;
if
(
contains
(
info
.
attributes
,
"layout"
))
{
layout
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"layout"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
{
...
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
if
(
layout
!=
0
)
{
lstm_transpose_inputs
(
info
,
args
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto
last_cell_output
=
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
if
(
layout
!=
0
)
{
lstm_transpose_outputs
(
info
,
hidden_states
,
last_output
,
last_cell_output
);
}
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
};
...
...
src/simplify_qdq.cpp
View file @
61775eab
...
...
@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return
s
;
}
MIGRAPHX_PRED_MATCHER
(
has_same_value
,
instruction_ref
ins
)
struct
match_find_quantizable_ops
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
bool
all_same
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
s
)
{
all_same
=
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
const
auto
&
scale
)
{
return
float_equal
(
scale
,
s
.
front
());
static
bool
is_valid_scale
(
instruction_ref
scale
,
std
::
vector
<
std
::
size_t
>
lens
,
std
::
size_t
axis
)
{
return
scale
->
get_shape
().
scalar
()
or
scale
->
get_shape
().
elements
()
==
lens
.
at
(
axis
);
}
static
bool
is_valid_zero_point
(
instruction_ref
zp
)
{
if
(
not
zp
->
can_eval
())
return
false
;
bool
all_zeros
=
false
;
zp
->
eval
().
visit
([
&
](
auto
z
)
{
all_zeros
=
std
::
all_of
(
z
.
begin
(),
z
.
end
(),
[
&
](
auto
val
)
{
return
float_equal
(
val
,
0
);
});
});
});
return
all_same
;
}
return
all_zeros
;
}
struct
match_find_quantizable_ops
{
static
auto
scale_broadcast_op
(
instruction_ref
scale
,
std
::
vector
<
std
::
size_t
>
lens
,
std
::
size_t
axis
)
{
if
(
scale
->
get_shape
().
scalar
())
{
return
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}});
}
else
{
return
migraphx
::
make_op
(
"broadcast"
,
{{
"out_lens"
,
lens
},
{
"axis"
,
axis
}});
}
}
static
auto
dequantizelinear_op
(
const
std
::
string
&
name
,
const
std
::
string
&
scale
)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
next_ins
=
dqins
;
while
(
next_ins
!=
qop
)
{
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
{
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
}
next_ins
=
next_ins
->
outputs
().
front
();
}
return
qinp
;
}
static
auto
dequantizelinear_op
(
const
std
::
string
&
scale
,
const
std
::
string
&
zp
)
{
return
match
::
name
(
"dequantizelinear"
)(
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"quantizelinear"
))(
match
::
any
()
.
bind
(
name
)
)),
match
::
arg
(
1
)(
match
::
skip_broadcasts
(
has_same_value
().
bind
(
scale
))),
match
::
arg
(
2
)(
match
::
skip_broadcasts
(
match
::
all_of
(
match
::
has_value
(
0
)
))));
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"quantizelinear"
))(
match
::
any
())),
match
::
arg
(
1
)(
match
::
skip_broadcasts
(
match
::
is_constant
().
bind
(
scale
))),
match
::
arg
(
2
)(
match
::
skip_broadcasts
(
match
::
is_constant
().
bind
(
zp
))));
}
auto
matcher
()
const
{
return
match
::
name
(
get_quantizable_op_names
())(
match
::
arg
(
0
)(
dequantizelinear_op
(
"x1"
,
"scale1"
)),
match
::
arg
(
1
)(
dequantizelinear_op
(
"x2"
,
"scale2"
)));
match
::
arg
(
0
)(
match
::
skip_broadcasts_transposes_contiguous
(
dequantizelinear_op
(
"scale1"
,
"zp1"
).
bind
(
"dq1"
))),
match
::
arg
(
1
)(
match
::
skip_broadcasts_transposes_contiguous
(
dequantizelinear_op
(
"scale2"
,
"zp2"
).
bind
(
"dq2"
))));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
qop
=
r
.
result
;
auto
q1
=
r
.
instructions
[
"
x
1"
];
auto
q2
=
r
.
instructions
[
"
x
2"
];
auto
d
q1
=
r
.
instructions
[
"
dq
1"
];
auto
d
q2
=
r
.
instructions
[
"
dq
2"
];
auto
scale1
=
r
.
instructions
[
"scale1"
];
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 type currently supported
if
(
q1
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
q2
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
if
(
d
q1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
d
q2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
return
;
double
scale
;
visit_all
(
scale1
->
get_literal
(),
scale2
->
get_literal
(
))
(
[
&
](
const
auto
s1
,
const
auto
s2
)
{
scale
=
s1
.
front
()
*
s2
.
front
();
})
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if
(
not
(
is_valid_zero_point
(
zp1
)
and
is_valid_zero_point
(
zp2
))
)
return
;
// Only support scalar and 1D scales
if
(
scale1
->
get_shape
().
lens
().
size
()
!=
1
or
scale2
->
get_shape
().
lens
().
size
()
!=
1
)
return
;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
q1
;
qop_args
.
at
(
1
)
=
q2
;
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
)
;
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
)
;
instruction_ref
dq
;
instruction_ref
dq
_scale
;
instruction_ref
out
_scale
;
instruction_ref
zero_point
;
if
(
qop
->
name
()
==
"convolution"
)
{
auto
conv_val
=
qop
->
get_operator
().
to_value
();
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_convolution"
,
conv_val
),
qop_args
);
auto
out_lens
=
dq
->
get_shape
().
lens
();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if
(
not
(
is_valid_scale
(
scale1
,
out_lens
,
1
)
and
is_valid_scale
(
scale2
,
out_lens
,
1
)))
return
;
auto
s1_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale1
,
out_lens
,
1
),
scale1
);
auto
s2_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale2
,
out_lens
,
1
),
scale2
);
out_scale
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"mul"
),
s1_bcast
,
s2_bcast
);
}
else
if
(
qop
->
name
()
==
"dot"
)
{
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_dot"
),
qop_args
);
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_dot"
),
qop_args
);
auto
out_lens
=
dq
->
get_shape
().
lens
();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if
(
not
(
is_valid_scale
(
scale1
,
out_lens
,
out_lens
.
size
()
-
2
)
and
is_valid_scale
(
scale2
,
out_lens
,
out_lens
.
size
()
-
1
)))
return
;
auto
s1_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale1
,
out_lens
,
out_lens
.
size
()
-
2
),
scale1
);
auto
s2_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale2
,
out_lens
,
out_lens
.
size
()
-
1
),
scale2
);
out_scale
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"mul"
),
s1_bcast
,
s2_bcast
);
}
auto
ins_type
=
qop
->
get_shape
().
type
();
dq_scale
=
m
.
add_literal
(
literal
({
ins_type
},
{
scale
}));
auto
lens
=
dq
->
get_shape
().
lens
();
auto
scale_mb
=
m
.
insert_instruction
(
qop
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
dq_scale
);
dq
=
m
.
insert_instruction
(
qop
,
make_op
(
"dequantizelinear"
),
dq
,
scale_mb
);
dq
=
m
.
insert_instruction
(
qop
,
make_op
(
"dequantizelinear"
),
dq
,
out_scale
);
m
.
replace_instruction
(
qop
,
dq
);
}
};
...
...
test/onnx/gen_onnx.py
View file @
61775eab
...
...
@@ -4484,6 +4484,177 @@ def lrn_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
lstm_bi_layout_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
2
,
60
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
''
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'bidirectional'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
cellout
])
@
onnx_test
()
def
lstm_bi_layout_last_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
2
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
2
,
20
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
,
'output'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'bidirectional'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
,
output
])
@
onnx_test
()
def
lstm_f_layout_hs_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
1
,
20
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
,
'output'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'forward'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
,
output
])
@
onnx_test
()
def
lstm_f_layout_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
''
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'forward'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
cellout
])
@
onnx_test
()
def
lstm_r_layout_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'reverse'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
])
@
onnx_test
()
def
lstm_r_layout_hs_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
'output'
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'reverse'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
output
,
cellout
])
@
onnx_test
()
def
matmul_bmbm_test
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
FLOAT
,
[
3
,
6
,
7
])
...
...
test/onnx/lstm_bi_layout_cell_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_bi_layout_last_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_f_layout_cell_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_f_layout_hs_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_r_layout_hs_cell_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_r_layout_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/onnx_rnn_test.cpp
View file @
61775eab
...
...
@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE
(
lstm_forward_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
1
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 8 args, hs and last output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
auto
prog
=
optimize_onnx
(
"lstm_f_layout_hs_test.onnx"
);
EXPECT
(
p
==
prog
);
}
// 8 args, cell output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_f_layout_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
// activation functions
TEST_CASE
(
lstm_forward_actv_func
)
{
...
...
@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
}
}
TEST_CASE
(
lstm_reverse_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
1
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 8 args, hs output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
auto
prog
=
optimize_onnx
(
"lstm_r_layout_test.onnx"
);
EXPECT
(
p
==
prog
);
}
// 8 args, last and cell output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_r_layout_hs_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
lstm_bidirectional
)
{
std
::
size_t
sl
=
5
;
// sequence len
...
...
@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
}
}
TEST_CASE
(
lstm_bidirectional_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
2
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 0 activation function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
auto
prog
=
optimize_onnx
(
"lstm_bi_layout_last_test.onnx"
);
EXPECT
(
p
==
prog
);
}
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_bi_layout_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
lstm_bi_actv_funcs
)
{
std
::
size_t
sl
=
5
;
// sequence len
...
...
test/py/onnx_backend_test.py
View file @
61775eab
...
...
@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_gru_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_lstm_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_simple_rnn_batchwise_cpu'
)
# from OnnxBackendPyTorchConvertedModelTest
backend_test
.
exclude
(
r
'test_MaxPool1d_stride_padding_dilation_cpu'
)
...
...
test/quantization.cpp
View file @
61775eab
...
...
@@ -636,13 +636,12 @@ TEST_CASE(dot_float)
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale
);
auto
zp_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
auto
dc
=
mm
->
add_literal
(
100.0
f
);
auto
mdc
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sc
.
lens
()}}),
dc
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
mdc
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
scale_a
=
mm
->
add_literal
(
10.0
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
);
auto
scale_a
_lit
=
mm
->
add_literal
(
10.0
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
_lit
);
auto
zp_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
scale_b
=
mm
->
add_literal
(
5.0
);
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
scale_b
_lit
=
mm
->
add_literal
(
5.0
);
auto
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
_lit
);
auto
zp_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
scale
=
mm
->
add_literal
(
50.0
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
scale
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
scale_a_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_a_lit
);
auto
scale_b_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_b_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_a_mb
,
scale_b_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg)
migraphx
::
shape
sa
{
migraphx
::
shape
::
half_type
,
{
9
,
9
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
sa
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
dq_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
100.0
}));
dq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
dq_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
dq_scale
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale
,
scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -851,10 +851,10 @@ TEST_CASE(conv_float)
auto
px
=
mm
->
add_parameter
(
"x"
,
sx
);
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
10.0
f
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
_lit
=
mm
->
add_literal
(
10.0
f
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
...
...
@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
migraphx
::
shape
sc
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
migraphx
::
shape
s_scale
{
migraphx
::
shape
::
float_type
,
sc
.
lens
()};
auto
d_scale
=
mm
->
add_literal
(
100.0
f
);
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -930,20 +928,21 @@ TEST_CASE(conv_half)
auto
px
=
mm
->
add_parameter
(
"x"
,
sx
);
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
auto
quant_w
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pw
,
scale
,
zp
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
auto
d_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
100.0
}));
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph)
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
s1
);
auto
zpb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
zp1
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
s
o
=
then_mod
->
add_
literal
(
100.0
f
);
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
s
1_mb
=
then_mod
->
add_
instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
s1
);
auto
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
s1_mb
,
s1_mb
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
then_mod
->
add_return
({
r
});
migraphx
::
shape
sd
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
4
,
6
}};
...
...
@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph)
auto
w
=
mm
->
add_parameter
(
"w"
,
sw
);
// else submod
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
sax
=
else_mod
->
add_literal
(
2.0
f
);
auto
sax
_lit
=
else_mod
->
add_literal
(
2.0
f
);
auto
zp
=
else_mod
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
);
auto
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
_lit
);
auto
zpx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
zp
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
ssw
=
else_mod
->
add_literal
(
1.66667
f
);
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
ssw
_lit
=
else_mod
->
add_literal
(
1.66667
f
);
auto
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
_lit
);
auto
zpw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
zp
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
so1
=
else_mod
->
add_literal
(
3.33333
f
);
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so1
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
ssw_mb
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qconv
->
get_shape
().
lens
()}}),
ssw_lit
);
auto
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sax
,
ssw_mb
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
else_mod
->
add_return
({
r1
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
...
...
test/ref/rnn_ops.cpp
View file @
61775eab
...
...
@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE(lstm_forward_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, hidden state concatenation as output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.0417273, -0.272355, 0.206765, 0.223879, 0.0742487, -0.0800085, 0.259897,
0.0670196, -0.00532985, 0.0440265, 0.29654, -0.0463156, -0.0847427, 0.0874114,
0.304256, -0.0585745, 0.138193, -0.0322939, -0.0891815, 0.15773, 0.184266,
0.0610048, -0.138041, 0.0963885, 0.0498799, 0.125772, 0.0533032, -0.131413,
-0.0223018, 0.131113, 0.135643, -0.056620, 0.19139, -0.127708, -0.409371,
-0.136186, 0.0213755, -0.146027, -0.0324509, -0.0620429, 0.0988431, -0.018085,
-0.159434, 0.030266, 0.142701, 0.0342236, -0.198664, 0.0702607};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, last_cell_output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
und);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.111454,
0.247794,
0.471087,
-0.220574,
-0.048196,
0.263184,
0.283258,
-0.14882,
0.605585,
0.078598,
-0.64457,
0.119811};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_forward_more)
{
std::size_t batch_size = 3;
...
...
@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more)
}
}
TEST_CASE
(
lstm_
reverse
)
TEST_CASE(lstm_
forward_more_layout
)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
...
...
@@ -3527,32 +3785,668 @@ TEST_CASE(lstm_reverse)
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-
0.
2763
,
-
0.
4715
,
-
0.
3010
,
-
0.
2306
,
-
0.
2283
,
-
0.2
656
,
0.2
035
,
0.3
570
,
-
0.1499
,
0.4390
,
-
0.
1843
,
0.23
51
,
0.
3357
,
0.
1217
,
0.
1401
,
0.3
300
,
-
0.0429
,
0.3266
,
0.4
834
,
-
0.
3914
,
-
0.
1480
,
0.373
4
,
-
0.03
7
2
,
-
0.
174
6
,
0.0
550
,
0.4
17
7
,
-
0.
1332
,
0.4391
,
-
0.
3287
,
-
0.
4401
,
0.
1486
,
0.1346
,
0.
1048
,
-
0.4361
,
0.0886
,
-
0.
3840
,
-
0.
2730
,
-
0.1710
,
0.3274
,
0.
0169
,
-
0.
4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4
150
,
-
0.4684
,
-
0.
2522
};
0.
1236
,
-0.
3942
, 0.
4149
,
0.
0795
,
0.
4934,
-0.2
858
, 0.2
602
,
-
0.3
098, 0.0567, 0.3344
,
0.
3607, -0.05
51, 0.
4952
, 0.
3799
, 0.
0630
,
-
0.3
532, 0.0023, -0.0592
, 0.4
267
, 0.
2382
,
-0.
078
4, -0.
0
032, -0.
247
6,
-
0.0
206, -0.4963,
0.4
83
7, 0.
0827, 0.0123
, -0.
1203
, -0.
0279
,
-
0.
0049, 0.4721
,
-
0.
3564, -0.1286, 0.4090
, -0.
0504
, 0.
0575, -0.2138, 0.1071
, 0.
1976
,
-0.
0758, 0.0139, -0.0761, 0.3991, -0.2965
,
-
0.4
845, -0.1496
, 0.
3285
};
std::vector<float> r_data{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260};
std::vector<float> bias_data{0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881,
-0.4407, -0.2760, 0.1274, -0.0083, -0.2885, 0.3949, -0.0182,
0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067, 0.0807,
0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316,
-0.3025, 0.3637, -0.3181, -0.4655};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104,
-1.9004,
0.3337,
0.5741,
0.5671,
0.0458,
0.4514,
-0.8968,
-0.9201,
0.1962,
0.5771,
-0.5332};
std::vector<float> ic_data{0.9569,
-0.5981,
1.1312,
1.0945,
1.1055,
-0.1212,
-0.9097,
0.7831,
-1.6991,
-1.9498,
-1.2567,
-0.4114};
std::vector<float> pph_data{1.84369764,
0.68413646,
-0.44892886,
-1.50904413,
0.3860796,
-0.52186625,
1.08474445,
-1.80867321,
1.32594529,
0.4336262,
-0.83699064,
0.49162736};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// forward, 3 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.0786602, -0.0613048, 0.179592,
-0.071286, -0.102509, -0.0372696, 0.252296, -0.144544, -0.165194, -0.0372928,
0.273786, -0.100877, 0.0319021, -0.00298698, -0.0623361, 0.0598866, 0.074206,
0.0124086, -0.139544, 0.108016, 0.00496085, 0.0662588, -0.048577, -0.187329,
-0.0458544, -0.0401315, 0.0737483, -0.064505, 0.101585, 0.0687269, -0.161725,
-0.25617, -0.00973633, -0.0552699, 0.0252681, -0.0562072, 0.0855831, -0.0171894,
-0.140202, 0.0828391, 0.136898, 0.00160891, -0.184812, 0.147774};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// forward, 8 args
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, 0.186991, -0.0624168, 0.205513,
0.0836373, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.058052, 0.0795391,
0.266617, -0.0128746, 0.294074, -0.0319677, -0.0955337, 0.104168, 0.421857,
0.0459771, -0.144955, 0.0720673, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.022618, -0.121195, -0.4065,
-0.252054, -0.0300906, -0.0890598, -0.135266, -0.0413375, 0.103489, 0.0142918,
-0.123408, 0.0401075, 0.187761, 0.0501726, -0.121584, 0.0606723};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
// forward, last_output as program output, sequence length shorter
// than max_seq_len
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
und);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto last_hs = p.eval({}).back();
std::vector<float> output_data;
last_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.0847427,
0.0874114,
0.304256,
-0.0585745,
-0.0223018,
0.131113,
0.135643,
-0.0566208,
0.142701,
0.0342236,
-0.198664,
0.0702607};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// seq_len = 1
{
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::forward)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> hs_data;
hs_concat.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{0.079753,
-0.289854,
0.160043,
0.115056,
0.294074,
-0.0319677,
-0.0955337,
0.104168,
0.022618,
-0.121195,
-0.4065,
-0.252054};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
}
}
TEST_CASE(lstm_reverse)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
0.6091,
1.6462,
0.8720,
0.5349,
-0.1962,
-1.7416,
-0.9912,
1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{-0.8323,
0.3998,
0.1831,
0.5938,
2.7096,
-0.1790,
0.0022,
-0.8040,
0.1578,
0.0567,
0.8069,
-0.5141};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {2, batch_size, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, 0.789732, 0.128538, 0.20909,
0.0553812, -0.224905, 0.32421, 0.344048, 0.271694, -0.175114, -0.00543549,
0.178681, -0.266999, 0.928866, 0.113685, 0.220626, -0.0432316, -0.063456,
0.148524, 0.05108, -0.0234895, -0.182201, -0.0232277, 0.235501, -0.213485,
0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356, 0.146353,
0.0789186, -0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676,
0.044508, -0.373961, -0.0681467, 0.382748, 0.230211, -0.161537, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// variable sequence lengths
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.126517, 0.0359124, 0.107453, -0.0617278, 0.911307, 0.11468, 0.114449,
0.0196755, -0.102969, 0.295872, 0.515859, 0.246501, -0.168327, 0.00023761,
0.167567, -0.0621982, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, -0.204545, 0.0146403, 0.210057, 0.0296268,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 0 actv function
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func", {}},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.443077,
-0.325425,
-0.249367,
-0.270812,
0.122913,
0.118537,
0.0370199,
-0.0164687,
-0.00754759,
0.141613,
0.348002,
0.667298};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_reverse_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 1;
std::vector<float> w_data{
-0.2763, -0.4715, -0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390,
-0.1843, 0.2351, 0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914,
-0.1480, 0.3734, -0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401,
0.1486, 0.1346, 0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169,
-0.4462, 0.0729, 0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
-0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116, 0.2824, 0.4775, -0.2729, -0.4707,
0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687, 0.3794, -0.1069, -0.3049, 0.1430,
-0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425, 0.2891, 0.1786, -0.3274, 0.2365,
0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154, 0.2851, -0.3930, -0.1174, 0.4360,
0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459, -0.2991, -0.2624, 0.4194, -0.3291,
-0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584, 0.2596, 0.2871, -0.3509, -0.1910,
0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{-0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632,
-0.1823, 0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896,
0.4755, -0.4794, 0.2167, -0.4474, -0.3139, 0.1018, 0.4470,
-0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
4313
,
-
0.9730
,
-
0.2005
,
2
.3
930
,
-
0.5221
,
-
0.1331
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.112
2
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.
07
01
,
-
0.4100
,
-
2.2344
,
0.
3685
,
0.4583
,
2.3794
,
1.0
3
72
,
-
0.8887
,
0.
789
2
,
-
0.4
012
,
-
0.2818
,
-
2.3374
,
1.5310
};
-0.5516, 0.2391, -1.6951, -0.
0910, 1.2122, -0.1952
,
0
.3
577, 1.3508, -0.5366
,
0.
4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.133
2,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892
, -0.
4
01
2
,
2.3930, -0.5221
,
-
0.
1331
,
-
1.0
9
72,
0.9816
, 0.
112
2, -0.4
100, -2.2344, 0.3685,
-0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.5289,
1.0986,
...
...
@@ -3593,14 +4487,15 @@ TEST_CASE(lstm_reverse)
-0.5811,
0.2166};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size,
seq_len,
input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx::shape ih_shape{migraphx::shape::float_type, {
batch_size, num_dirct
, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {
batch_size, num_dirct
, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
// reverse, concatenation of hidden states as program output
{
migraphx::program p;
...
...
@@ -3614,7 +4509,13 @@ TEST_CASE(lstm_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm
->
add_instruction
(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
...
...
@@ -3633,18 +4534,21 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
-0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681,
-0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845,
0.177273, -0.0774616, 0.789732, 0.128538, 0.20909, 0.0553812, 0.928866,
0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741, 0.130438,
0.946669, 0.0868676, 0.044508, -0.373961, -0.224905, 0.32421, 0.344048,
0.271694, -0.063456, 0.148524, 0.05108, -0.0234895, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -3661,14 +4565,20 @@ TEST_CASE(lstm_reverse)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
2
,
batch_size
,
input_size
}};
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size,
2,
input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
seq_orig
,
seq_p
);
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis",
1
}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
mm
->
add_instruction
(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
...
...
@@ -3687,22 +4597,26 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
};
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.120174, 0.043157, 0.117138, -0.222188, -0.175114, -0.00543549, 0.178681,
-0.266999, -0.182201, -0.0232277, 0.235501, -0.213485, -0.185038, -0.026845,
0.177273, -0.0774616, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.789732, 0.128538, 0.20909, 0.0553812,
0.928866, 0.113685, 0.220626, -0.0432316, 0.960938, 0.133565, 0.269741,
0.130438, 0.946669, 0.0868676, 0.044508, -0.373961, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.224905,
0.32421, 0.344048, 0.271694, -0.063456, 0.148524, 0.05108, -0.0234895,
-0.0252804, 0.267356, 0.146353, 0.0789186, -0.0681467, 0.382748, 0.230211,
-0.161537, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -3722,7 +4636,13 @@ TEST_CASE(lstm_reverse)
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data{3, 2, 1};
auto sql = mm->add_literal(seq_len_s, len_data);
mm
->
add_instruction
(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
...
...
@@ -3741,18 +4661,22 @@ TEST_CASE(lstm_reverse)
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
-
0.168327
,
0.00023761
,
0.167567
,
-
0.0621982
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
-0.126517, 0.0359124, 0.107453, -0.0617278, -0.168327, 0.00023761, 0.167567,
-0.0621982, -0.204545, 0.0146403, 0.210057, 0.0296268, 0, 0,
0, 0, 0.911307, 0.11468, 0.114449, 0.0196755, 0.96657,
0.0755112, 0.0620917, -0.264845, 0, 0, 0, 0,
0, 0, 0, 0, -0.102969, 0.295872, 0.515859,
0.246501, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -3763,6 +4687,10 @@ TEST_CASE(lstm_reverse)
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
...
...
@@ -3777,46 +4705,8 @@ TEST_CASE(lstm_reverse)
seq,
w,
r);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, 0 actv function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{}},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
...
...
@@ -3900,19 +4790,214 @@ TEST_CASE(lstm_reverse_actv)
0.8069,
-0.5141};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
std::vector<float> pph_data{-0.8271,
-0.5683,
0.4562,
-1.2545,
1.2729,
-0.4082,
-0.4392,
-0.9406,
0.7794,
1.8194,
-0.5811,
0.2166};
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float clip = 0.0f;
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(
std::vector<migraphx::operation>{migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078, 0.199709, 0.303753, 0.301178, 0.264634, 0.304661, 0.349371, 0.288934,
0.405483, 0.445586, 0.515814, 0.473186, 0.301937, 0.264893, 0.254353, 0.269231,
0.359258, 0.400097, 0.288884, 0.247329, 0.276519, 0.264249, 0.1769, 0.23213,
0.310306, 0.262902, 0.276964, 0.295002, 0.373802, 0.366785, 0.419791, 0.393216,
0.262827, 0.371441, 0.369022, 0.298262, 0.334143, 0.309444, 0.174822, 0.251634,
0.244564, 0.214386, 0.185994, 0.226699, 0.28445, 0.376092, 0.338326, 0.259502};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{
migraphx::make_op("tanh"), migraphx::make_op("sigmoid")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.132123,
-0.37531,
-0.12943,
-0.00798307,
-0.133882,
-0.0251383,
0.0486486,
-0.0220606,
0.292495,
0.233866,
0.48646,
0.481844};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
seq_len = 1;
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::reverse)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{-0.104351,
-0.0471426,
-0.0905753,
0.01506,
0.059797,
0.104239,
-0.0266768,
0.0727547,
-0.146298,
0.070535,
0.327809,
0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE(lstm_bidirectional)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331,
-0.0910, 1.2122, -0.1952, 0.4661, 0.6494, 2.1332, -1.0972, 0.9816, 0.1122,
0.3577, 1.3508, -0.5366, 1.7449, 0.5483, -0.0701, -0.4100, -2.2344, 0.3685,
0.4583, 2.3794, 1.0372, -0.8887, 0.7892, -0.4012, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 0.5671, 0.0458,
0.4514, -0.8968, -0.9201, 0.1962, 0.5771, -0.5332,
1.5289, 1.0986, 0.6091, 1.6462, 0.8720, 0.5349,
-0.1962, -1.7416, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, 1.1055, -0.1212,
-0.9097, 0.7831, -1.6991, -1.9498, -1.2567, -0.4114,
-0.8323, 0.3998, 0.1831, 0.5938, 2.7096, -0.1790,
0.0022, -0.8040, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
...
...
@@ -3920,95 +5005,200 @@ TEST_CASE(lstm_reverse_actv)
migraphx::shape ih_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {num_dirct, batch_size, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
float
clip
=
0.0
f
;
// concatenation of hidden states as program output
{
migraphx::program p;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r
);
r,
bias,
und,
ih,
ic,
pph);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.246078
,
0.199709
,
0.303753
,
0.301178
,
0.264634
,
0.304661
,
0.349371
,
0.288934
,
0.405483
,
0.445586
,
0.515814
,
0.473186
,
0.301937
,
0.264893
,
0.254353
,
0.269231
,
0.359258
,
0.400097
,
0.288884
,
0.247329
,
0.276519
,
0.264249
,
0.1769
,
0.23213
,
0.310306
,
0.262902
,
0.276964
,
0.295002
,
0.373802
,
0.366785
,
0.419791
,
0.393216
,
0.262827
,
0.371441
,
0.369022
,
0.298262
,
0.334143
,
0.309444
,
0.174822
,
0.251634
,
0.244564
,
0.214386
,
0.185994
,
0.226699
,
0.28445
,
0.376092
,
0.338326
,
0.259502
};
0.079753, -0.289854, 0.160043, 0.115056, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.022618, -0.121195, -0.4065, -0.252054, -0.120174, 0.043157,
0.117138, -0.222188, 0.789732, 0.128538, 0.20909, 0.0553812, -0.224905,
0.32421, 0.344048, 0.271694, 0.186991, -0.0624168, 0.205513, 0.0836373,
0.421857, 0.0459771, -0.144955, 0.0720673, -0.0300906, -0.0890598, -0.135266,
-0.0413375, -0.175114, -0.00543549, 0.178681, -0.266999, 0.928866, 0.113685,
0.220626, -0.0432316, -0.063456, 0.148524, 0.05108, -0.0234895, 0.0459032,
0.0414126, 0.272303, 0.0393149, 0.218258, 0.0944405, 0.0431211, -0.132394,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.182201, -0.0232277, 0.235501,
-0.213485, 0.960938, 0.133565, 0.269741, 0.130438, -0.0252804, 0.267356,
0.146353, 0.0789186, -0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878,
0.0971544, 0.149294, -0.0492549, 0.187761, 0.0501726, -0.121584, 0.0606723,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.946669, 0.0868676, 0.044508,
-0.373961, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// reverse, 3 args, 2 actv functions
// last hidden state as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, 0.0309878, 0.0971544, 0.149294, -0.0492549,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.120174, 0.043157, 0.117138, -0.222188,
0.789732, 0.128538, 0.20909, 0.0553812, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
und,
ih,
ic,
pph);
mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, 0.0685243, 0.20465, 0.277867, -0.112934,
0.67312, 0.120508, -0.726968, 0.113845, -0.889294, 0.182463, 0.186512, -0.402334,
1.48161, 0.524116, 0.347113, 0.181813, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
// 3 args, concatenation of hidden states as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std
::
vector
<
float
>
output_data_gold
{
-
0.132123
,
-
0.37531
,
-
0.12943
,
-
0.00798307
,
-
0.133882
,
-
0.0251383
,
0.0486486
,
-
0.0220606
,
0.292495
,
0.233866
,
0.48646
,
0.481844
};
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698, -0.0623361,
0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617, -0.162851, -0.102647,
-0.113827, -0.142818, 0.0513685, 0.0547876, 0.0201981, -0.00808453, -0.00520328,
0.0945081, 0.264123, 0.410805, -0.0786602, -0.0613048, 0.179592, -0.071286,
0.074206, 0.0124086, -0.139544, 0.108016, -0.00973633, -0.0552699, 0.0252681,
-0.0562072, -0.123496, -0.153616, -0.032874, -0.195349, 0.0192675, -0.108636,
0.098927, -0.140733, 0.162602, 0.0143099, -0.0455534, 0.0151574, -0.102509,
-0.0372696, 0.252296, -0.144544, 0.00496085, 0.0662588, -0.048577, -0.187329,
0.0855831, -0.0171894, -0.140202, 0.0828391, -0.1073, -0.150145, 0.015065,
-0.192699, -0.112764, -0.120496, 0.155754, 0.148256, 0.208491, 0.348432,
0.0291103, 0.230275, -0.165194, -0.0372928, 0.273786, -0.100877, -0.0458544,
-0.0401315, 0.0737483, -0.064505, 0.136898, 0.00160891, -0.184812, 0.147774,
-0.021205, -0.125423, 0.0206439, -0.187097, -0.0051453, -0.0767618, -0.0735348,
-0.0826436, 0.214159, 0.262295, 0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
//
reverse, 3 args, seq_len =
1, con
ca
tenation of hidden state
s
as program output
//
sequence length is
1, contenation of hidden state as program output
{
seq_len
=
1
;
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx::program p;
auto* mm = p.get_main_module();
seq_len = 1;
migraphx::shape in_shape1{migraphx::shape::float_type, {seq_len, batch_size, input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm->add_instruction(
migraphx::make_op(
"lstm",
...
...
@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv)
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::
bidirectional
)},
{"clip", clip},
{"input_forget", 0}}),
seq,
...
...
@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv)
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std
::
vector
<
float
>
output_data_gold
{
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
std::vector<float> output_data_gold{
-0.0327039, -0.0543852, 0.114378, -0.0768855, 0.0319021, -0.00298698,
-0.0623361, 0.0598866, 0.101585, 0.0687269, -0.161725, -0.25617,
-0.104351, -0.0471426, -0.0905753, 0.01506, 0.059797, 0.104239,
-0.0266768, 0.0727547, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
TEST_CASE
(
lstm_bidirectional
)
TEST_CASE(lstm_bidirectional
_layout
)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
...
...
@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional)
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
4313
,
-
0.9730
,
-
0.2005
,
2
.3
930
,
-
0.5221
,
-
0.1331
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.112
2
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.
07
01
,
-
0.4100
,
-
2.2344
,
0.
3685
,
0.4583
,
2.3794
,
1.0
3
72
,
-
0.8887
,
0.
789
2
,
-
0.4
012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.
4514
,
-
0.
8968
,
-
0.
9201
,
0.1962
,
0.5771
,
-
0.5332
,
1.5289
,
1.0986
,
0.
6091
,
1.
6462
,
0.
87
20
,
0.
5349
,
-
0.
1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.
9097
,
0.
7831
,
-
1.
6991
,
-
1.9498
,
-
1.256
7
,
-
0.
4114
,
-
0.8323
,
0.3998
,
0.1831
,
0.
5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.
8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
-0.5516, 0.2391, -1.6951, -0.
0910, 1.2122, -0.1952
,
0
.3
577, 1.3508, -0.5366
,
0.
4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.133
2,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892
, -0.
4
01
2
,
2.3930, -0.5221
,
-
0.
1331
,
-
1.0
9
72,
0.9816
, 0.
112
2, -0.4
100, -2.2344, 0.3685,
-0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741,
1.5289, 1.0986
,
0.
6091, 1.6462
, 0.
5671
,
0.
0458, 0.4514, -0.8968
,
0.8720, 0.5349
,
-
0.
1962
,
-
1.
7416
,
-
0.
9
20
1
, 0.
1962
,
0.
5771, -0.5332
, -0.9912, 1.2831,
1.0896,
-0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945,
-0.8323, 0.3998
,
0.
1831
,
0.
5938
, 1.
1055, -0.1212, -0.909
7, 0.
7831
,
2.7096, -0.1790, 0.0022
,
-
0.
8040, -1.6991, -1.9498
,
-1.2567
, -0.
4114
, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
...
...
@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional)
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size,
seq_len,
input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx::shape ih_shape{migraphx::shape::float_type, {
batch_size, num_dirct
, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {
batch_size, num_dirct
, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
...
...
@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
mm
->
add_instruction
(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
...
...
@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.0459032
,
0.0414126
,
0.272303
,
0.0393149
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138,
-0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549,
0.178681, -0.266999, 0.0459032, 0.0414126, 0.272303, 0.0393149, -0.182201,
-0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746,
-0.185038, -0.026845, 0.177273, -0.0774616, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.789732, 0.128538, 0.20909, 0.0553812, 0.421857, 0.0459771,
-0.144955, 0.0720673, 0.928866, 0.113685, 0.220626, -0.0432316, 0.218258,
0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565, 0.269741, 0.130438,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669, 0.0868676, 0.044508,
-0.373961, 0.022618, -0.121195, -0.4065, -0.252054, -0.224905, 0.32421,
0.344048, 0.271694, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.063456,
0.148524, 0.05108, -0.0234895, 0.103489, 0.0142918, -0.123408, 0.0401075,
-0.0252804, 0.267356, 0.146353, 0.0789186, 0.187761, 0.0501726, -0.121584,
0.0606723, -0.0681467, 0.382748, 0.230211, -0.161537};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
...
...
@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
auto last_output = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), last_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
-0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional)
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto und = mm->add_instruction(migraphx::make_op("undefined"));
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
...
...
@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional)
ih,
ic,
pph);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
auto cell_output = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), cell_output);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
-0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334,
0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813,
0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional)
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm
->
add_instruction
(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
...
...
@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional)
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.
031902
1
,
-
0.
0
02
98698
,
-
0.0623361
,
0.
0598866
,
0.
101585
,
0.0687269
,
-
0.1
61725
,
-
0.25617
,
-
0.1
62851
,
-
0.1
02647
,
-
0.
113827
,
-
0.1
42818
,
0.
0513685
,
0.054787
6
,
0.
0201981
,
-
0.00808453
,
-
0.00520328
,
0.
0945081
,
0.264123
,
0.
410805
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.
071286
,
0.0
74
20
6
,
0.0124086
,
-
0.1
39544
,
0.
108016
,
-
0.00
973633
,
-
0.0552699
,
0.025268
1
,
-
0.05
62072
,
-
0.
123496
,
-
0.
15361
6
,
-
0.0
32874
,
-
0.195349
,
0.0192675
,
-
0.
10863
6
,
0.
098927
,
-
0.1
40733
,
0.162602
,
0.0143099
,
-
0.0
455534
,
0.0151574
,
-
0.
102509
,
-
0.0
372696
,
0.252
29
6
,
-
0.1
44544
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.1
87329
,
0.0
855831
,
-
0.0
171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.
150
145
,
0.0
15065
,
-
0.
192699
,
-
0.112764
,
-
0.
120496
,
0.155754
,
0.
148
256
,
0.
208491
,
0.348432
,
0.
029110
3
,
0.
23027
5
,
-
0.
165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.
0458544
,
-
0.0
401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.0
0160
89
1
,
-
0.1
8481
2
,
0.
147774
,
-
0.
021205
,
-
0.
125423
,
0.0206439
,
-
0.187097
,
-
0.
0051453
,
-
0.0
767618
,
-
0.0735348
,
-
0.
0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
-0.0327039, -0.0543852, 0.114378,
-0.0768855,
-
0.
16285
1, -0.
1
02
647, -0.113827
,
-
0.
142818
,
-
0.
0786602, -0.0613048
,
0.1
79592, -0.071286,
-0.1
23496
, -0.1
53616
,
-0.
032874
, -0.1
95349
,
-
0.
102509, -0.037269
6, 0.
252296, -0.144544, -0.1073
,
-
0.
150145, 0.015065
,
-
0.
192699, -0.165194, -0.0372928, 0.273786
, -0.
100877
,
-
0.0
21
20
5
,
-0.125423, 0.0206439,
-0.1
87097
,
0.
0319021
, -0.00
298698, -0.062336
1,
0.05
98866
,
0.
0513685
, 0.
054787
6,
0.0
201981, -0.00808453, 0.074206
,
0.
012408
6,
-
0.
139544
, 0.1
08016, 0.0192675, -0.108636
, 0.0
98927, -0.140733
, 0.
00496085
,
0.0
662588, -0.048577, -0.1873
29, -0.1
12764, -0.120496, 0.155754
,
0.1
48256
,
-
0.0
458544
, -0.0
401315, 0.0737483, -0.064505,
-0.
005
145
3
,
-
0.0
767618, -0.0735348
,
-0.
0826436, 0.101585,
0.
0687269, -0.161725
,
-
0.256
17
,
-
0.
00520328, 0.0945081
,
0.
26412
3,
0.
41080
5, -0.
00973633, -0.0552699, 0.0252681, -0.0562072,
0.
162602
,
0.0
143099, -0.0455534, 0.0151574, 0.0855831
,
-
0.0
171
89
4
, -0.1
4020
2, 0.
0828391
,
0.
208491
,
0.
348432, 0.0291103, 0.230275,
0.
136898
,
0.0
0160891, -0.184812
,
0.
147774,
0.214159, 0.262295,
0.0247127, 0.14472};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
...
...
@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional)
migraphx::program p;
auto* mm = p.get_main_module();
seq_len = 1;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx::shape in_shape1{migraphx::shape::float_type, {batch_size,
seq_len,
input_size}};
std::vector<float> input_data1{
-0.5516, 0.2391, -1.6951, -0.4313, -0.9730, -0.2005, 2.3930, -0.5221, -0.1331};
auto seq = mm->add_literal(migraphx::literal{in_shape1, input_data1});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
mm
->
add_instruction
(
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
...
...
@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional)
seq,
w,
r);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
p.compile(migraphx::make_target("ref"));
auto hs_concat = p.eval({}).back();
std::vector<float> output_data;
hs_concat.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
-0.0327039, -0.0543852, 0.114378, -0.0768855, -0.104351, -0.0471426,
-0.0905753, 0.01506, 0.0319021, -0.00298698, -0.0623361, 0.0598866,
0.059797, 0.104239, -0.0266768, 0.0727547, 0.101585, 0.0687269,
-0.161725, -0.25617, -0.146298, 0.070535, 0.327809, 0.407388};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
}
}
...
...
@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
}
}
TEST_CASE(lstm_bidirectional_var_seq_lens_layout)
{
std::size_t batch_size = 3;
std::size_t seq_len = 4;
std::size_t hidden_size = 4;
std::size_t input_size = 3;
std::size_t num_dirct = 2;
std::vector<float> w_data{
0.1236, -0.3942, 0.4149, 0.0795, 0.4934, -0.2858, 0.2602, -0.3098, 0.0567, 0.3344,
0.3607, -0.0551, 0.4952, 0.3799, 0.0630, -0.3532, 0.0023, -0.0592, 0.4267, 0.2382,
-0.0784, -0.0032, -0.2476, -0.0206, -0.4963, 0.4837, 0.0827, 0.0123, -0.1203, -0.0279,
-0.0049, 0.4721, -0.3564, -0.1286, 0.4090, -0.0504, 0.0575, -0.2138, 0.1071, 0.1976,
-0.0758, 0.0139, -0.0761, 0.3991, -0.2965, -0.4845, -0.1496, 0.3285, -0.2763, -0.4715,
-0.3010, -0.2306, -0.2283, -0.2656, 0.2035, 0.3570, -0.1499, 0.4390, -0.1843, 0.2351,
0.3357, 0.1217, 0.1401, 0.3300, -0.0429, 0.3266, 0.4834, -0.3914, -0.1480, 0.3734,
-0.0372, -0.1746, 0.0550, 0.4177, -0.1332, 0.4391, -0.3287, -0.4401, 0.1486, 0.1346,
0.1048, -0.4361, 0.0886, -0.3840, -0.2730, -0.1710, 0.3274, 0.0169, -0.4462, 0.0729,
0.3983, -0.0669, 0.0756, 0.4150, -0.4684, -0.2522};
std::vector<float> r_data{
0.1237, 0.1229, -0.0766, -0.1144, -0.1186, 0.2922, 0.2478, 0.3159, -0.0522, 0.1685,
-0.4621, 0.1728, 0.0670, -0.2458, -0.3835, -0.4589, -0.3109, 0.4908, -0.0133, -0.1858,
-0.0590, -0.0347, -0.2353, -0.0671, -0.3812, -0.0004, -0.1432, 0.2406, 0.1033, -0.0265,
-0.3902, 0.0755, 0.3733, 0.4383, -0.3140, 0.2537, -0.1818, -0.4127, 0.3506, 0.2562,
0.2926, 0.1620, -0.4849, -0.4861, 0.4426, 0.2106, -0.0005, 0.4418, -0.2926, -0.3100,
0.1500, -0.0362, -0.3801, -0.0065, -0.0631, 0.1277, 0.2315, 0.4087, -0.3963, -0.4161,
-0.2169, -0.1344, 0.3468, -0.2260, -0.4564, -0.4432, 0.1605, 0.4387, 0.0034, 0.4116,
0.2824, 0.4775, -0.2729, -0.4707, 0.1363, 0.2218, 0.0559, 0.2828, 0.2093, 0.4687,
0.3794, -0.1069, -0.3049, 0.1430, -0.2506, 0.4644, 0.2755, -0.3645, -0.3155, 0.1425,
0.2891, 0.1786, -0.3274, 0.2365, 0.2522, -0.4312, -0.0562, -0.2748, 0.0776, -0.3154,
0.2851, -0.3930, -0.1174, 0.4360, 0.2436, 0.0164, -0.0680, 0.3403, -0.2857, -0.0459,
-0.2991, -0.2624, 0.4194, -0.3291, -0.4659, 0.3300, 0.0454, 0.4981, -0.4706, -0.4584,
0.2596, 0.2871, -0.3509, -0.1910, 0.3987, -0.1687, -0.0032, -0.1038};
std::vector<float> bias_data{
0.0088, 0.1183, 0.1642, -0.2631, -0.1330, -0.4008, 0.3881, -0.4407, -0.2760, 0.1274,
-0.0083, -0.2885, 0.3949, -0.0182, 0.4445, 0.3477, 0.2266, 0.3423, -0.0674, -0.4067,
0.0807, 0.1109, -0.2036, 0.1782, -0.2467, -0.0730, -0.4216, 0.0316, -0.3025, 0.3637,
-0.3181, -0.4655, -0.0258, 0.0073, -0.4780, -0.4101, -0.3556, -0.1017, 0.3632, -0.1823,
0.1479, 0.1677, -0.2603, 0.0381, 0.1575, 0.1896, 0.4755, -0.4794, 0.2167, -0.4474,
-0.3139, 0.1018, 0.4470, -0.4232, 0.3247, -0.1636, -0.1582, -0.1703, 0.3920, 0.2055,
-0.4386, 0.4208, 0.0717, 0.3789};
std::vector<float> input_data{
-0.5516, 0.2391, -1.6951, -0.0910, 1.2122, -0.1952, 0.3577, 1.3508, -0.5366,
0.4583, 2.3794, 1.0372, -0.4313, -0.9730, -0.2005, 0.4661, 0.6494, 2.1332,
1.7449, 0.5483, -0.0701, -0.8887, 0.7892, -0.4012, 2.3930, -0.5221, -0.1331,
-1.0972, 0.9816, 0.1122, -0.4100, -2.2344, 0.3685, -0.2818, -2.3374, 1.5310};
std::vector<float> ih_data{1.9104, -1.9004, 0.3337, 0.5741, 1.5289, 1.0986,
0.6091, 1.6462, 0.5671, 0.0458, 0.4514, -0.8968,
0.8720, 0.5349, -0.1962, -1.7416, -0.9201, 0.1962,
0.5771, -0.5332, -0.9912, 1.2831, 1.0896, -0.6959};
std::vector<float> ic_data{0.9569, -0.5981, 1.1312, 1.0945, -0.8323, 0.3998,
0.1831, 0.5938, 1.1055, -0.1212, -0.9097, 0.7831,
2.7096, -0.1790, 0.0022, -0.8040, -1.6991, -1.9498,
-1.2567, -0.4114, 0.1578, 0.0567, 0.8069, -0.5141};
std::vector<float> pph_data{1.84369764, 0.68413646, -0.44892886, -1.50904413, 0.3860796,
-0.52186625, 1.08474445, -1.80867321, 1.32594529, 0.4336262,
-0.83699064, 0.49162736, -0.8271, -0.5683, 0.4562,
-1.2545, 1.2729, -0.4082, -0.4392, -0.9406,
0.7794, 1.8194, -0.5811, 0.2166};
float clip = 0.0f;
migraphx::shape in_shape{migraphx::shape::float_type, {batch_size, seq_len, input_size}};
migraphx::shape w_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, input_size}};
migraphx::shape r_shape{migraphx::shape::float_type, {num_dirct, 4 * hidden_size, hidden_size}};
migraphx::shape b_shape{migraphx::shape::float_type, {num_dirct, 8 * hidden_size}};
migraphx::shape ih_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape ic_shape{migraphx::shape::float_type, {batch_size, num_dirct, hidden_size}};
migraphx::shape pph_shape{migraphx::shape::float_type, {num_dirct, 3 * hidden_size}};
// concatenation of hidden states as program output
{
std::vector<int> sl_data{1, 2, 3};
migraphx::shape sl_shape{migraphx::shape::int32_type, {batch_size}};
migraphx::program p;
auto* mm = p.get_main_module();
auto seq = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
auto sql = mm->add_literal(migraphx::literal{sl_shape, sl_data});
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto out_hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), out_hs);
auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), out_hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
out_hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}),
out_hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco);
mm->add_return({out_hs, lho, lco});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto arg_hs = outputs.front();
auto arg_lho = outputs.at(1);
auto arg_lco = outputs.at(2);
std::vector<float> output_data;
std::vector<float> last_output_data;
std::vector<float> last_cell_data;
arg_hs.visit([&](auto output) { output_data.assign(output.begin(), output.end()); });
arg_lho.visit([&](auto output) { last_output_data.assign(output.begin(), output.end()); });
arg_lco.visit([&](auto output) { last_cell_data.assign(output.begin(), output.end()); });
std::vector<float> output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804,
0.0745128, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0.294074, -0.0319677, -0.0955337,
0.104168, 0.911307, 0.11468, 0.114449, 0.0196755, 0.421857, 0.0459771,
-0.144955, 0.0720673, 0.96657, 0.0755112, 0.0620917, -0.264845, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0.022618, -0.121195, -0.4065, -0.252054, -0.262807, 0.275286,
0.358395, 0.266267, -0.0300906, -0.0890598, -0.135266, -0.0413375, -0.128254,
0.125398, 0.0665142, -0.163651, 0.103489, 0.0142918, -0.123408, 0.0401075,
-0.0644683, 0.371512, 0.212431, -0.116131, 0, 0, 0,
0, 0, 0, 0, 0};
std::vector<float> last_output_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.141643, 0.0451978, 0.140804, 0.0745128,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.911307, 0.11468, 0.114449, 0.0196755,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.262807, 0.275286, 0.358395, 0.266267};
std::vector<float> last_cell_data_gold{
0.600582, -0.601197, 0.353558, 0.789097, -0.326822, 0.301121, 0.219523, 0.415242,
0.737121, 0.134902, -0.303595, 0.241948, 2.08242, 0.442513, 0.187127, 0.0577626,
0.391174, 0.0308845, -0.561745, 0.0730323, -0.611307, 0.55454, 0.4364, 0.509436};
EXPECT(migraphx::verify::verify_rms_range(output_data, output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_output_data, last_output_data_gold));
EXPECT(migraphx::verify::verify_rms_range(last_cell_data, last_cell_data_gold));
}
// last cell output as program output
{
migraphx::program p;
auto* mm = p.get_main_module();
auto seq_orig = mm->add_literal(migraphx::literal{in_shape, input_data});
auto ih = mm->add_literal(migraphx::literal{ih_shape, ih_data});
auto ic = mm->add_literal(migraphx::literal{ic_shape, ic_data});
auto w = mm->add_literal(migraphx::literal{w_shape, w_data});
auto r = mm->add_literal(migraphx::literal{r_shape, r_data});
auto bias = mm->add_literal(migraphx::literal{b_shape, bias_data});
auto pph = mm->add_literal(migraphx::literal{pph_shape, pph_data});
migraphx::shape pad_seq_s{migraphx::shape::float_type, {batch_size, 2, input_size}};
std::vector<float> pad_data(pad_seq_s.elements(), 0.0f);
auto seq_p = mm->add_literal(migraphx::literal{pad_seq_s, pad_data});
auto seq = mm->add_instruction(migraphx::make_op("concat", {{"axis", 1}}), seq_orig, seq_p);
migraphx::shape seq_len_s{migraphx::shape::int32_type, {batch_size}};
std::vector<int32_t> len_data(batch_size, static_cast<int32_t>(seq_len));
auto sql = mm->add_literal(seq_len_s, len_data);
std::vector<int64_t> perm{1, 0, 2};
seq = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), seq);
ih = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ih);
ic = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), ic);
auto hs = mm->add_instruction(
migraphx::make_op(
"lstm",
{{"hidden_size", hidden_size},
{"actv_func",
migraphx::to_value(std::vector<migraphx::operation>{migraphx::make_op("sigmoid"),
migraphx::make_op("tanh"),
migraphx::make_op("tanh")})},
{"direction", migraphx::to_value(migraphx::op::rnn_direction::bidirectional)},
{"clip", clip},
{"input_forget", 0}}),
seq,
w,
r,
bias,
sql,
ih,
ic,
pph);
auto lho = mm->add_instruction(migraphx::make_op("rnn_last_hs_output"), hs);
auto lco = mm->add_instruction(migraphx::make_op("rnn_last_cell_output"), hs);
std::vector<int64_t> perm_hid{2, 0, 1, 3};
hs = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm_hid}}), hs);
lho = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lho);
lco = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", perm}}), lco);
mm->add_return({hs, lho, lco});
p.compile(migraphx::make_target("ref"));
auto outputs = p.eval({});
auto res_hs = outputs.at(0);
auto res_lho = outputs.at(1);
auto res_lco = outputs.at(2);
std::vector<float> hs_data;
std::vector<float> lho_data;
std::vector<float> lco_data;
res_hs.visit([&](auto output) { hs_data.assign(output.begin(), output.end()); });
res_lho.visit([&](auto output) { lho_data.assign(output.begin(), output.end()); });
res_lco.visit([&](auto output) { lco_data.assign(output.begin(), output.end()); });
std::vector<float> hs_data_gold{
0.079753, -0.289854, 0.160043, 0.115056, -0.120174, 0.043157, 0.117138,
-0.222188, 0.186991, -0.0624168, 0.205513, 0.0836373, -0.175114, -0.00543549,
0.178681, -0.266999, 0.0459033, 0.0414126, 0.272303, 0.0393149, -0.182201,
-0.0232277, 0.235501, -0.213485, -0.058052, 0.0795391, 0.266617, -0.0128746,
-0.185038, -0.026845, 0.177273, -0.0774616, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0.294074,
-0.0319677, -0.0955337, 0.104168, 0.789732, 0.128538, 0.20909, 0.0553812,
0.421857, 0.0459771, -0.144955, 0.0720673, 0.928866, 0.113685, 0.220626,
-0.0432316, 0.218258, 0.0944405, 0.0431211, -0.132394, 0.960938, 0.133565,
0.269741, 0.130438, 0.0309878, 0.0971544, 0.149294, -0.0492549, 0.946669,
0.0868676, 0.044508, -0.373961, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0.022618, -0.121195,
-0.4065, -0.252054, -0.224905, 0.32421, 0.344048, 0.271694, -0.0300906,
-0.0890598, -0.135266, -0.0413375, -0.063456, 0.148524, 0.05108, -0.0234895,
0.103489, 0.0142918, -0.123408, 0.0401075, -0.0252804, 0.267356, 0.146353,
0.0789186, 0.187761, 0.0501726, -0.121584, 0.0606723, -0.0681467, 0.382748,
0.230211, -0.161537, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0};
std::vector<float> lho_data_gold{
-0.058052, 0.0795391, 0.266617, -0.0128746, -0.120174, 0.043157, 0.117138, -0.222188,
0.0309878, 0.0971544, 0.149294, -0.0492549, 0.789732, 0.128538, 0.20909, 0.0553812,
0.187761, 0.0501726, -0.121584, 0.0606723, -0.224905, 0.32421, 0.344048, 0.271694};
std::vector<float> lco_data_gold{
-0.077353, 0.245616, 0.361023, -0.0443759, -0.889294, 0.182463, 0.186512, -0.402334,
0.0685243, 0.20465, 0.277867, -0.112934, 1.48161, 0.524116, 0.347113, 0.181813,
0.67312, 0.120508, -0.726968, 0.113845, -0.434265, 0.747833, 0.416053, 0.558713};
EXPECT(migraphx::verify::verify_rms_range(hs_data, hs_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lho_data, lho_data_gold));
EXPECT(migraphx::verify::verify_rms_range(lco_data, lco_data_gold));
}
}
TEST_CASE(lstm_bidirectional_actv_func)
{
std::size_t batch_size = 3;
...
...
test/simplify_qdq_test.cpp
View file @
61775eab
...
...
@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
sqdq
.
apply
(
m
);
}
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
broadcast_scale
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
shift
)
const
std
::
vector
<
std
::
size_t
>&
out_lens
,
std
::
size_t
axis
)
{
auto
lens
=
x
->
get_shape
().
lens
();
if
(
scale
->
get_shape
().
lens
()
==
out_lens
)
return
scale
;
migraphx
::
instruction_ref
scale_mb
;
if
(
scale
->
get_shape
().
lens
().
front
()
==
1
)
auto
scale_lens
=
scale
->
get_shape
().
lens
();
if
(
scale_lens
.
front
()
==
1
and
scale_lens
.
size
()
==
1
)
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
scale
);
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_
lens
}}),
scale
);
else
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}}),
scale
);
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
out_lens
}}),
scale
);
return
scale_mb
;
}
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
shift
,
std
::
size_t
q_axis
=
1
)
{
auto
lens
=
x
->
get_shape
().
lens
();
auto
scale_mb
=
broadcast_scale
(
m
,
scale
,
lens
,
q_axis
);
auto
shift_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
shift
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
,
shift_mb
);
...
...
@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
)
migraphx
::
instruction_ref
scale
,
std
::
size_t
q_axis
=
1
)
{
auto
lens
=
x
->
get_shape
().
lens
();
migraphx
::
instruction_ref
scale_mb
;
if
(
scale
->
get_shape
().
lens
().
front
()
==
1
)
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
scale
);
else
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}}),
scale
);
auto
lens
=
x
->
get_shape
().
lens
();
auto
scale_mb
=
broadcast_scale
(
m
,
scale
,
lens
,
q_axis
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
);
}
migraphx
::
instruction_ref
add_scale_mul
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
scale1
,
migraphx
::
instruction_ref
scale2
,
std
::
size_t
axis1
,
std
::
size_t
axis2
,
const
std
::
vector
<
std
::
size_t
>&
out_lens
)
{
auto
scale1_mb
=
broadcast_scale
(
m
,
scale1
,
out_lens
,
axis1
);
auto
scale2_mb
=
broadcast_scale
(
m
,
scale2
,
out_lens
,
axis2
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale1_mb
,
scale2_mb
);
}
TEST_CASE
(
remove_qdq
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
100
,
100
}};
...
...
@@ -159,18 +180,62 @@ TEST_CASE(dot)
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
0.4
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
0
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
0
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
1
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale1
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m2
.
add_literal
(
0.4
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale1
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
0
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale1
,
scale2
,
0
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
...
...
@@ -178,6 +243,180 @@ TEST_CASE(dot)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_broadcasted
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
2
,
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
d2_mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
1000
,
1024
}}}),
d2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_mb
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2_mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
1000
,
1024
}}}),
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_mb
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_transposed
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1024
,
1000
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
d2_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_t
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_t
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale_transposed_broadcasted
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1024
,
1000
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
shape
sh4
{
migraphx
::
shape
::
float_type
,
{
1024
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh4
,
0
));
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
2
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
2
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
0
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
0
);
auto
d2_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d2
);
auto
d2_mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
1000
,
1024
}}}),
d2_t
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_mb
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh4
,
0
));
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
2
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
0
);
auto
q2_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q2
);
auto
q2_mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
1000
,
1024
}}}),
q2_t
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_mb
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale1
,
scale2
,
2
,
3
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale_unsupported_axis
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1000
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
0.4
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
1
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
1
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
1
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
t1
,
t2
);
m2
.
add_return
({
dot
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_non_zero_point
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
...
...
@@ -269,18 +508,18 @@ TEST_CASE(dot_add)
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
ab
=
m2
.
add_parameter
(
"ab"
,
sh3
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q
1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t
1
,
scale
,
zero
);
auto
q2
=
add_
quantiz
e_op
(
m2
,
"quant
izelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale
1
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d3
,
ab
);
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
ab
=
m2
.
add_parameter
(
"ab"
,
sh3
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q
2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t
2
,
scale
,
zero
);
auto
dot
=
m2
.
add_
instruction
(
migraphx
::
mak
e_op
(
"quant
_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
()
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_
scale
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d3
,
ab
);
m2
.
add_return
({
add
});
}
...
...
@@ -320,26 +559,80 @@ TEST_CASE(conv)
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
m2
.
add_return
({
d6
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv_multi_scale
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
migraphx
::
shape
s8
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s7
);
auto
weights
=
m1
.
add_parameter
(
"weights"
,
s4
);
auto
w_scale
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s8
,
0
));
auto
inp_scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
weights
,
w_scale
,
zero
,
0
);
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
input
,
inp_scale
,
zero
);
auto
d5
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
inp_scale
,
zero
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
d5
,
d1
);
m1
.
add_return
({
c1
});
}
migraphx
::
module
m2
;
{
auto
input
=
m2
.
add_parameter
(
"input"
,
s7
);
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
w_scale
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s8
,
0
));
auto
inp_scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q_inp
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
inp_scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q_inp
,
weights
);
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
m2
.
add_return
({
d6
});
auto
out_scale
=
add_scale_mul
(
m2
,
inp_scale
,
w_scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d1
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
m2
.
add_return
({
d1
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv_multi_scale
)
TEST_CASE
(
conv_multi_scale
_unsupported_axis
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
...
...
@@ -430,20 +723,20 @@ TEST_CASE(conv_bias_add)
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
b1
=
m2
.
add_instruction
(
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
auto
b1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d6
,
b1
);
m2
.
add_return
({
a1
});
...
...
@@ -519,22 +812,21 @@ TEST_CASE(conv_pooling_dot)
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
scale2
=
m2
.
add_literal
(
0.25
f
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
bc1
=
m2
.
add_instruction
(
auto
out_scale1
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale1
);
auto
bc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d5
,
bc1
);
auto
ap
=
...
...
@@ -545,10 +837,11 @@ TEST_CASE(conv_pooling_dot)
{
"lengths"
,
{
7
,
7
}},
{
"ceil_mode"
,
0
}}),
a1
);
auto
fl
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
ap
);
auto
q4
=
add_quantize_op
(
m2
,
"quantizelinear"
,
fl
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q4
,
db
);
auto
d9
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale2
);
auto
fl
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
ap
);
auto
q4
=
add_quantize_op
(
m2
,
"quantizelinear"
,
fl
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q4
,
db
);
auto
out_scale2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
0
,
dot
->
get_shape
().
lens
());
auto
d9
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale2
);
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1000
}}}),
d3
);
auto
a2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d9
,
mb1
);
...
...
test/verify/test_lstm_bidirct_3args_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_3args_layout
:
verify_program
<
test_lstm_bidirct_3args_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_bidirct_last_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_last_layout
:
verify_program
<
test_lstm_bidirct_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_hs_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_hs_layout
:
verify_program
<
test_lstm_forward_hs_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_last_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_last_layout
:
verify_program
<
test_lstm_forward_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
l_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
len
=
mm
->
add_literal
(
migraphx
::
literal
(
l_shape
,
{
1
,
2
}));
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
,
len
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_reverse_3args_cell_layout
:
verify_program
<
test_lstm_reverse_3args_cell_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment