Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
61775eab
Commit
61775eab
authored
Nov 16, 2023
by
Umang Yadav
Browse files
Merge branch 'ref_fp8' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into ref_fp8
parents
a5c38ebe
e7e5ba23
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3274 additions
and
410 deletions
+3274
-410
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+19
-0
src/onnx/parse_lstm.cpp
src/onnx/parse_lstm.cpp
+47
-0
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+103
-35
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+171
-0
test/onnx/lstm_bi_layout_cell_test.onnx
test/onnx/lstm_bi_layout_cell_test.onnx
+0
-0
test/onnx/lstm_bi_layout_last_test.onnx
test/onnx/lstm_bi_layout_last_test.onnx
+0
-0
test/onnx/lstm_f_layout_cell_test.onnx
test/onnx/lstm_f_layout_cell_test.onnx
+0
-0
test/onnx/lstm_f_layout_hs_test.onnx
test/onnx/lstm_f_layout_hs_test.onnx
+0
-0
test/onnx/lstm_r_layout_hs_cell_test.onnx
test/onnx/lstm_r_layout_hs_cell_test.onnx
+0
-0
test/onnx/lstm_r_layout_test.onnx
test/onnx/lstm_r_layout_test.onnx
+0
-0
test/onnx/onnx_rnn_test.cpp
test/onnx/onnx_rnn_test.cpp
+332
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/quantization.cpp
test/quantization.cpp
+70
-70
test/ref/rnn_ops.cpp
test/ref/rnn_ops.cpp
+1727
-234
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+363
-70
test/verify/test_lstm_bidirct_3args_layout.cpp
test/verify/test_lstm_bidirct_3args_layout.cpp
+77
-0
test/verify/test_lstm_bidirct_last_layout.cpp
test/verify/test_lstm_bidirct_last_layout.cpp
+95
-0
test/verify/test_lstm_forward_hs_layout.cpp
test/verify/test_lstm_forward_hs_layout.cpp
+95
-0
test/verify/test_lstm_forward_last_layout.cpp
test/verify/test_lstm_forward_last_layout.cpp
+97
-0
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
+78
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
61775eab
...
@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
...
@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
}
MIGRAPHX_PRED_MATCHER
(
has_same_value
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
bool
all_same
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
s
)
{
all_same
=
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
const
auto
&
scale
)
{
return
float_equal
(
scale
,
s
.
front
());
});
});
return
all_same
;
}
MIGRAPHX_BASIC_MATCHER
(
output
,
const
matcher_context
&
,
instruction_ref
ins
)
MIGRAPHX_BASIC_MATCHER
(
output
,
const
matcher_context
&
,
instruction_ref
ins
)
{
{
if
(
ins
->
outputs
().
size
()
==
1
)
if
(
ins
->
outputs
().
size
()
==
1
)
...
@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
...
@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
}
template
<
class
...
Ms
>
auto
skip_broadcasts_transposes_contiguous
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"transpose"
))(
ms
...);
}
template
<
class
T
>
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
{
...
...
src/onnx/parse_lstm.cpp
View file @
61775eab
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
}
}
void
lstm_transpose_inputs
(
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
0
]);
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_undefined
())
{
args
[
5
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
5
]);
}
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_undefined
())
{
args
[
6
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
6
]);
}
}
void
lstm_transpose_outputs
(
onnx_parser
::
node_info
&
info
,
instruction_ref
&
hidden_states
,
instruction_ref
&
last_output
,
instruction_ref
&
last_cell_output
)
{
std
::
vector
<
int64_t
>
perm_hs
{
2
,
0
,
1
,
3
};
hidden_states
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hs
}}),
hidden_states
);
std
::
vector
<
int64_t
>
perm_last
{
1
,
0
,
2
};
last_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_output
);
last_cell_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_cell_output
);
}
struct
parse_lstm
:
op_parser
<
parse_lstm
>
struct
parse_lstm
:
op_parser
<
parse_lstm
>
{
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
}
int
layout
=
0
;
if
(
contains
(
info
.
attributes
,
"layout"
))
{
layout
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"layout"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
if
(
args
.
size
()
<
8
)
{
{
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
}
if
(
layout
!=
0
)
{
lstm_transpose_inputs
(
info
,
args
);
}
// first output for concatenation of hidden states
// first output for concatenation of hidden states
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto
last_cell_output
=
auto
last_cell_output
=
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
if
(
layout
!=
0
)
{
lstm_transpose_outputs
(
info
,
hidden_states
,
last_output
,
last_cell_output
);
}
return
{
hidden_states
,
last_output
,
last_cell_output
};
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
}
};
};
...
...
src/simplify_qdq.cpp
View file @
61775eab
...
@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
...
@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return
s
;
return
s
;
}
}
MIGRAPHX_PRED_MATCHER
(
has_same_value
,
instruction_ref
ins
)
struct
match_find_quantizable_ops
{
{
if
(
ins
->
name
()
!=
"@literal"
)
static
bool
is_valid_scale
(
instruction_ref
scale
,
std
::
vector
<
std
::
size_t
>
lens
,
std
::
size_t
axis
)
{
return
scale
->
get_shape
().
scalar
()
or
scale
->
get_shape
().
elements
()
==
lens
.
at
(
axis
);
}
static
bool
is_valid_zero_point
(
instruction_ref
zp
)
{
if
(
not
zp
->
can_eval
())
return
false
;
return
false
;
bool
all_same
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
s
)
{
bool
all_zeros
=
false
;
all_same
=
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
const
auto
&
scale
)
{
zp
->
eval
().
visit
([
&
](
auto
z
)
{
return
float_equal
(
scale
,
s
.
front
());
all_zeros
=
});
std
::
all_of
(
z
.
begin
(),
z
.
end
(),
[
&
](
auto
val
)
{
return
float_equal
(
val
,
0
);
});
});
});
return
all_
same
;
return
all_
zeros
;
}
}
struct
match_find_quantizable_ops
static
auto
{
scale_broadcast_op
(
instruction_ref
scale
,
std
::
vector
<
std
::
size_t
>
lens
,
std
::
size_t
axis
)
{
if
(
scale
->
get_shape
().
scalar
())
{
return
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}});
}
else
{
return
migraphx
::
make_op
(
"broadcast"
,
{{
"out_lens"
,
lens
},
{
"axis"
,
axis
}});
}
}
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
next_ins
=
dqins
;
static
auto
dequantizelinear_op
(
const
std
::
string
&
name
,
const
std
::
string
&
scale
)
while
(
next_ins
!=
qop
)
{
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
{
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
}
next_ins
=
next_ins
->
outputs
().
front
();
}
return
qinp
;
}
static
auto
dequantizelinear_op
(
const
std
::
string
&
scale
,
const
std
::
string
&
zp
)
{
{
return
match
::
name
(
"dequantizelinear"
)(
return
match
::
name
(
"dequantizelinear"
)(
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"quantizelinear"
))(
match
::
any
()
.
bind
(
name
)
)),
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"quantizelinear"
))(
match
::
any
())),
match
::
arg
(
1
)(
match
::
skip_broadcasts
(
has_same_value
().
bind
(
scale
))),
match
::
arg
(
1
)(
match
::
skip_broadcasts
(
match
::
is_constant
().
bind
(
scale
))),
match
::
arg
(
2
)(
match
::
skip_broadcasts
(
match
::
all_of
(
match
::
has_value
(
0
)
))));
match
::
arg
(
2
)(
match
::
skip_broadcasts
(
match
::
is_constant
().
bind
(
zp
))));
}
}
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
get_quantizable_op_names
())(
return
match
::
name
(
get_quantizable_op_names
())(
match
::
arg
(
0
)(
dequantizelinear_op
(
"x1"
,
"scale1"
)),
match
::
arg
(
0
)(
match
::
skip_broadcasts_transposes_contiguous
(
match
::
arg
(
1
)(
dequantizelinear_op
(
"x2"
,
"scale2"
)));
dequantizelinear_op
(
"scale1"
,
"zp1"
).
bind
(
"dq1"
))),
match
::
arg
(
1
)(
match
::
skip_broadcasts_transposes_contiguous
(
dequantizelinear_op
(
"scale2"
,
"zp2"
).
bind
(
"dq2"
))));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
qop
=
r
.
result
;
auto
qop
=
r
.
result
;
auto
q1
=
r
.
instructions
[
"
x
1"
];
auto
d
q1
=
r
.
instructions
[
"
dq
1"
];
auto
q2
=
r
.
instructions
[
"
x
2"
];
auto
d
q2
=
r
.
instructions
[
"
dq
2"
];
auto
scale1
=
r
.
instructions
[
"scale1"
];
auto
scale1
=
r
.
instructions
[
"scale1"
];
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 type currently supported
// Only INT8 type currently supported
if
(
q1
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
if
(
dq1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
q2
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
dq2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
return
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if
(
not
(
is_valid_zero_point
(
zp1
)
and
is_valid_zero_point
(
zp2
)))
return
;
return
;
double
scale
;
// Only support scalar and 1D
scale
s
visit_all
(
scale1
->
get_
literal
(),
scale2
->
get_literal
())(
if
(
scale1
->
get_
shape
().
lens
().
size
()
!=
1
or
scale2
->
get_shape
().
lens
().
size
()
!=
1
)
[
&
](
const
auto
s1
,
const
auto
s2
)
{
scale
=
s1
.
front
()
*
s2
.
front
();
})
;
return
;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
q1
;
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
)
;
qop_args
.
at
(
1
)
=
q2
;
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
)
;
instruction_ref
dq
;
instruction_ref
dq
;
instruction_ref
dq
_scale
;
instruction_ref
out
_scale
;
instruction_ref
zero_point
;
instruction_ref
zero_point
;
if
(
qop
->
name
()
==
"convolution"
)
if
(
qop
->
name
()
==
"convolution"
)
{
{
auto
conv_val
=
qop
->
get_operator
().
to_value
();
auto
conv_val
=
qop
->
get_operator
().
to_value
();
dq
=
m
.
insert_instruction
(
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_convolution"
,
conv_val
),
qop_args
);
qop
,
migraphx
::
make_op
(
"quant_convolution"
,
conv_val
),
qop_args
);
auto
out_lens
=
dq
->
get_shape
().
lens
();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if
(
not
(
is_valid_scale
(
scale1
,
out_lens
,
1
)
and
is_valid_scale
(
scale2
,
out_lens
,
1
)))
return
;
auto
s1_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale1
,
out_lens
,
1
),
scale1
);
auto
s2_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale2
,
out_lens
,
1
),
scale2
);
out_scale
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"mul"
),
s1_bcast
,
s2_bcast
);
}
}
else
if
(
qop
->
name
()
==
"dot"
)
else
if
(
qop
->
name
()
==
"dot"
)
{
{
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_dot"
),
qop_args
);
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_dot"
),
qop_args
);
auto
out_lens
=
dq
->
get_shape
().
lens
();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if
(
not
(
is_valid_scale
(
scale1
,
out_lens
,
out_lens
.
size
()
-
2
)
and
is_valid_scale
(
scale2
,
out_lens
,
out_lens
.
size
()
-
1
)))
return
;
auto
s1_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale1
,
out_lens
,
out_lens
.
size
()
-
2
),
scale1
);
auto
s2_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale2
,
out_lens
,
out_lens
.
size
()
-
1
),
scale2
);
out_scale
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"mul"
),
s1_bcast
,
s2_bcast
);
}
}
auto
ins_type
=
qop
->
get_shape
().
type
();
dq_scale
=
m
.
add_literal
(
literal
({
ins_type
},
{
scale
}));
auto
lens
=
dq
->
get_shape
().
lens
();
dq
=
m
.
insert_instruction
(
qop
,
make_op
(
"dequantizelinear"
),
dq
,
out_scale
);
auto
scale_mb
=
m
.
insert_instruction
(
qop
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
dq_scale
);
dq
=
m
.
insert_instruction
(
qop
,
make_op
(
"dequantizelinear"
),
dq
,
scale_mb
);
m
.
replace_instruction
(
qop
,
dq
);
m
.
replace_instruction
(
qop
,
dq
);
}
}
};
};
...
...
test/onnx/gen_onnx.py
View file @
61775eab
...
@@ -4484,6 +4484,177 @@ def lrn_test():
...
@@ -4484,6 +4484,177 @@ def lrn_test():
return
([
node
],
[
x
],
[
y
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
lstm_bi_layout_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
2
,
60
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
''
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'bidirectional'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
cellout
])
@
onnx_test
()
def
lstm_bi_layout_last_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
2
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
2
,
20
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
,
'output'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'bidirectional'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
,
output
])
@
onnx_test
()
def
lstm_f_layout_hs_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
1
,
20
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
,
'output'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'forward'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
,
output
])
@
onnx_test
()
def
lstm_f_layout_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
''
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'forward'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
cellout
])
@
onnx_test
()
def
lstm_r_layout_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'reverse'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
])
@
onnx_test
()
def
lstm_r_layout_hs_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
'output'
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'reverse'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
output
,
cellout
])
@
onnx_test
()
@
onnx_test
()
def
matmul_bmbm_test
():
def
matmul_bmbm_test
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
FLOAT
,
[
3
,
6
,
7
])
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
FLOAT
,
[
3
,
6
,
7
])
...
...
test/onnx/lstm_bi_layout_cell_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_bi_layout_last_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_f_layout_cell_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_f_layout_hs_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_r_layout_hs_cell_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/lstm_r_layout_test.onnx
0 → 100644
View file @
61775eab
File added
test/onnx/onnx_rnn_test.cpp
View file @
61775eab
...
@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
...
@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
}
}
}
}
TEST_CASE
(
lstm_forward_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
1
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 8 args, hs and last output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
auto
prog
=
optimize_onnx
(
"lstm_f_layout_hs_test.onnx"
);
EXPECT
(
p
==
prog
);
}
// 8 args, cell output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_f_layout_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
// activation functions
// activation functions
TEST_CASE
(
lstm_forward_actv_func
)
TEST_CASE
(
lstm_forward_actv_func
)
{
{
...
@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
...
@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
}
}
}
}
TEST_CASE
(
lstm_reverse_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
1
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 8 args, hs output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
auto
prog
=
optimize_onnx
(
"lstm_r_layout_test.onnx"
);
EXPECT
(
p
==
prog
);
}
// 8 args, last and cell output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_r_layout_hs_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
lstm_bidirectional
)
TEST_CASE
(
lstm_bidirectional
)
{
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
sl
=
5
;
// sequence len
...
@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
...
@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
}
}
}
}
TEST_CASE
(
lstm_bidirectional_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
2
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 0 activation function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
auto
prog
=
optimize_onnx
(
"lstm_bi_layout_last_test.onnx"
);
EXPECT
(
p
==
prog
);
}
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_bi_layout_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
lstm_bi_actv_funcs
)
TEST_CASE
(
lstm_bi_actv_funcs
)
{
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
sl
=
5
;
// sequence len
...
...
test/py/onnx_backend_test.py
View file @
61775eab
...
@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
...
@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# fails
# from OnnxBackendNodeModelTest
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_gru_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_gru_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_lstm_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_simple_rnn_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_simple_rnn_batchwise_cpu'
)
# from OnnxBackendPyTorchConvertedModelTest
# from OnnxBackendPyTorchConvertedModelTest
backend_test
.
exclude
(
r
'test_MaxPool1d_stride_padding_dilation_cpu'
)
backend_test
.
exclude
(
r
'test_MaxPool1d_stride_padding_dilation_cpu'
)
...
...
test/quantization.cpp
View file @
61775eab
...
@@ -638,11 +638,10 @@ TEST_CASE(dot_float)
...
@@ -638,11 +638,10 @@ TEST_CASE(dot_float)
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
auto
scale_mb
=
mm
->
add_instruction
(
auto
dc
=
mm
->
add_literal
(
100.0
f
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale
);
auto
mdc
=
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sc
.
lens
()}}),
dc
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
mdc
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
...
@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
...
@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
scale_a
=
mm
->
add_literal
(
10.0
);
auto
scale_a
_lit
=
mm
->
add_literal
(
10.0
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
scale_a
=
mm
->
add_instruction
(
auto
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
_lit
);
auto
zp_a
=
auto
zp_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
scale_b
=
mm
->
add_literal
(
5.0
);
auto
scale_b
_lit
=
mm
->
add_literal
(
5.0
);
scale_b
=
mm
->
add_instruction
(
auto
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
_lit
);
auto
zp_b
=
auto
zp_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
scale
=
mm
->
add_literal
(
50.0
);
auto
scale_a_mb
=
mm
->
add_instruction
(
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale
);
scale_a_lit
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
scale
);
auto
scale_b_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_b_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_a_mb
,
scale_b_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
};
};
...
@@ -799,18 +802,15 @@ TEST_CASE(dot_half_1arg)
...
@@ -799,18 +802,15 @@ TEST_CASE(dot_half_1arg)
auto
x
=
mm
->
add_parameter
(
"x"
,
sa
);
auto
x
=
mm
->
add_parameter
(
"x"
,
sa
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
auto
scale
=
mm
->
add_instruction
(
scale
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
_lit
);
zp
=
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
dq_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
100.0
}));
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale
,
scale
);
dq_scale
=
mm
->
add_instruction
(
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
dq_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
dq_scale
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
};
};
...
@@ -852,9 +852,9 @@ TEST_CASE(conv_float)
...
@@ -852,9 +852,9 @@ TEST_CASE(conv_float)
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
10.0
f
);
auto
scale
_lit
=
mm
->
add_literal
(
10.0
f
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
auto
scale
=
mm
->
add_instruction
(
scale
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
...
@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
...
@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
migraphx
::
shape
sc
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
auto
scale_mb
=
mm
->
add_instruction
(
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
migraphx
::
shape
s_scale
{
migraphx
::
shape
::
float_type
,
sc
.
lens
()};
scale_lit
);
auto
d_scale
=
mm
->
add_literal
(
100.0
f
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
d_scale
=
mm
->
add_instruction
(
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
...
@@ -931,19 +929,20 @@ TEST_CASE(conv_half)
...
@@ -931,19 +929,20 @@ TEST_CASE(conv_half)
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
auto
scale
=
mm
->
add_instruction
(
scale
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
auto
quant_w
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pw
,
scale
,
zp
);
auto
quant_w
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pw
,
scale
,
zp
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
auto
d_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
100.0
}));
auto
scale_mb
=
mm
->
add_instruction
(
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
scale_lit
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
mm
->
add_return
({
r
});
return
p
;
return
p
;
...
@@ -1187,9 +1186,9 @@ TEST_CASE(int8_subgraph)
...
@@ -1187,9 +1186,9 @@ TEST_CASE(int8_subgraph)
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
zp1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
zp1
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
s
o
=
then_mod
->
add_
literal
(
100.0
f
);
auto
s
1_mb
=
then_mod
->
add_
instruction
(
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
s1
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so
);
auto
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
s1_mb
,
s1_mb
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
then_mod
->
add_return
({
r
});
then_mod
->
add_return
({
r
});
...
@@ -1199,23 +1198,24 @@ TEST_CASE(int8_subgraph)
...
@@ -1199,23 +1198,24 @@ TEST_CASE(int8_subgraph)
auto
w
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
w
=
mm
->
add_parameter
(
"w"
,
sw
);
// else submod
// else submod
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
sax
=
else_mod
->
add_literal
(
2.0
f
);
auto
sax
_lit
=
else_mod
->
add_literal
(
2.0
f
);
auto
zp
=
else_mod
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
zp
=
else_mod
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
sax
=
else_mod
->
add_instruction
(
auto
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
_lit
);
auto
zpx
=
else_mod
->
add_instruction
(
auto
zpx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
zp
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
zp
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
ssw
=
else_mod
->
add_literal
(
1.66667
f
);
auto
ssw
_lit
=
else_mod
->
add_literal
(
1.66667
f
);
ssw
=
else_mod
->
add_instruction
(
auto
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
_lit
);
auto
zpw
=
else_mod
->
add_instruction
(
auto
zpw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
zp
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
zp
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
so1
=
else_mod
->
add_literal
(
3.33333
f
);
auto
ssw_mb
=
else_mod
->
add_instruction
(
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qconv
->
get_shape
().
lens
()}}),
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so1
);
ssw_lit
);
auto
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sax
,
ssw_mb
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
else_mod
->
add_return
({
r1
});
else_mod
->
add_return
({
r1
});
...
...
test/ref/rnn_ops.cpp
View file @
61775eab
...
@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward)
...
@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward)
}
}
}
}
TEST_CASE
(
lstm_forward_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// forward, hidden state concatenation as output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
und
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.0417273
,
-
0.272355
,
0.206765
,
0.223879
,
0.0742487
,
-
0.0800085
,
0.259897
,
0.0670196
,
-
0.00532985
,
0.0440265
,
0.29654
,
-
0.0463156
,
-
0.0847427
,
0.0874114
,
0.304256
,
-
0.0585745
,
0.138193
,
-
0.0322939
,
-
0.0891815
,
0.15773
,
0.184266
,
0.0610048
,
-
0.138041
,
0.0963885
,
0.0498799
,
0.125772
,
0.0533032
,
-
0.131413
,
-
0.0223018
,
0.131113
,
0.135643
,
-
0.056620
,
0.19139
,
-
0.127708
,
-
0.409371
,
-
0.136186
,
0.0213755
,
-
0.146027
,
-
0.0324509
,
-
0.0620429
,
0.0988431
,
-
0.018085
,
-
0.159434
,
0.030266
,
0.142701
,
0.0342236
,
-
0.198664
,
0.0702607
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
}
// forward, last_output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
und
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0847427
,
0.0874114
,
0.304256
,
-
0.0585745
,
-
0.0223018
,
0.131113
,
0.135643
,
-
0.0566208
,
0.142701
,
0.0342236
,
-
0.198664
,
0.0702607
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// forward, last_cell_output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
und
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.111454
,
0.247794
,
0.471087
,
-
0.220574
,
-
0.048196
,
0.263184
,
0.283258
,
-
0.14882
,
0.605585
,
0.078598
,
-
0.64457
,
0.119811
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
TEST_CASE
(
lstm_forward_more
)
TEST_CASE
(
lstm_forward_more
)
{
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
batch_size
=
3
;
...
@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more)
...
@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more)
}
}
}
}
TEST_CASE
(
lstm_
reverse
)
TEST_CASE
(
lstm_
forward_more_layout
)
{
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
seq_len
=
4
;
...
@@ -3527,26 +3785,978 @@ TEST_CASE(lstm_reverse)
...
@@ -3527,26 +3785,978 @@ TEST_CASE(lstm_reverse)
std
::
size_t
input_size
=
3
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
std
::
vector
<
float
>
w_data
{
-
0.
2763
,
-
0.
4715
,
-
0.
3010
,
-
0.
2306
,
-
0.
2283
,
-
0.2
656
,
0.2
035
,
0.3
570
,
-
0.1499
,
0.4390
,
0.
1236
,
-
0.
3942
,
0.
4149
,
0.
0795
,
0.
4934
,
-
0.2
858
,
0.2
602
,
-
0.3
098
,
0.0567
,
0.3344
,
-
0.
1843
,
0.23
51
,
0.
3357
,
0.
1217
,
0.
1401
,
0.3
300
,
-
0.0429
,
0.3266
,
0.4
834
,
-
0.
3914
,
0.
3607
,
-
0.05
51
,
0.
4952
,
0.
3799
,
0.
0630
,
-
0.3
532
,
0.0023
,
-
0.0592
,
0.4
267
,
0.
2382
,
-
0.
1480
,
0.373
4
,
-
0.03
7
2
,
-
0.
174
6
,
0.0
550
,
0.4
17
7
,
-
0.
1332
,
0.4391
,
-
0.
3287
,
-
0.
4401
,
-
0.
078
4
,
-
0.
0
032
,
-
0.
247
6
,
-
0.0
206
,
-
0.4963
,
0.4
83
7
,
0.
0827
,
0.0123
,
-
0.
1203
,
-
0.
0279
,
0.
1486
,
0.1346
,
0.
1048
,
-
0.4361
,
0.0886
,
-
0.
3840
,
-
0.
2730
,
-
0.1710
,
0.3274
,
0.
0169
,
-
0.
0049
,
0.4721
,
-
0.
3564
,
-
0.1286
,
0.4090
,
-
0.
0504
,
0.
0575
,
-
0.2138
,
0.1071
,
0.
1976
,
-
0.
4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4
150
,
-
0.4684
,
-
0.
2522
};
-
0.
0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4
845
,
-
0.1496
,
0.
3285
};
std
::
vector
<
float
>
r_data
{
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// forward, 3 args
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.071286
,
-
0.102509
,
-
0.0372696
,
0.252296
,
-
0.144544
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.074206
,
0.0124086
,
-
0.139544
,
0.108016
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.187329
,
-
0.0458544
,
-
0.0401315
,
0.0737483
,
-
0.064505
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
0.0855831
,
-
0.0171894
,
-
0.140202
,
0.0828391
,
0.136898
,
0.00160891
,
-
0.184812
,
0.147774
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// forward, 8 args
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.0459033
,
0.0414126
,
0.272303
,
0.0393149
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
}
// forward, last_output as program output, sequence length shorter
// than max_seq_len
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
und
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0847427
,
0.0874114
,
0.304256
,
-
0.0585745
,
-
0.0223018
,
0.131113
,
0.135643
,
-
0.0566208
,
0.142701
,
0.0342236
,
-
0.198664
,
0.0702607
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// seq_len = 1
{
seq_len
=
1
;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
}
}
TEST_CASE
(
lstm_reverse
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
// reverse, concatenation of hidden states as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
2
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// variable sequence lengths
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
{
3
,
2
,
1
};
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
-
0.168327
,
0.00023761
,
0.167567
,
-
0.0621982
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, last cell output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, 0 actv function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{}},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
TEST_CASE
(
lstm_reverse_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
// reverse, concatenation of hidden states as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// variable sequence lengths
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
{
3
,
2
,
1
};
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
-
0.168327
,
0.00023761
,
0.167567
,
-
0.0621982
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0
,
0
,
0
,
0
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, last cell output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
// lstm activation function test
TEST_CASE
(
lstm_reverse_actv
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
...
@@ -3601,127 +4811,94 @@ TEST_CASE(lstm_reverse)
...
@@ -3601,127 +4811,94 @@ TEST_CASE(lstm_reverse)
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
// reverse, concatenation of hidden states as program output
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
to_value
(
migraphx
::
make_op
(
"tanh"
),
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
)})},
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
,
r
);
bias
,
und
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.246078
,
0.199709
,
0.303753
,
0.301178
,
0.264634
,
0.304661
,
0.349371
,
0.288934
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.405483
,
0.445586
,
0.515814
,
0.473186
,
0.301937
,
0.264893
,
0.254353
,
0.269231
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.359258
,
0.400097
,
0.288884
,
0.247329
,
0.276519
,
0.264249
,
0.1769
,
0.23213
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.310306
,
0.262902
,
0.276964
,
0.295002
,
0.373802
,
0.366785
,
0.419791
,
0.393216
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.262827
,
0.371441
,
0.369022
,
0.298262
,
0.334143
,
0.309444
,
0.174822
,
0.251634
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.244564
,
0.214386
,
0.185994
,
0.226699
,
0.28445
,
0.376092
,
0.338326
,
0.259502
};
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
// reverse,
sequence lengths are the same, but less than max_seq_le
ns
// reverse,
3 args, 2 actv functio
ns
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
hs
=
mm
->
add_instruction
(
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
2
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
)})},
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
,
r
);
bias
,
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
sql
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.132123
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.
20909
,
-
0.
37531
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
-
0.12943
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
-
0.00798307
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.
2
13
485
,
-
0.13
3882
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.14635
3
,
-
0.025138
3
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868
67
6
,
0.0
4
86
4
86
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
-
0.0220606
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.292495
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.233866
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.48646
,
0.0
,
0.
0
};
0.
481844
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
//
variable sequence lengths
//
reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
{
seq_len
=
1
;
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
1
,
input_data
1
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
{
3
,
2
,
1
};
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
...
@@ -3735,35 +4912,113 @@ TEST_CASE(lstm_reverse)
...
@@ -3735,35 +4912,113 @@ TEST_CASE(lstm_reverse)
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
,
r
);
bias
,
sql
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.104351
,
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
0.911307
,
0.11468
,
0.114449
,
-
0.0471426
,
0.0196755
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
-
0.168327
,
0.00023761
,
-
0.0905753
,
0.167567
,
-
0.0621982
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0.01506
,
0
,
0
,
0
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0.059797
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.104239
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
-
0.0266768
,
0
,
0
,
0
,
0
,
0
,
0
};
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
}
// reverse, 3 args, last cell output as program output
TEST_CASE
(
lstm_bidirectional
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
,
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
,
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
,
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
,
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
,
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
,
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// concatenation of hidden states as program output
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -3771,242 +5026,177 @@ TEST_CASE(lstm_reverse)
...
@@ -3771,242 +5026,177 @@ TEST_CASE(lstm_reverse)
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
);
r
,
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
bias
,
und
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
std
::
vector
<
float
>
output_data_gold
{
-
0.325425
,
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
-
0.249367
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.120174
,
0.043157
,
-
0.270812
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.122913
,
0.32421
,
0.344048
,
0.271694
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.118537
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
0.0370199
,
-
0.0413375
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
-
0.0164687
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.0459032
,
-
0.00754759
,
0.0414126
,
0.272303
,
0.0393149
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.141613
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.182201
,
-
0.0232277
,
0.235501
,
0.348002
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.667298
};
0.146353
,
0.0789186
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
//
reverse, 3 args, 0 actv function
//
last hidden state as program output
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
hs
=
mm
->
add_instruction
(
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{}},
{
"actv_func"
,
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
);
r
,
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
bias
,
und
,
ih
,
ic
,
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
std
::
vector
<
float
>
output_data_gold
{
-
0.325425
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
-
0.249367
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
-
0.270812
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
}
// lstm activation function test
TEST_CASE
(
lstm_reverse_actv
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
// last cell output as program output
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{
"actv_func"
,
migraphx
::
to_value
(
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
)})},
migraphx
::
make_op
(
"tanh"
),
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
);
r
,
bias
,
und
,
ih
,
ic
,
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
0.246078
,
0.199709
,
0.303753
,
0.301178
,
0.264634
,
0.304661
,
0.349371
,
0.288934
,
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
0.405483
,
0.445586
,
0.515814
,
0.473186
,
0.301937
,
0.264893
,
0.254353
,
0.269231
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
0.359258
,
0.400097
,
0.288884
,
0.247329
,
0.276519
,
0.264249
,
0.1769
,
0.23213
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
0.310306
,
0.262902
,
0.276964
,
0.295002
,
0.373802
,
0.366785
,
0.419791
,
0.393216
,
0.262827
,
0.371441
,
0.369022
,
0.298262
,
0.334143
,
0.309444
,
0.174822
,
0.251634
,
0.244564
,
0.214386
,
0.185994
,
0.226699
,
0.28445
,
0.376092
,
0.338326
,
0.259502
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
//
reverse, 3 args, 2 actv functions
//
3 args, concatenation of hidden states as program output
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
)})},
migraphx
::
make_op
(
"tanh"
),
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
w
,
w
,
r
);
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.132123
,
std
::
vector
<
float
>
output_data_gold
{
-
0.37531
,
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
-
0.12943
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.162851
,
-
0.102647
,
-
0.00798307
,
-
0.113827
,
-
0.142818
,
0.0513685
,
0.0547876
,
0.0201981
,
-
0.00808453
,
-
0.00520328
,
-
0.133882
,
0.0945081
,
0.264123
,
0.410805
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.071286
,
-
0.0251383
,
0.074206
,
0.0124086
,
-
0.139544
,
0.108016
,
-
0.00973633
,
-
0.0552699
,
0.0252681
,
0.0486486
,
-
0.0562072
,
-
0.123496
,
-
0.153616
,
-
0.032874
,
-
0.195349
,
0.0192675
,
-
0.108636
,
-
0.0220606
,
0.098927
,
-
0.140733
,
0.162602
,
0.0143099
,
-
0.0455534
,
0.0151574
,
-
0.102509
,
0.292495
,
-
0.0372696
,
0.252296
,
-
0.144544
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.187329
,
0.233866
,
0.0855831
,
-
0.0171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.150145
,
0.015065
,
0.48646
,
-
0.192699
,
-
0.112764
,
-
0.120496
,
0.155754
,
0.148256
,
0.208491
,
0.348432
,
0.481844
};
0.0291103
,
0.230275
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.0458544
,
-
0.0401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.00160891
,
-
0.184812
,
0.147774
,
-
0.021205
,
-
0.125423
,
0.0206439
,
-
0.187097
,
-
0.0051453
,
-
0.0767618
,
-
0.0735348
,
-
0.0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
//
reverse, 3 args, seq_len =
1, con
ca
tenation of hidden state
s
as program output
//
sequence length is
1, contenation of hidden state as program output
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
seq_len
=
1
;
seq_len
=
1
;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
input_data1
{
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
mm
->
add_instruction
(
...
@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv)
...
@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv)
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
{
"input_forget"
,
0
}}),
seq
,
seq
,
...
@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv)
...
@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv)
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.104351
,
std
::
vector
<
float
>
output_data_gold
{
-
0.0471426
,
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0905753
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
0.01506
,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
0.059797
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
}
}
TEST_CASE
(
lstm_bidirectional
)
TEST_CASE
(
lstm_bidirectional
_layout
)
{
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
seq_len
=
4
;
...
@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional)
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
4313
,
-
0.9730
,
-
0.2005
,
2
.3
930
,
-
0.5221
,
-
0.1331
,
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0
.3
577
,
1.3508
,
-
0.5366
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.112
2
,
0.
4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.133
2
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.
07
01
,
-
0.4100
,
-
2.2344
,
0.
3685
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.
4
01
2
,
2.3930
,
-
0.5221
,
-
0.
1331
,
0.4583
,
2.3794
,
1.0
3
72
,
-
0.8887
,
0.
789
2
,
-
0.4
012
,
-
0.2818
,
-
2.3374
,
1.5310
};
-
1.0
9
72
,
0.9816
,
0.
112
2
,
-
0.4
100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
1.5289
,
1.0986
,
0.
4514
,
-
0.
8968
,
-
0.
9201
,
0.1962
,
0.5771
,
-
0.5332
,
0.
6091
,
1.6462
,
0.
5671
,
0.
0458
,
0.4514
,
-
0.8968
,
1.5289
,
1.0986
,
0.
6091
,
1.
6462
,
0.
87
20
,
0.
5349
,
0.8720
,
0.5349
,
-
0.
1962
,
-
1.
7416
,
-
0.
9
20
1
,
0.
1962
,
-
0.
1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
0.
5771
,
-
0.5332
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
-
0.8323
,
0.3998
,
-
0.
9097
,
0.
7831
,
-
1.
6991
,
-
1.9498
,
-
1.256
7
,
-
0.
4114
,
0.
1831
,
0.
5938
,
1.
1055
,
-
0.1212
,
-
0.909
7
,
0.
7831
,
-
0.8323
,
0.3998
,
0.1831
,
0.
5938
,
2.7096
,
-
0.1790
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.
8040
,
-
1.6991
,
-
1.9498
,
0.0022
,
-
0.
8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
-
1.2567
,
-
0.
4114
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
...
@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional)
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
float
clip
=
0.0
f
;
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// concatenation of hidden states as program output
// concatenation of hidden states as program output
...
@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional)
ih
,
ih
,
ic
,
ic
,
pph
);
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.120174
,
0.043157
,
0.117138
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.120174
,
0.043157
,
-
0.222188
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
-
0.175114
,
-
0.00543549
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.178681
,
-
0.266999
,
0.0459032
,
0.0414126
,
0.272303
,
0.0393149
,
-
0.182201
,
0.32421
,
0.344048
,
0.271694
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
-
0.0413375
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.104168
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.421857
,
0.0459771
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.0459032
,
-
0.144955
,
0.0720673
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.218258
,
0.0414126
,
0.272303
,
0.0393149
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.182201
,
-
0.0232277
,
0.235501
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.946669
,
0.0868676
,
0.044508
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
-
0.373961
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.224905
,
0.32421
,
0.146353
,
0.0789186
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.344048
,
0.271694
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.063456
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
0.148524
,
0.05108
,
-
0.0234895
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
0.187761
,
0.0501726
,
-
0.121584
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
0.0606723
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
...
@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
...
@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional)
ih
,
ih
,
ic
,
ic
,
pph
);
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
...
@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
...
@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional)
ih
,
ih
,
ic
,
ic
,
pph
);
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
...
@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional)
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional)
seq
,
seq
,
w
,
w
,
r
);
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.
031902
1
,
-
0.
0
02
98698
,
-
0.0623361
,
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-
0.
16285
1
,
-
0.
1
02
647
,
-
0.113827
,
0.
0598866
,
0.
101585
,
0.0687269
,
-
0.1
61725
,
-
0.25617
,
-
0.1
62851
,
-
0.1
02647
,
-
0.
142818
,
-
0.
0786602
,
-
0.0613048
,
0.1
79592
,
-
0.071286
,
-
0.1
23496
,
-
0.1
53616
,
-
0.
113827
,
-
0.1
42818
,
0.
0513685
,
0.054787
6
,
0.
0201981
,
-
0.00808453
,
-
0.00520328
,
-
0.
032874
,
-
0.1
95349
,
-
0.
102509
,
-
0.037269
6
,
0.
252296
,
-
0.144544
,
-
0.1073
,
0.
0945081
,
0.264123
,
0.
410805
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.
071286
,
-
0.
150145
,
0.015065
,
-
0.
192699
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.
100877
,
0.0
74
20
6
,
0.0124086
,
-
0.1
39544
,
0.
108016
,
-
0.00
973633
,
-
0.0552699
,
0.025268
1
,
-
0.0
21
20
5
,
-
0.125423
,
0.0206439
,
-
0.1
87097
,
0.
0319021
,
-
0.00
298698
,
-
0.062336
1
,
-
0.05
62072
,
-
0.
123496
,
-
0.
15361
6
,
-
0.0
32874
,
-
0.195349
,
0.0192675
,
-
0.
10863
6
,
0.05
98866
,
0.
0513685
,
0.
054787
6
,
0.0
201981
,
-
0.00808453
,
0.074206
,
0.
012408
6
,
0.
098927
,
-
0.1
40733
,
0.162602
,
0.0143099
,
-
0.0
455534
,
0.0151574
,
-
0.
102509
,
-
0.
139544
,
0.1
08016
,
0.0192675
,
-
0.108636
,
0.0
98927
,
-
0.140733
,
0.
00496085
,
-
0.0
372696
,
0.252
29
6
,
-
0.1
44544
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.1
87329
,
0.0
662588
,
-
0.048577
,
-
0.1873
29
,
-
0.1
12764
,
-
0.120496
,
0.155754
,
0.1
48256
,
0.0
855831
,
-
0.0
171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.
150
145
,
0.0
15065
,
-
0.0
458544
,
-
0.0
401315
,
0.0737483
,
-
0.064505
,
-
0.
005
145
3
,
-
0.0
767618
,
-
0.0735348
,
-
0.
192699
,
-
0.112764
,
-
0.
120496
,
0.155754
,
0.
148
256
,
0.
208491
,
0.348432
,
-
0.
0826436
,
0.101585
,
0.
0687269
,
-
0.161725
,
-
0.256
17
,
-
0.
00520328
,
0.0945081
,
0.
029110
3
,
0.
23027
5
,
-
0.
165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.
0458544
,
0.
26412
3
,
0.
41080
5
,
-
0.
00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
0.
162602
,
-
0.0
401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.0
0160
89
1
,
-
0.1
8481
2
,
0.
147774
,
0.0
143099
,
-
0.0455534
,
0.0151574
,
0.0855831
,
-
0.0
171
89
4
,
-
0.1
4020
2
,
0.
0828391
,
-
0.
021205
,
-
0.
125423
,
0.0206439
,
-
0.187097
,
-
0.
0051453
,
-
0.0
767618
,
-
0.0735348
,
0.
208491
,
0.
348432
,
0.0291103
,
0.230275
,
0.
136898
,
0.0
0160891
,
-
0.184812
,
-
0.
0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
0.
147774
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
...
@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional)
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
seq_len
=
1
;
seq_len
=
1
;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
std
::
vector
<
float
>
input_data1
{
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
migraphx
::
make_op
(
"lstm"
,
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{{
"hidden_size"
,
hidden_size
},
...
@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional)
...
@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional)
seq
,
seq
,
w
,
w
,
r
);
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-
0.104351
,
-
0.0471426
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.0905753
,
0.01506
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
0.101585
,
0.0687269
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
-
0.161725
,
-
0.25617
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
}
}
...
@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
...
@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
}
}
}
}
TEST_CASE
(
lstm_bidirectional_var_seq_lens_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
,
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
,
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
,
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
,
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// concatenation of hidden states as program output
{
std
::
vector
<
int
>
sl_data
{
1
,
2
,
3
};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
sql
=
mm
->
add_literal
(
migraphx
::
literal
{
sl_shape
,
sl_data
});
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
auto
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
auto
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lho
);
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lco
);
mm
->
add_return
({
out_hs
,
lho
,
lco
});
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
outputs
=
p
.
eval
({});
auto
arg_hs
=
outputs
.
front
();
auto
arg_lho
=
outputs
.
at
(
1
);
auto
arg_lco
=
outputs
.
at
(
2
);
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
last_output_data
;
std
::
vector
<
float
>
last_cell_data
;
arg_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
arg_lho
.
visit
([
&
](
auto
output
)
{
last_output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
arg_lco
.
visit
([
&
](
auto
output
)
{
last_cell_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.141643
,
0.0451978
,
0.140804
,
0.0745128
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.262807
,
0.275286
,
0.358395
,
0.266267
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.128254
,
0.125398
,
0.0665142
,
-
0.163651
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.0644683
,
0.371512
,
0.212431
,
-
0.116131
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
float
>
last_output_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.141643
,
0.0451978
,
0.140804
,
0.0745128
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.262807
,
0.275286
,
0.358395
,
0.266267
};
std
::
vector
<
float
>
last_cell_data_gold
{
0.600582
,
-
0.601197
,
0.353558
,
0.789097
,
-
0.326822
,
0.301121
,
0.219523
,
0.415242
,
0.737121
,
0.134902
,
-
0.303595
,
0.241948
,
2.08242
,
0.442513
,
0.187127
,
0.0577626
,
0.391174
,
0.0308845
,
-
0.561745
,
0.0730323
,
-
0.611307
,
0.55454
,
0.4364
,
0.509436
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
last_output_data
,
last_output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
last_cell_data
,
last_cell_data_gold
));
}
// last cell output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
auto
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
auto
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lho
);
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lco
);
mm
->
add_return
({
hs
,
lho
,
lco
});
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
outputs
=
p
.
eval
({});
auto
res_hs
=
outputs
.
at
(
0
);
auto
res_lho
=
outputs
.
at
(
1
);
auto
res_lco
=
outputs
.
at
(
2
);
std
::
vector
<
float
>
hs_data
;
std
::
vector
<
float
>
lho_data
;
std
::
vector
<
float
>
lco_data
;
res_hs
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_lho
.
visit
([
&
](
auto
output
)
{
lho_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_lco
.
visit
([
&
](
auto
output
)
{
lco_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.0459033
,
0.0414126
,
0.272303
,
0.0393149
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
float
>
lho_data_gold
{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
std
::
vector
<
float
>
lco_data_gold
{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
lho_data
,
lho_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
lco_data
,
lco_data_gold
));
}
}
TEST_CASE
(
lstm_bidirectional_actv_func
)
TEST_CASE
(
lstm_bidirectional_actv_func
)
{
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
batch_size
=
3
;
...
...
test/simplify_qdq_test.cpp
View file @
61775eab
...
@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
...
@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
sqdq
.
apply
(
m
);
sqdq
.
apply
(
m
);
}
}
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
broadcast_scale
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
shift
)
const
std
::
vector
<
std
::
size_t
>&
out_lens
,
std
::
size_t
axis
)
{
{
auto
lens
=
x
->
get_shape
().
lens
();
if
(
scale
->
get_shape
().
lens
()
==
out_lens
)
return
scale
;
migraphx
::
instruction_ref
scale_mb
;
migraphx
::
instruction_ref
scale_mb
;
if
(
scale
->
get_shape
().
lens
().
front
()
==
1
)
auto
scale_lens
=
scale
->
get_shape
().
lens
();
if
(
scale_lens
.
front
()
==
1
and
scale_lens
.
size
()
==
1
)
scale_mb
=
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
scale
);
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_
lens
}}),
scale
);
else
else
scale_mb
=
m
.
add_instruction
(
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}}),
scale
);
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
out_lens
}}),
scale
);
return
scale_mb
;
}
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
shift
,
std
::
size_t
q_axis
=
1
)
{
auto
lens
=
x
->
get_shape
().
lens
();
auto
scale_mb
=
broadcast_scale
(
m
,
scale
,
lens
,
q_axis
);
auto
shift_mb
=
auto
shift_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
shift
);
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
shift
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
,
shift_mb
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
,
shift_mb
);
...
@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
...
@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
)
migraphx
::
instruction_ref
scale
,
std
::
size_t
q_axis
=
1
)
{
{
auto
lens
=
x
->
get_shape
().
lens
();
auto
lens
=
x
->
get_shape
().
lens
();
migraphx
::
instruction_ref
scale_mb
;
auto
scale_mb
=
broadcast_scale
(
m
,
scale
,
lens
,
q_axis
);
if
(
scale
->
get_shape
().
lens
().
front
()
==
1
)
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
scale
);
else
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}}),
scale
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
);
}
}
migraphx
::
instruction_ref
add_scale_mul
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
scale1
,
migraphx
::
instruction_ref
scale2
,
std
::
size_t
axis1
,
std
::
size_t
axis2
,
const
std
::
vector
<
std
::
size_t
>&
out_lens
)
{
auto
scale1_mb
=
broadcast_scale
(
m
,
scale1
,
out_lens
,
axis1
);
auto
scale2_mb
=
broadcast_scale
(
m
,
scale2
,
out_lens
,
axis2
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale1_mb
,
scale2_mb
);
}
TEST_CASE
(
remove_qdq
)
TEST_CASE
(
remove_qdq
)
{
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
100
,
100
}};
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
100
,
100
}};
...
@@ -165,12 +186,144 @@ TEST_CASE(dot)
...
@@ -165,12 +186,144 @@ TEST_CASE(dot)
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale1
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
0.4
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
0
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
0
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
1
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m2
.
add_literal
(
0.4
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
0
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale1
,
scale2
,
0
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_broadcasted
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
2
,
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
d2_mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
1000
,
1024
}}}),
d2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_mb
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2_mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
1000
,
1024
}}}),
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_mb
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_transposed
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1024
,
1000
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
d2_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_t
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_t
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
m2
.
add_return
({
d3
});
}
}
...
@@ -178,6 +331,92 @@ TEST_CASE(dot)
...
@@ -178,6 +331,92 @@ TEST_CASE(dot)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
dot_multi_scale_transposed_broadcasted
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1024
,
1000
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
shape
sh4
{
migraphx
::
shape
::
float_type
,
{
1024
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh4
,
0
));
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
2
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
2
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
0
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
0
);
auto
d2_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d2
);
auto
d2_mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
1000
,
1024
}}}),
d2_t
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_mb
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh4
,
0
));
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
2
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
0
);
auto
q2_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q2
);
auto
q2_mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
1000
,
1024
}}}),
q2_t
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_mb
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale1
,
scale2
,
2
,
3
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale_unsupported_axis
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1000
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
0.4
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
1
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
1
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
1
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
t1
,
t2
);
m2
.
add_return
({
dot
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_non_zero_point
)
TEST_CASE
(
dot_non_zero_point
)
{
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
...
@@ -274,12 +513,12 @@ TEST_CASE(dot_add)
...
@@ -274,12 +513,12 @@ TEST_CASE(dot_add)
auto
ab
=
m2
.
add_parameter
(
"ab"
,
sh3
);
auto
ab
=
m2
.
add_parameter
(
"ab"
,
sh3
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale1
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d3
,
ab
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d3
,
ab
);
m2
.
add_return
({
add
});
m2
.
add_return
({
add
});
}
}
...
@@ -320,7 +559,6 @@ TEST_CASE(conv)
...
@@ -320,7 +559,6 @@ TEST_CASE(conv)
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
...
@@ -331,7 +569,8 @@ TEST_CASE(conv)
...
@@ -331,7 +569,8 @@ TEST_CASE(conv)
{
"padding_mode"
,
0
}}),
{
"padding_mode"
,
0
}}),
q1
,
q1
,
weights
);
weights
);
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
m2
.
add_return
({
d6
});
m2
.
add_return
({
d6
});
}
}
...
@@ -340,6 +579,60 @@ TEST_CASE(conv)
...
@@ -340,6 +579,60 @@ TEST_CASE(conv)
}
}
TEST_CASE
(
conv_multi_scale
)
TEST_CASE
(
conv_multi_scale
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
migraphx
::
shape
s8
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s7
);
auto
weights
=
m1
.
add_parameter
(
"weights"
,
s4
);
auto
w_scale
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s8
,
0
));
auto
inp_scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
weights
,
w_scale
,
zero
,
0
);
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
input
,
inp_scale
,
zero
);
auto
d5
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
inp_scale
,
zero
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
d5
,
d1
);
m1
.
add_return
({
c1
});
}
migraphx
::
module
m2
;
{
auto
input
=
m2
.
add_parameter
(
"input"
,
s7
);
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
w_scale
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s8
,
0
));
auto
inp_scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q_inp
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
inp_scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q_inp
,
weights
);
auto
out_scale
=
add_scale_mul
(
m2
,
inp_scale
,
w_scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d1
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
m2
.
add_return
({
d1
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv_multi_scale_unsupported_axis
)
{
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
...
@@ -430,7 +723,6 @@ TEST_CASE(conv_bias_add)
...
@@ -430,7 +723,6 @@ TEST_CASE(conv_bias_add)
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
...
@@ -442,7 +734,8 @@ TEST_CASE(conv_bias_add)
...
@@ -442,7 +734,8 @@ TEST_CASE(conv_bias_add)
{
"padding_mode"
,
0
}}),
{
"padding_mode"
,
0
}}),
q1
,
q1
,
weights
);
weights
);
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
auto
b1
=
m2
.
add_instruction
(
auto
b1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d6
,
b1
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d6
,
b1
);
...
@@ -519,8 +812,6 @@ TEST_CASE(conv_pooling_dot)
...
@@ -519,8 +812,6 @@ TEST_CASE(conv_pooling_dot)
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
scale2
=
m2
.
add_literal
(
0.25
f
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
...
@@ -533,7 +824,8 @@ TEST_CASE(conv_pooling_dot)
...
@@ -533,7 +824,8 @@ TEST_CASE(conv_pooling_dot)
{
"padding_mode"
,
0
}}),
{
"padding_mode"
,
0
}}),
q1
,
q1
,
weights
);
weights
);
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
out_scale1
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale1
);
auto
bc1
=
m2
.
add_instruction
(
auto
bc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d5
,
bc1
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d5
,
bc1
);
...
@@ -548,7 +840,8 @@ TEST_CASE(conv_pooling_dot)
...
@@ -548,7 +840,8 @@ TEST_CASE(conv_pooling_dot)
auto
fl
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
ap
);
auto
fl
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
ap
);
auto
q4
=
add_quantize_op
(
m2
,
"quantizelinear"
,
fl
,
scale
,
zero
);
auto
q4
=
add_quantize_op
(
m2
,
"quantizelinear"
,
fl
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q4
,
db
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q4
,
db
);
auto
d9
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale2
);
auto
out_scale2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
0
,
dot
->
get_shape
().
lens
());
auto
d9
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale2
);
auto
mb1
=
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1000
}}}),
d3
);
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1000
}}}),
d3
);
auto
a2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d9
,
mb1
);
auto
a2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d9
,
mb1
);
...
...
test/verify/test_lstm_bidirct_3args_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_3args_layout
:
verify_program
<
test_lstm_bidirct_3args_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_bidirct_last_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_last_layout
:
verify_program
<
test_lstm_bidirct_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_hs_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_hs_layout
:
verify_program
<
test_lstm_forward_hs_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_last_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_last_layout
:
verify_program
<
test_lstm_forward_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
l_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
len
=
mm
->
add_literal
(
migraphx
::
literal
(
l_shape
,
{
1
,
2
}));
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
,
len
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
0 → 100644
View file @
61775eab
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_reverse_3args_cell_layout
:
verify_program
<
test_lstm_reverse_3args_cell_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment