Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
7bffbe43
Unverified
Commit
7bffbe43
authored
Nov 16, 2023
by
Zakor Gyula
Committed by
GitHub
Nov 16, 2023
Browse files
Merge branch 'develop' into qlinearmul
parents
566907fc
0039b11a
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3274 additions
and
410 deletions
+3274
-410
src/include/migraphx/matcher.hpp
src/include/migraphx/matcher.hpp
+19
-0
src/onnx/parse_lstm.cpp
src/onnx/parse_lstm.cpp
+47
-0
src/simplify_qdq.cpp
src/simplify_qdq.cpp
+103
-35
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+171
-0
test/onnx/lstm_bi_layout_cell_test.onnx
test/onnx/lstm_bi_layout_cell_test.onnx
+0
-0
test/onnx/lstm_bi_layout_last_test.onnx
test/onnx/lstm_bi_layout_last_test.onnx
+0
-0
test/onnx/lstm_f_layout_cell_test.onnx
test/onnx/lstm_f_layout_cell_test.onnx
+0
-0
test/onnx/lstm_f_layout_hs_test.onnx
test/onnx/lstm_f_layout_hs_test.onnx
+0
-0
test/onnx/lstm_r_layout_hs_cell_test.onnx
test/onnx/lstm_r_layout_hs_cell_test.onnx
+0
-0
test/onnx/lstm_r_layout_test.onnx
test/onnx/lstm_r_layout_test.onnx
+0
-0
test/onnx/onnx_rnn_test.cpp
test/onnx/onnx_rnn_test.cpp
+332
-0
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/quantization.cpp
test/quantization.cpp
+70
-70
test/ref/rnn_ops.cpp
test/ref/rnn_ops.cpp
+1727
-234
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+363
-70
test/verify/test_lstm_bidirct_3args_layout.cpp
test/verify/test_lstm_bidirct_3args_layout.cpp
+77
-0
test/verify/test_lstm_bidirct_last_layout.cpp
test/verify/test_lstm_bidirct_last_layout.cpp
+95
-0
test/verify/test_lstm_forward_hs_layout.cpp
test/verify/test_lstm_forward_hs_layout.cpp
+95
-0
test/verify/test_lstm_forward_last_layout.cpp
test/verify/test_lstm_forward_last_layout.cpp
+97
-0
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
+78
-0
No files found.
src/include/migraphx/matcher.hpp
View file @
7bffbe43
...
...
@@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins)
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
x
)
{
return
x
->
get_shape
()
==
s
;
});
}
MIGRAPHX_PRED_MATCHER
(
has_same_value
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
bool
all_same
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
s
)
{
all_same
=
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
const
auto
&
scale
)
{
return
float_equal
(
scale
,
s
.
front
());
});
});
return
all_same
;
}
MIGRAPHX_BASIC_MATCHER
(
output
,
const
matcher_context
&
,
instruction_ref
ins
)
{
if
(
ins
->
outputs
().
size
()
==
1
)
...
...
@@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms)
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"convert"
))(
ms
...);
}
template
<
class
...
Ms
>
auto
skip_broadcasts_transposes_contiguous
(
Ms
...
ms
)
{
return
skip
(
name
(
"broadcast"
,
"multibroadcast"
,
"contiguous"
,
"transpose"
))(
ms
...);
}
template
<
class
T
>
inline
auto
has_value
(
T
x
,
float
tolerance
=
1e-6
)
{
...
...
src/onnx/parse_lstm.cpp
View file @
7bffbe43
...
...
@@ -116,6 +116,37 @@ void lstm_actv_functions(op::rnn_direction dirct, std::vector<std::string>& actv
}
}
void
lstm_transpose_inputs
(
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
{
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
args
[
0
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
0
]);
if
(
args
.
size
()
>=
6
and
not
args
[
5
]
->
is_undefined
())
{
args
[
5
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
5
]);
}
if
(
args
.
size
()
>=
7
and
not
args
[
6
]
->
is_undefined
())
{
args
[
6
]
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
args
[
6
]);
}
}
void
lstm_transpose_outputs
(
onnx_parser
::
node_info
&
info
,
instruction_ref
&
hidden_states
,
instruction_ref
&
last_output
,
instruction_ref
&
last_cell_output
)
{
std
::
vector
<
int64_t
>
perm_hs
{
2
,
0
,
1
,
3
};
hidden_states
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hs
}}),
hidden_states
);
std
::
vector
<
int64_t
>
perm_last
{
1
,
0
,
2
};
last_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_output
);
last_cell_output
=
info
.
add_instruction
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm_last
}}),
last_cell_output
);
}
struct
parse_lstm
:
op_parser
<
parse_lstm
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"LSTM"
}};
}
...
...
@@ -202,6 +233,12 @@ struct parse_lstm : op_parser<parse_lstm>
input_forget
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"input_forget"
)).
at
<
int
>
();
}
int
layout
=
0
;
if
(
contains
(
info
.
attributes
,
"layout"
))
{
layout
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"layout"
)).
at
<
int
>
();
}
// append undefined opeator to make 6 arguments
if
(
args
.
size
()
<
8
)
{
...
...
@@ -209,6 +246,11 @@ struct parse_lstm : op_parser<parse_lstm>
args
.
insert
(
args
.
end
(),
8
-
args
.
size
(),
ins
);
}
if
(
layout
!=
0
)
{
lstm_transpose_inputs
(
info
,
args
);
}
// first output for concatenation of hidden states
auto
hidden_states
=
info
.
add_instruction
(
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -224,6 +266,11 @@ struct parse_lstm : op_parser<parse_lstm>
auto
last_cell_output
=
info
.
add_instruction
(
make_op
(
"rnn_last_cell_output"
),
hidden_states
);
if
(
layout
!=
0
)
{
lstm_transpose_outputs
(
info
,
hidden_states
,
last_output
,
last_cell_output
);
}
return
{
hidden_states
,
last_output
,
last_cell_output
};
}
};
...
...
src/simplify_qdq.cpp
View file @
7bffbe43
...
...
@@ -45,77 +45,145 @@ std::unordered_set<std::string> get_quantizable_op_names()
return
s
;
}
MIGRAPHX_PRED_MATCHER
(
has_same_value
,
instruction_ref
ins
)
struct
match_find_quantizable_ops
{
if
(
ins
->
name
()
!=
"@literal"
)
return
false
;
bool
all_same
=
false
;
ins
->
get_literal
().
visit
([
&
](
auto
s
)
{
all_same
=
std
::
all_of
(
s
.
begin
()
+
1
,
s
.
end
(),
[
&
](
const
auto
&
scale
)
{
return
float_equal
(
scale
,
s
.
front
());
static
bool
is_valid_scale
(
instruction_ref
scale
,
std
::
vector
<
std
::
size_t
>
lens
,
std
::
size_t
axis
)
{
return
scale
->
get_shape
().
scalar
()
or
scale
->
get_shape
().
elements
()
==
lens
.
at
(
axis
);
}
static
bool
is_valid_zero_point
(
instruction_ref
zp
)
{
if
(
not
zp
->
can_eval
())
return
false
;
bool
all_zeros
=
false
;
zp
->
eval
().
visit
([
&
](
auto
z
)
{
all_zeros
=
std
::
all_of
(
z
.
begin
(),
z
.
end
(),
[
&
](
auto
val
)
{
return
float_equal
(
val
,
0
);
});
});
});
return
all_same
;
}
return
all_zeros
;
}
struct
match_find_quantizable_ops
{
static
auto
scale_broadcast_op
(
instruction_ref
scale
,
std
::
vector
<
std
::
size_t
>
lens
,
std
::
size_t
axis
)
{
if
(
scale
->
get_shape
().
scalar
())
{
return
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}});
}
else
{
return
migraphx
::
make_op
(
"broadcast"
,
{{
"out_lens"
,
lens
},
{
"axis"
,
axis
}});
}
}
static
auto
dequantizelinear_op
(
const
std
::
string
&
name
,
const
std
::
string
&
scale
)
// Helper function to insert quantized versions of any broadcasts and transpose ops that
// occur between dequantizelinear and the quantized op
static
auto
propagate_quantized_ins
(
module
&
m
,
const
instruction_ref
dqins
,
const
instruction_ref
qop
)
{
auto
qinp
=
dqins
->
inputs
().
front
();
auto
next_ins
=
dqins
;
while
(
next_ins
!=
qop
)
{
if
(
next_ins
->
name
()
!=
"dequantizelinear"
)
{
qinp
=
m
.
insert_instruction
(
qop
,
next_ins
->
get_operator
(),
qinp
);
}
next_ins
=
next_ins
->
outputs
().
front
();
}
return
qinp
;
}
static
auto
dequantizelinear_op
(
const
std
::
string
&
scale
,
const
std
::
string
&
zp
)
{
return
match
::
name
(
"dequantizelinear"
)(
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"quantizelinear"
))(
match
::
any
()
.
bind
(
name
)
)),
match
::
arg
(
1
)(
match
::
skip_broadcasts
(
has_same_value
().
bind
(
scale
))),
match
::
arg
(
2
)(
match
::
skip_broadcasts
(
match
::
all_of
(
match
::
has_value
(
0
)
))));
match
::
arg
(
0
)(
match
::
skip
(
match
::
name
(
"quantizelinear"
))(
match
::
any
())),
match
::
arg
(
1
)(
match
::
skip_broadcasts
(
match
::
is_constant
().
bind
(
scale
))),
match
::
arg
(
2
)(
match
::
skip_broadcasts
(
match
::
is_constant
().
bind
(
zp
))));
}
auto
matcher
()
const
{
return
match
::
name
(
get_quantizable_op_names
())(
match
::
arg
(
0
)(
dequantizelinear_op
(
"x1"
,
"scale1"
)),
match
::
arg
(
1
)(
dequantizelinear_op
(
"x2"
,
"scale2"
)));
match
::
arg
(
0
)(
match
::
skip_broadcasts_transposes_contiguous
(
dequantizelinear_op
(
"scale1"
,
"zp1"
).
bind
(
"dq1"
))),
match
::
arg
(
1
)(
match
::
skip_broadcasts_transposes_contiguous
(
dequantizelinear_op
(
"scale2"
,
"zp2"
).
bind
(
"dq2"
))));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
qop
=
r
.
result
;
auto
q1
=
r
.
instructions
[
"
x
1"
];
auto
q2
=
r
.
instructions
[
"
x
2"
];
auto
d
q1
=
r
.
instructions
[
"
dq
1"
];
auto
d
q2
=
r
.
instructions
[
"
dq
2"
];
auto
scale1
=
r
.
instructions
[
"scale1"
];
auto
scale2
=
r
.
instructions
[
"scale2"
];
auto
zp1
=
r
.
instructions
[
"zp1"
];
auto
zp2
=
r
.
instructions
[
"zp2"
];
// Only INT8 type currently supported
if
(
q1
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
q2
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
if
(
d
q1
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
or
d
q2
->
inputs
().
front
()
->
get_shape
().
type
()
!=
migraphx
::
shape
::
int8_type
)
return
;
double
scale
;
visit_all
(
scale1
->
get_literal
(),
scale2
->
get_literal
(
))
(
[
&
](
const
auto
s1
,
const
auto
s2
)
{
scale
=
s1
.
front
()
*
s2
.
front
();
})
;
// Only symmetric quantization supported (ie. non-zero zero_points not allowed)
if
(
not
(
is_valid_zero_point
(
zp1
)
and
is_valid_zero_point
(
zp2
))
)
return
;
// Only support scalar and 1D scales
if
(
scale1
->
get_shape
().
lens
().
size
()
!=
1
or
scale2
->
get_shape
().
lens
().
size
()
!=
1
)
return
;
// Propagate q1 and q2 through any broadcasts and transposes before qop
auto
qop_args
=
qop
->
inputs
();
qop_args
.
at
(
0
)
=
q1
;
qop_args
.
at
(
1
)
=
q2
;
qop_args
.
at
(
0
)
=
propagate_quantized_ins
(
m
,
dq1
,
qop
)
;
qop_args
.
at
(
1
)
=
propagate_quantized_ins
(
m
,
dq2
,
qop
)
;
instruction_ref
dq
;
instruction_ref
dq
_scale
;
instruction_ref
out
_scale
;
instruction_ref
zero_point
;
if
(
qop
->
name
()
==
"convolution"
)
{
auto
conv_val
=
qop
->
get_operator
().
to_value
();
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_convolution"
,
conv_val
),
qop_args
);
auto
out_lens
=
dq
->
get_shape
().
lens
();
// Input scale should always be scalar and weight scale can be scalar or 1D of the
// same lens as the output channel dim (dim 1 in the output)
if
(
not
(
is_valid_scale
(
scale1
,
out_lens
,
1
)
and
is_valid_scale
(
scale2
,
out_lens
,
1
)))
return
;
auto
s1_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale1
,
out_lens
,
1
),
scale1
);
auto
s2_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale2
,
out_lens
,
1
),
scale2
);
out_scale
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"mul"
),
s1_bcast
,
s2_bcast
);
}
else
if
(
qop
->
name
()
==
"dot"
)
{
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_dot"
),
qop_args
);
dq
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"quant_dot"
),
qop_args
);
auto
out_lens
=
dq
->
get_shape
().
lens
();
// For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M
// for input1 and K for input 2
if
(
not
(
is_valid_scale
(
scale1
,
out_lens
,
out_lens
.
size
()
-
2
)
and
is_valid_scale
(
scale2
,
out_lens
,
out_lens
.
size
()
-
1
)))
return
;
auto
s1_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale1
,
out_lens
,
out_lens
.
size
()
-
2
),
scale1
);
auto
s2_bcast
=
m
.
insert_instruction
(
qop
,
scale_broadcast_op
(
scale2
,
out_lens
,
out_lens
.
size
()
-
1
),
scale2
);
out_scale
=
m
.
insert_instruction
(
qop
,
migraphx
::
make_op
(
"mul"
),
s1_bcast
,
s2_bcast
);
}
auto
ins_type
=
qop
->
get_shape
().
type
();
dq_scale
=
m
.
add_literal
(
literal
({
ins_type
},
{
scale
}));
auto
lens
=
dq
->
get_shape
().
lens
();
auto
scale_mb
=
m
.
insert_instruction
(
qop
,
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
dq_scale
);
dq
=
m
.
insert_instruction
(
qop
,
make_op
(
"dequantizelinear"
),
dq
,
scale_mb
);
dq
=
m
.
insert_instruction
(
qop
,
make_op
(
"dequantizelinear"
),
dq
,
out_scale
);
m
.
replace_instruction
(
qop
,
dq
);
}
};
...
...
test/onnx/gen_onnx.py
View file @
7bffbe43
...
...
@@ -4484,6 +4484,177 @@ def lrn_test():
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
lstm_bi_layout_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
2
,
60
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
''
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'bidirectional'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
cellout
])
@
onnx_test
()
def
lstm_bi_layout_last_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
2
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
2
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
2
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
2
,
20
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
2
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
,
'output'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'bidirectional'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
,
output
])
@
onnx_test
()
def
lstm_f_layout_hs_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
1
,
20
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
,
'output'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'forward'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
,
output
])
@
onnx_test
()
def
lstm_f_layout_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
''
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'forward'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
cellout
])
@
onnx_test
()
def
lstm_r_layout_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
hs
=
helper
.
make_tensor_value_info
(
'hs'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
'hs'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'reverse'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
hs
])
@
onnx_test
()
def
lstm_r_layout_hs_cell_test
():
seq
=
helper
.
make_tensor_value_info
(
'seq'
,
TensorProto
.
FLOAT
,
[
3
,
5
,
10
])
w
=
helper
.
make_tensor_value_info
(
'w'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
10
])
r
=
helper
.
make_tensor_value_info
(
'r'
,
TensorProto
.
FLOAT
,
[
1
,
80
,
20
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT
,
[
1
,
160
])
seq_len
=
helper
.
make_tensor_value_info
(
'seq_len'
,
TensorProto
.
INT32
,
[
3
])
h0
=
helper
.
make_tensor_value_info
(
'h0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
c0
=
helper
.
make_tensor_value_info
(
'c0'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
pph
=
helper
.
make_tensor_value_info
(
'pph'
,
TensorProto
.
FLOAT
,
[
1
,
60
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
cellout
=
helper
.
make_tensor_value_info
(
'cellout'
,
TensorProto
.
FLOAT
,
[
3
,
1
,
20
])
node
=
onnx
.
helper
.
make_node
(
'LSTM'
,
inputs
=
[
'seq'
,
'w'
,
'r'
,
'bias'
,
'seq_len'
,
'h0'
,
'c0'
,
'pph'
],
outputs
=
[
''
,
'output'
,
'cellout'
],
activations
=
[
'sigmoid'
,
'tanh'
,
'tanh'
],
clip
=
0
,
direction
=
'reverse'
,
hidden_size
=
20
,
input_forget
=
1
,
layout
=
1
)
return
([
node
],
[
seq
,
w
,
r
,
bias
,
seq_len
,
h0
,
c0
,
pph
],
[
output
,
cellout
])
@
onnx_test
()
def
matmul_bmbm_test
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
FLOAT
,
[
3
,
6
,
7
])
...
...
test/onnx/lstm_bi_layout_cell_test.onnx
0 → 100644
View file @
7bffbe43
File added
test/onnx/lstm_bi_layout_last_test.onnx
0 → 100644
View file @
7bffbe43
File added
test/onnx/lstm_f_layout_cell_test.onnx
0 → 100644
View file @
7bffbe43
File added
test/onnx/lstm_f_layout_hs_test.onnx
0 → 100644
View file @
7bffbe43
File added
test/onnx/lstm_r_layout_hs_cell_test.onnx
0 → 100644
View file @
7bffbe43
File added
test/onnx/lstm_r_layout_test.onnx
0 → 100644
View file @
7bffbe43
File added
test/onnx/onnx_rnn_test.cpp
View file @
7bffbe43
...
...
@@ -1092,6 +1092,115 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE
(
lstm_forward_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
1
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 8 args, hs and last output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
auto
prog
=
optimize_onnx
(
"lstm_f_layout_hs_test.onnx"
);
EXPECT
(
p
==
prog
);
}
// 8 args, cell output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_f_layout_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
// activation functions
TEST_CASE
(
lstm_forward_actv_func
)
{
...
...
@@ -1342,6 +1451,117 @@ TEST_CASE(lstm_reverse)
}
}
TEST_CASE
(
lstm_reverse_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
1
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 8 args, hs output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
auto
prog
=
optimize_onnx
(
"lstm_r_layout_test.onnx"
);
EXPECT
(
p
==
prog
);
}
// 8 args, last and cell output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_r_layout_hs_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
lstm_bidirectional
)
{
std
::
size_t
sl
=
5
;
// sequence len
...
...
@@ -1594,6 +1814,118 @@ TEST_CASE(lstm_bidirectional)
}
}
TEST_CASE
(
lstm_bidirectional_layout
)
{
std
::
size_t
sl
=
5
;
// sequence len
std
::
size_t
bs
=
3
;
// batch size
std
::
size_t
hs
=
20
;
// hidden size
std
::
size_t
is
=
10
;
// input size
std
::
size_t
nd
=
2
;
// num directions
float
clip
=
0.0
f
;
int
input_forget
=
1
;
migraphx
::
shape
seq_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
sl
,
is
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
is
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
4
*
hs
,
hs
}};
migraphx
::
shape
bias_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
8
*
hs
}};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
bs
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
bs
,
nd
,
hs
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
nd
,
3
*
hs
}};
// 0 activation function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
auto
prog
=
optimize_onnx
(
"lstm_bi_layout_last_test.onnx"
);
EXPECT
(
p
==
prog
);
}
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_parameter
(
"seq"
,
seq_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
bias_shape
);
auto
seq_len
=
mm
->
add_parameter
(
"seq_len"
,
sl_shape
);
auto
ih
=
mm
->
add_parameter
(
"h0"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"c0"
,
ih_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hs
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
input_forget
}}),
seq
,
w
,
r
,
bias
,
seq_len
,
ih
,
ic
,
pph
);
auto
last_cell
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_cell
);
auto
prog
=
optimize_onnx
(
"lstm_bi_layout_cell_test.onnx"
);
EXPECT
(
p
==
prog
);
}
}
TEST_CASE
(
lstm_bi_actv_funcs
)
{
std
::
size_t
sl
=
5
;
// sequence len
...
...
test/py/onnx_backend_test.py
View file @
7bffbe43
...
...
@@ -574,7 +574,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
# fails
# from OnnxBackendNodeModelTest
backend_test
.
exclude
(
r
'test_gru_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_lstm_batchwise_cpu'
)
backend_test
.
exclude
(
r
'test_simple_rnn_batchwise_cpu'
)
# from OnnxBackendPyTorchConvertedModelTest
backend_test
.
exclude
(
r
'test_MaxPool1d_stride_padding_dilation_cpu'
)
...
...
test/quantization.cpp
View file @
7bffbe43
...
...
@@ -636,13 +636,12 @@ TEST_CASE(dot_float)
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale
);
auto
zp_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
auto
dc
=
mm
->
add_literal
(
100.0
f
);
auto
mdc
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sc
.
lens
()}}),
dc
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
mdc
);
auto
quant_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
quant_a
,
quant_b
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args)
auto
pa
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
pb
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
scale_a
=
mm
->
add_literal
(
10.0
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
);
auto
scale_a
_lit
=
mm
->
add_literal
(
10.0
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale_a
_lit
);
auto
zp_a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
scale_b
=
mm
->
add_literal
(
5.0
);
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
);
auto
qa
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pa
,
scale_a
,
zp_a
);
auto
scale_b
_lit
=
mm
->
add_literal
(
5.0
);
auto
scale_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
scale_b
_lit
);
auto
zp_b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sb
.
lens
()}}),
zp
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
scale
=
mm
->
add_literal
(
50.0
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
scale
);
auto
qb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pb
,
scale_b
,
zp_b
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
scale_a_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_a_lit
);
auto
scale_b_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
scale_b_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_a_mb
,
scale_b_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg)
migraphx
::
shape
sa
{
migraphx
::
shape
::
half_type
,
{
9
,
9
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
sa
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
10.0
}));
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sa
.
lens
()}}),
zp
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
dq_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sa
.
type
()},
{
100.0
}));
dq_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
dq_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
dq_scale
);
auto
qx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
scale
,
zp
);
auto
qdot
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qx
,
qx
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale
,
scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
};
...
...
@@ -851,10 +851,10 @@ TEST_CASE(conv_float)
auto
px
=
mm
->
add_parameter
(
"x"
,
sx
);
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
10.0
f
);
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
_lit
=
mm
->
add_literal
(
10.0
f
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
...
...
@@ -862,13 +862,11 @@ TEST_CASE(conv_float)
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
migraphx
::
shape
sc
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
std
::
vector
<
float
>
vec
(
sc
.
elements
(),
100.0
f
);
migraphx
::
shape
s_scale
{
migraphx
::
shape
::
float_type
,
sc
.
lens
()};
auto
d_scale
=
mm
->
add_literal
(
100.0
f
);
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -930,20 +928,21 @@ TEST_CASE(conv_half)
auto
px
=
mm
->
add_parameter
(
"x"
,
sx
);
auto
pw
=
mm
->
add_parameter
(
"w"
,
sw
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
);
auto
zp
=
mm
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
auto
scale
_lit
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
10.0
}));
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
scale
_lit
);
zp
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sx
.
lens
()}}),
zp
);
auto
quant_x
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
px
,
scale
,
zp
);
auto
quant_w
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
pw
,
scale
,
zp
);
auto
quant
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
quant_x
,
quant_w
);
auto
d_scale
=
mm
->
add_literal
(
migraphx
::
literal
({
sx
.
type
()},
{
100.0
}));
d_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
4
,
4
,
1
,
1
}}}),
d_scale
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
d_scale
);
auto
scale_mb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
quant
->
get_shape
().
lens
()}}),
scale_lit
);
auto
out_scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mb
,
scale_mb
);
auto
r
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
quant
,
out_scale
);
mm
->
add_return
({
r
});
return
p
;
...
...
@@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph)
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
s1
);
auto
zpb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sy
.
lens
()}}),
zp1
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
s
o
=
then_mod
->
add_
literal
(
100.0
f
);
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
auto
qb
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
b
,
sb
,
zpb
);
auto
qdot
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
qa
,
qb
);
auto
s
1_mb
=
then_mod
->
add_
instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qdot
->
get_shape
().
lens
()}}),
s1
);
auto
so
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
s1_mb
,
s1_mb
);
auto
r
=
then_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qdot
,
so
);
then_mod
->
add_return
({
r
});
migraphx
::
shape
sd
{
migraphx
::
shape
::
float_type
,
{
2
,
2
,
4
,
6
}};
...
...
@@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph)
auto
w
=
mm
->
add_parameter
(
"w"
,
sw
);
// else submod
auto
*
else_mod
=
p
.
create_module
(
"If_6_else"
);
auto
sax
=
else_mod
->
add_literal
(
2.0
f
);
auto
sax
_lit
=
else_mod
->
add_literal
(
2.0
f
);
auto
zp
=
else_mod
->
add_literal
(
static_cast
<
int8_t
>
(
0
));
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
);
auto
sax
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
sax
_lit
);
auto
zpx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sd
.
lens
()}}),
zp
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
ssw
=
else_mod
->
add_literal
(
1.66667
f
);
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
);
auto
qx
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
x
,
sax
,
zpx
);
auto
ssw
_lit
=
else_mod
->
add_literal
(
1.66667
f
);
auto
ssw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
ssw
_lit
);
auto
zpw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sw
.
lens
()}}),
zp
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
so1
=
else_mod
->
add_literal
(
3.33333
f
);
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
sout
.
lens
()}}),
so1
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
auto
qw
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quantizelinear"
),
w
,
ssw
,
zpw
);
auto
qconv
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
),
qx
,
qw
);
auto
ssw_mb
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
qconv
->
get_shape
().
lens
()}}),
ssw_lit
);
auto
so1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
sax
,
ssw_mb
);
auto
r1
=
else_mod
->
add_instruction
(
migraphx
::
make_op
(
"dequantizelinear"
),
qconv
,
so1
);
else_mod
->
add_return
({
r1
});
auto
ret
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"if"
),
{
cond
},
{
then_mod
,
else_mod
});
...
...
test/ref/rnn_ops.cpp
View file @
7bffbe43
...
...
@@ -3228,6 +3228,264 @@ TEST_CASE(lstm_forward)
}
}
TEST_CASE
(
lstm_forward_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// forward, hidden state concatenation as output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
und
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.0417273
,
-
0.272355
,
0.206765
,
0.223879
,
0.0742487
,
-
0.0800085
,
0.259897
,
0.0670196
,
-
0.00532985
,
0.0440265
,
0.29654
,
-
0.0463156
,
-
0.0847427
,
0.0874114
,
0.304256
,
-
0.0585745
,
0.138193
,
-
0.0322939
,
-
0.0891815
,
0.15773
,
0.184266
,
0.0610048
,
-
0.138041
,
0.0963885
,
0.0498799
,
0.125772
,
0.0533032
,
-
0.131413
,
-
0.0223018
,
0.131113
,
0.135643
,
-
0.056620
,
0.19139
,
-
0.127708
,
-
0.409371
,
-
0.136186
,
0.0213755
,
-
0.146027
,
-
0.0324509
,
-
0.0620429
,
0.0988431
,
-
0.018085
,
-
0.159434
,
0.030266
,
0.142701
,
0.0342236
,
-
0.198664
,
0.0702607
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
}
// forward, last_output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
und
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0847427
,
0.0874114
,
0.304256
,
-
0.0585745
,
-
0.0223018
,
0.131113
,
0.135643
,
-
0.0566208
,
0.142701
,
0.0342236
,
-
0.198664
,
0.0702607
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// forward, last_cell_output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
und
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.111454
,
0.247794
,
0.471087
,
-
0.220574
,
-
0.048196
,
0.263184
,
0.283258
,
-
0.14882
,
0.605585
,
0.078598
,
-
0.64457
,
0.119811
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
TEST_CASE
(
lstm_forward_more
)
{
std
::
size_t
batch_size
=
3
;
...
...
@@ -3519,7 +3777,7 @@ TEST_CASE(lstm_forward_more)
}
}
TEST_CASE
(
lstm_
reverse
)
TEST_CASE
(
lstm_
forward_more_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
...
...
@@ -3527,32 +3785,668 @@ TEST_CASE(lstm_reverse)
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.
2763
,
-
0.
4715
,
-
0.
3010
,
-
0.
2306
,
-
0.
2283
,
-
0.2
656
,
0.2
035
,
0.3
570
,
-
0.1499
,
0.4390
,
-
0.
1843
,
0.23
51
,
0.
3357
,
0.
1217
,
0.
1401
,
0.3
300
,
-
0.0429
,
0.3266
,
0.4
834
,
-
0.
3914
,
-
0.
1480
,
0.373
4
,
-
0.03
7
2
,
-
0.
174
6
,
0.0
550
,
0.4
17
7
,
-
0.
1332
,
0.4391
,
-
0.
3287
,
-
0.
4401
,
0.
1486
,
0.1346
,
0.
1048
,
-
0.4361
,
0.0886
,
-
0.
3840
,
-
0.
2730
,
-
0.1710
,
0.3274
,
0.
0169
,
-
0.
4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4
150
,
-
0.4684
,
-
0.
2522
};
0.
1236
,
-
0.
3942
,
0.
4149
,
0.
0795
,
0.
4934
,
-
0.2
858
,
0.2
602
,
-
0.3
098
,
0.0567
,
0.3344
,
0.
3607
,
-
0.05
51
,
0.
4952
,
0.
3799
,
0.
0630
,
-
0.3
532
,
0.0023
,
-
0.0592
,
0.4
267
,
0.
2382
,
-
0.
078
4
,
-
0.
0
032
,
-
0.
247
6
,
-
0.0
206
,
-
0.4963
,
0.4
83
7
,
0.
0827
,
0.0123
,
-
0.
1203
,
-
0.
0279
,
-
0.
0049
,
0.4721
,
-
0.
3564
,
-
0.1286
,
0.4090
,
-
0.
0504
,
0.
0575
,
-
0.2138
,
0.1071
,
0.
1976
,
-
0.
0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4
845
,
-
0.1496
,
0.
3285
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// forward, 3 args
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.071286
,
-
0.102509
,
-
0.0372696
,
0.252296
,
-
0.144544
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.074206
,
0.0124086
,
-
0.139544
,
0.108016
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.187329
,
-
0.0458544
,
-
0.0401315
,
0.0737483
,
-
0.064505
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
0.0855831
,
-
0.0171894
,
-
0.140202
,
0.0828391
,
0.136898
,
0.00160891
,
-
0.184812
,
0.147774
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// forward, 8 args
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.0459033
,
0.0414126
,
0.272303
,
0.0393149
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
}
// forward, last_output as program output, sequence length shorter
// than max_seq_len
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
und
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
last_hs
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
last_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0847427
,
0.0874114
,
0.304256
,
-
0.0585745
,
-
0.0223018
,
0.131113
,
0.135643
,
-
0.0566208
,
0.142701
,
0.0342236
,
-
0.198664
,
0.0702607
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// seq_len = 1
{
seq_len
=
1
;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
hs_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
}
}
TEST_CASE
(
lstm_reverse
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
// reverse, concatenation of hidden states as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, sequence lengths are the same, but less than max_seq_lens
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
2
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// variable sequence lengths
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
{
3
,
2
,
1
};
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
-
0.168327
,
0.00023761
,
0.167567
,
-
0.0621982
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, last cell output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, 0 actv function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{}},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
TEST_CASE
(
lstm_reverse_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
1
;
std
::
vector
<
float
>
w_data
{
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
4313
,
-
0.9730
,
-
0.2005
,
2
.3
930
,
-
0.5221
,
-
0.1331
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.112
2
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.
07
01
,
-
0.4100
,
-
2.2344
,
0.
3685
,
0.4583
,
2.3794
,
1.0
3
72
,
-
0.8887
,
0.
789
2
,
-
0.4
012
,
-
0.2818
,
-
2.3374
,
1.5310
};
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0
.3
577
,
1.3508
,
-
0.5366
,
0.
4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.133
2
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.
4
01
2
,
2.3930
,
-
0.5221
,
-
0.
1331
,
-
1.0
9
72
,
0.9816
,
0.
112
2
,
-
0.4
100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.5289
,
1.0986
,
...
...
@@ -3593,14 +4487,15 @@ TEST_CASE(lstm_reverse)
-
0.5811
,
0.2166
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
// reverse, concatenation of hidden states as program output
{
migraphx
::
program
p
;
...
...
@@ -3614,7 +4509,13 @@ TEST_CASE(lstm_reverse)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -3633,18 +4534,21 @@ TEST_CASE(lstm_reverse)
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -3661,14 +4565,20 @@ TEST_CASE(lstm_reverse)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
2
,
batch_size
,
input_size
}};
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
0
}}),
seq_orig
,
seq_p
);
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -3687,22 +4597,26 @@ TEST_CASE(lstm_reverse)
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
};
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -3722,7 +4636,13 @@ TEST_CASE(lstm_reverse)
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
{
3
,
2
,
1
};
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -3741,18 +4661,22 @@ TEST_CASE(lstm_reverse)
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
-
0.168327
,
0.00023761
,
0.167567
,
-
0.0621982
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
-
0.126517
,
0.0359124
,
0.107453
,
-
0.0617278
,
-
0.168327
,
0.00023761
,
0.167567
,
-
0.0621982
,
-
0.204545
,
0.0146403
,
0.210057
,
0.0296268
,
0
,
0
,
0
,
0
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
-
0.102969
,
0.295872
,
0.515859
,
0.246501
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -3763,6 +4687,10 @@ TEST_CASE(lstm_reverse)
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
...
...
@@ -3777,46 +4705,8 @@ TEST_CASE(lstm_reverse)
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.443077
,
-
0.325425
,
-
0.249367
,
-
0.270812
,
0.122913
,
0.118537
,
0.0370199
,
-
0.0164687
,
-
0.00754759
,
0.141613
,
0.348002
,
0.667298
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, 0 actv function
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
{}},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
...
...
@@ -3900,19 +4790,214 @@ TEST_CASE(lstm_reverse_actv)
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
std
::
vector
<
float
>
pph_data
{
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
0.246078
,
0.199709
,
0.303753
,
0.301178
,
0.264634
,
0.304661
,
0.349371
,
0.288934
,
0.405483
,
0.445586
,
0.515814
,
0.473186
,
0.301937
,
0.264893
,
0.254353
,
0.269231
,
0.359258
,
0.400097
,
0.288884
,
0.247329
,
0.276519
,
0.264249
,
0.1769
,
0.23213
,
0.310306
,
0.262902
,
0.276964
,
0.295002
,
0.373802
,
0.366785
,
0.419791
,
0.393216
,
0.262827
,
0.371441
,
0.369022
,
0.298262
,
0.334143
,
0.309444
,
0.174822
,
0.251634
,
0.244564
,
0.214386
,
0.185994
,
0.226699
,
0.28445
,
0.376092
,
0.338326
,
0.259502
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, 2 actv functions
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.132123
,
-
0.37531
,
-
0.12943
,
-
0.00798307
,
-
0.133882
,
-
0.0251383
,
0.0486486
,
-
0.0220606
,
0.292495
,
0.233866
,
0.48646
,
0.481844
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, seq_len = 1, concatenation of hidden states as program output
{
seq_len
=
1
;
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
TEST_CASE
(
lstm_bidirectional
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
,
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
,
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
,
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.1122
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.4100
,
-
2.2344
,
0.3685
,
0.4583
,
2.3794
,
1.0372
,
-
0.8887
,
0.7892
,
-
0.4012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
,
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
,
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
,
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
...
...
@@ -3920,95 +5005,200 @@ TEST_CASE(lstm_reverse_actv)
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
float
clip
=
0.0
f
;
// concatenation of hidden states as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
r
,
bias
,
und
,
ih
,
ic
,
pph
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
0.246078
,
0.199709
,
0.303753
,
0.301178
,
0.264634
,
0.304661
,
0.349371
,
0.288934
,
0.405483
,
0.445586
,
0.515814
,
0.473186
,
0.301937
,
0.264893
,
0.254353
,
0.269231
,
0.359258
,
0.400097
,
0.288884
,
0.247329
,
0.276519
,
0.264249
,
0.1769
,
0.23213
,
0.310306
,
0.262902
,
0.276964
,
0.295002
,
0.373802
,
0.366785
,
0.419791
,
0.393216
,
0.262827
,
0.371441
,
0.369022
,
0.298262
,
0.334143
,
0.309444
,
0.174822
,
0.251634
,
0.244564
,
0.214386
,
0.185994
,
0.226699
,
0.28445
,
0.376092
,
0.338326
,
0.259502
};
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.0459032
,
0.0414126
,
0.272303
,
0.0393149
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// reverse, 3 args, 2 actv functions
// last hidden state as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// last cell output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
// 3 args, concatenation of hidden states as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
hs
=
mm
->
add_instruction
(
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"sigmoid"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.132123
,
-
0.37531
,
-
0.12943
,
-
0.00798307
,
-
0.133882
,
-
0.0251383
,
0.0486486
,
-
0.0220606
,
0.292495
,
0.233866
,
0.48646
,
0.481844
};
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.162851
,
-
0.102647
,
-
0.113827
,
-
0.142818
,
0.0513685
,
0.0547876
,
0.0201981
,
-
0.00808453
,
-
0.00520328
,
0.0945081
,
0.264123
,
0.410805
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.071286
,
0.074206
,
0.0124086
,
-
0.139544
,
0.108016
,
-
0.00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
-
0.123496
,
-
0.153616
,
-
0.032874
,
-
0.195349
,
0.0192675
,
-
0.108636
,
0.098927
,
-
0.140733
,
0.162602
,
0.0143099
,
-
0.0455534
,
0.0151574
,
-
0.102509
,
-
0.0372696
,
0.252296
,
-
0.144544
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.187329
,
0.0855831
,
-
0.0171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.150145
,
0.015065
,
-
0.192699
,
-
0.112764
,
-
0.120496
,
0.155754
,
0.148256
,
0.208491
,
0.348432
,
0.0291103
,
0.230275
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.0458544
,
-
0.0401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.00160891
,
-
0.184812
,
0.147774
,
-
0.021205
,
-
0.125423
,
0.0206439
,
-
0.187097
,
-
0.0051453
,
-
0.0767618
,
-
0.0735348
,
-
0.0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
//
reverse, 3 args, seq_len =
1, con
ca
tenation of hidden state
s
as program output
//
sequence length is
1, contenation of hidden state as program output
{
seq_len
=
1
;
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
seq_len
=
1
;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
...
...
@@ -4017,7 +5207,7 @@ TEST_CASE(lstm_reverse_actv)
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
...
...
@@ -4027,23 +5217,16 @@ TEST_CASE(lstm_reverse_actv)
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
TEST_CASE
(
lstm_bidirectional
)
TEST_CASE
(
lstm_bidirectional
_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
...
...
@@ -4087,20 +5270,20 @@ TEST_CASE(lstm_bidirectional)
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
4313
,
-
0.9730
,
-
0.2005
,
2
.3
930
,
-
0.5221
,
-
0.1331
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0.4661
,
0.6494
,
2.1332
,
-
1.0972
,
0.9816
,
0.112
2
,
0.3577
,
1.3508
,
-
0.5366
,
1.7449
,
0.5483
,
-
0.
07
01
,
-
0.4100
,
-
2.2344
,
0.
3685
,
0.4583
,
2.3794
,
1.0
3
72
,
-
0.8887
,
0.
789
2
,
-
0.4
012
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
0.5671
,
0.0458
,
0.
4514
,
-
0.
8968
,
-
0.
9201
,
0.1962
,
0.5771
,
-
0.5332
,
1.5289
,
1.0986
,
0.
6091
,
1.
6462
,
0.
87
20
,
0.
5349
,
-
0.
1962
,
-
1.7416
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
1.1055
,
-
0.1212
,
-
0.
9097
,
0.
7831
,
-
1.
6991
,
-
1.9498
,
-
1.256
7
,
-
0.
4114
,
-
0.8323
,
0.3998
,
0.1831
,
0.
5938
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.
8040
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.
0910
,
1.2122
,
-
0.1952
,
0
.3
577
,
1.3508
,
-
0.5366
,
0.
4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.133
2
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.
4
01
2
,
2.3930
,
-
0.5221
,
-
0.
1331
,
-
1.0
9
72
,
0.9816
,
0.
112
2
,
-
0.4
100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
1.5289
,
1.0986
,
0.
6091
,
1.6462
,
0.
5671
,
0.
0458
,
0.4514
,
-
0.8968
,
0.8720
,
0.5349
,
-
0.
1962
,
-
1.
7416
,
-
0.
9
20
1
,
0.
1962
,
0.
5771
,
-
0.5332
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
-
0.8323
,
0.3998
,
0.
1831
,
0.
5938
,
1.
1055
,
-
0.1212
,
-
0.909
7
,
0.
7831
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.
8040
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.
4114
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
...
...
@@ -4108,12 +5291,12 @@ TEST_CASE(lstm_bidirectional)
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
batch_size
,
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// concatenation of hidden states as program output
...
...
@@ -4128,7 +5311,13 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -4147,25 +5336,29 @@ TEST_CASE(lstm_bidirectional)
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.0459032
,
0.0414126
,
0.272303
,
0.0393149
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.0459032
,
0.0414126
,
0.272303
,
0.0393149
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -4181,6 +5374,12 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
...
...
@@ -4200,15 +5399,17 @@ TEST_CASE(lstm_bidirectional)
ih
,
ic
,
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -4224,6 +5425,12 @@ TEST_CASE(lstm_bidirectional)
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
...
...
@@ -4243,15 +5450,17 @@ TEST_CASE(lstm_bidirectional)
ih
,
ic
,
pph
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -4262,7 +5471,11 @@ TEST_CASE(lstm_bidirectional)
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -4276,25 +5489,28 @@ TEST_CASE(lstm_bidirectional)
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.
031902
1
,
-
0.
0
02
98698
,
-
0.0623361
,
0.
0598866
,
0.
101585
,
0.0687269
,
-
0.1
61725
,
-
0.25617
,
-
0.1
62851
,
-
0.1
02647
,
-
0.
113827
,
-
0.1
42818
,
0.
0513685
,
0.054787
6
,
0.
0201981
,
-
0.00808453
,
-
0.00520328
,
0.
0945081
,
0.264123
,
0.
410805
,
-
0.0786602
,
-
0.0613048
,
0.179592
,
-
0.
071286
,
0.0
74
20
6
,
0.0124086
,
-
0.1
39544
,
0.
108016
,
-
0.00
973633
,
-
0.0552699
,
0.025268
1
,
-
0.05
62072
,
-
0.
123496
,
-
0.
15361
6
,
-
0.0
32874
,
-
0.195349
,
0.0192675
,
-
0.
10863
6
,
0.
098927
,
-
0.1
40733
,
0.162602
,
0.0143099
,
-
0.0
455534
,
0.0151574
,
-
0.
102509
,
-
0.0
372696
,
0.252
29
6
,
-
0.1
44544
,
0.00496085
,
0.0662588
,
-
0.048577
,
-
0.1
87329
,
0.0
855831
,
-
0.0
171894
,
-
0.140202
,
0.0828391
,
-
0.1073
,
-
0.
150
145
,
0.0
15065
,
-
0.
192699
,
-
0.112764
,
-
0.
120496
,
0.155754
,
0.
148
256
,
0.
208491
,
0.348432
,
0.
029110
3
,
0.
23027
5
,
-
0.
165194
,
-
0.0372928
,
0.273786
,
-
0.100877
,
-
0.
0458544
,
-
0.0
401315
,
0.0737483
,
-
0.064505
,
0.136898
,
0.0
0160
89
1
,
-
0.1
8481
2
,
0.
147774
,
-
0.
021205
,
-
0.
125423
,
0.0206439
,
-
0.187097
,
-
0.
0051453
,
-
0.0
767618
,
-
0.0735348
,
-
0.
0826436
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-
0.
16285
1
,
-
0.
1
02
647
,
-
0.113827
,
-
0.
142818
,
-
0.
0786602
,
-
0.0613048
,
0.1
79592
,
-
0.071286
,
-
0.1
23496
,
-
0.1
53616
,
-
0.
032874
,
-
0.1
95349
,
-
0.
102509
,
-
0.037269
6
,
0.
252296
,
-
0.144544
,
-
0.1073
,
-
0.
150145
,
0.015065
,
-
0.
192699
,
-
0.165194
,
-
0.0372928
,
0.273786
,
-
0.
100877
,
-
0.0
21
20
5
,
-
0.125423
,
0.0206439
,
-
0.1
87097
,
0.
0319021
,
-
0.00
298698
,
-
0.062336
1
,
0.05
98866
,
0.
0513685
,
0.
054787
6
,
0.0
201981
,
-
0.00808453
,
0.074206
,
0.
012408
6
,
-
0.
139544
,
0.1
08016
,
0.0192675
,
-
0.108636
,
0.0
98927
,
-
0.140733
,
0.
00496085
,
0.0
662588
,
-
0.048577
,
-
0.1873
29
,
-
0.1
12764
,
-
0.120496
,
0.155754
,
0.1
48256
,
-
0.0
458544
,
-
0.0
401315
,
0.0737483
,
-
0.064505
,
-
0.
005
145
3
,
-
0.0
767618
,
-
0.0735348
,
-
0.
0826436
,
0.101585
,
0.
0687269
,
-
0.161725
,
-
0.256
17
,
-
0.
00520328
,
0.0945081
,
0.
26412
3
,
0.
41080
5
,
-
0.
00973633
,
-
0.0552699
,
0.0252681
,
-
0.0562072
,
0.
162602
,
0.0
143099
,
-
0.0455534
,
0.0151574
,
0.0855831
,
-
0.0
171
89
4
,
-
0.1
4020
2
,
0.
0828391
,
0.
208491
,
0.
348432
,
0.0291103
,
0.230275
,
0.
136898
,
0.0
0160891
,
-
0.184812
,
0.
147774
,
0.214159
,
0.262295
,
0.0247127
,
0.14472
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
...
...
@@ -4303,13 +5519,17 @@ TEST_CASE(lstm_bidirectional)
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
seq_len
=
1
;
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
seq_len
,
batch_size
,
input_size
}};
migraphx
::
shape
in_shape1
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
std
::
vector
<
float
>
input_data1
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
2.3930
,
-
0.5221
,
-
0.1331
};
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape1
,
input_data1
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
mm
->
add_instruction
(
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
...
...
@@ -4323,15 +5543,19 @@ TEST_CASE(lstm_bidirectional)
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
hs_concat
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
output_data
;
hs_concat
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
-
0.0327039
,
-
0.0543852
,
0.114378
,
-
0.0768855
,
-
0.104351
,
-
0.0471426
,
-
0.0905753
,
0.01506
,
0.0319021
,
-
0.00298698
,
-
0.0623361
,
0.0598866
,
0.059797
,
0.104239
,
-
0.0266768
,
0.0727547
,
0.101585
,
0.0687269
,
-
0.161725
,
-
0.25617
,
-
0.146298
,
0.070535
,
0.327809
,
0.407388
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
}
}
...
...
@@ -4577,6 +5801,275 @@ TEST_CASE(lstm_bidirectional_var_seq_lens)
}
}
TEST_CASE
(
lstm_bidirectional_var_seq_lens_layout
)
{
std
::
size_t
batch_size
=
3
;
std
::
size_t
seq_len
=
4
;
std
::
size_t
hidden_size
=
4
;
std
::
size_t
input_size
=
3
;
std
::
size_t
num_dirct
=
2
;
std
::
vector
<
float
>
w_data
{
0.1236
,
-
0.3942
,
0.4149
,
0.0795
,
0.4934
,
-
0.2858
,
0.2602
,
-
0.3098
,
0.0567
,
0.3344
,
0.3607
,
-
0.0551
,
0.4952
,
0.3799
,
0.0630
,
-
0.3532
,
0.0023
,
-
0.0592
,
0.4267
,
0.2382
,
-
0.0784
,
-
0.0032
,
-
0.2476
,
-
0.0206
,
-
0.4963
,
0.4837
,
0.0827
,
0.0123
,
-
0.1203
,
-
0.0279
,
-
0.0049
,
0.4721
,
-
0.3564
,
-
0.1286
,
0.4090
,
-
0.0504
,
0.0575
,
-
0.2138
,
0.1071
,
0.1976
,
-
0.0758
,
0.0139
,
-
0.0761
,
0.3991
,
-
0.2965
,
-
0.4845
,
-
0.1496
,
0.3285
,
-
0.2763
,
-
0.4715
,
-
0.3010
,
-
0.2306
,
-
0.2283
,
-
0.2656
,
0.2035
,
0.3570
,
-
0.1499
,
0.4390
,
-
0.1843
,
0.2351
,
0.3357
,
0.1217
,
0.1401
,
0.3300
,
-
0.0429
,
0.3266
,
0.4834
,
-
0.3914
,
-
0.1480
,
0.3734
,
-
0.0372
,
-
0.1746
,
0.0550
,
0.4177
,
-
0.1332
,
0.4391
,
-
0.3287
,
-
0.4401
,
0.1486
,
0.1346
,
0.1048
,
-
0.4361
,
0.0886
,
-
0.3840
,
-
0.2730
,
-
0.1710
,
0.3274
,
0.0169
,
-
0.4462
,
0.0729
,
0.3983
,
-
0.0669
,
0.0756
,
0.4150
,
-
0.4684
,
-
0.2522
};
std
::
vector
<
float
>
r_data
{
0.1237
,
0.1229
,
-
0.0766
,
-
0.1144
,
-
0.1186
,
0.2922
,
0.2478
,
0.3159
,
-
0.0522
,
0.1685
,
-
0.4621
,
0.1728
,
0.0670
,
-
0.2458
,
-
0.3835
,
-
0.4589
,
-
0.3109
,
0.4908
,
-
0.0133
,
-
0.1858
,
-
0.0590
,
-
0.0347
,
-
0.2353
,
-
0.0671
,
-
0.3812
,
-
0.0004
,
-
0.1432
,
0.2406
,
0.1033
,
-
0.0265
,
-
0.3902
,
0.0755
,
0.3733
,
0.4383
,
-
0.3140
,
0.2537
,
-
0.1818
,
-
0.4127
,
0.3506
,
0.2562
,
0.2926
,
0.1620
,
-
0.4849
,
-
0.4861
,
0.4426
,
0.2106
,
-
0.0005
,
0.4418
,
-
0.2926
,
-
0.3100
,
0.1500
,
-
0.0362
,
-
0.3801
,
-
0.0065
,
-
0.0631
,
0.1277
,
0.2315
,
0.4087
,
-
0.3963
,
-
0.4161
,
-
0.2169
,
-
0.1344
,
0.3468
,
-
0.2260
,
-
0.4564
,
-
0.4432
,
0.1605
,
0.4387
,
0.0034
,
0.4116
,
0.2824
,
0.4775
,
-
0.2729
,
-
0.4707
,
0.1363
,
0.2218
,
0.0559
,
0.2828
,
0.2093
,
0.4687
,
0.3794
,
-
0.1069
,
-
0.3049
,
0.1430
,
-
0.2506
,
0.4644
,
0.2755
,
-
0.3645
,
-
0.3155
,
0.1425
,
0.2891
,
0.1786
,
-
0.3274
,
0.2365
,
0.2522
,
-
0.4312
,
-
0.0562
,
-
0.2748
,
0.0776
,
-
0.3154
,
0.2851
,
-
0.3930
,
-
0.1174
,
0.4360
,
0.2436
,
0.0164
,
-
0.0680
,
0.3403
,
-
0.2857
,
-
0.0459
,
-
0.2991
,
-
0.2624
,
0.4194
,
-
0.3291
,
-
0.4659
,
0.3300
,
0.0454
,
0.4981
,
-
0.4706
,
-
0.4584
,
0.2596
,
0.2871
,
-
0.3509
,
-
0.1910
,
0.3987
,
-
0.1687
,
-
0.0032
,
-
0.1038
};
std
::
vector
<
float
>
bias_data
{
0.0088
,
0.1183
,
0.1642
,
-
0.2631
,
-
0.1330
,
-
0.4008
,
0.3881
,
-
0.4407
,
-
0.2760
,
0.1274
,
-
0.0083
,
-
0.2885
,
0.3949
,
-
0.0182
,
0.4445
,
0.3477
,
0.2266
,
0.3423
,
-
0.0674
,
-
0.4067
,
0.0807
,
0.1109
,
-
0.2036
,
0.1782
,
-
0.2467
,
-
0.0730
,
-
0.4216
,
0.0316
,
-
0.3025
,
0.3637
,
-
0.3181
,
-
0.4655
,
-
0.0258
,
0.0073
,
-
0.4780
,
-
0.4101
,
-
0.3556
,
-
0.1017
,
0.3632
,
-
0.1823
,
0.1479
,
0.1677
,
-
0.2603
,
0.0381
,
0.1575
,
0.1896
,
0.4755
,
-
0.4794
,
0.2167
,
-
0.4474
,
-
0.3139
,
0.1018
,
0.4470
,
-
0.4232
,
0.3247
,
-
0.1636
,
-
0.1582
,
-
0.1703
,
0.3920
,
0.2055
,
-
0.4386
,
0.4208
,
0.0717
,
0.3789
};
std
::
vector
<
float
>
input_data
{
-
0.5516
,
0.2391
,
-
1.6951
,
-
0.0910
,
1.2122
,
-
0.1952
,
0.3577
,
1.3508
,
-
0.5366
,
0.4583
,
2.3794
,
1.0372
,
-
0.4313
,
-
0.9730
,
-
0.2005
,
0.4661
,
0.6494
,
2.1332
,
1.7449
,
0.5483
,
-
0.0701
,
-
0.8887
,
0.7892
,
-
0.4012
,
2.3930
,
-
0.5221
,
-
0.1331
,
-
1.0972
,
0.9816
,
0.1122
,
-
0.4100
,
-
2.2344
,
0.3685
,
-
0.2818
,
-
2.3374
,
1.5310
};
std
::
vector
<
float
>
ih_data
{
1.9104
,
-
1.9004
,
0.3337
,
0.5741
,
1.5289
,
1.0986
,
0.6091
,
1.6462
,
0.5671
,
0.0458
,
0.4514
,
-
0.8968
,
0.8720
,
0.5349
,
-
0.1962
,
-
1.7416
,
-
0.9201
,
0.1962
,
0.5771
,
-
0.5332
,
-
0.9912
,
1.2831
,
1.0896
,
-
0.6959
};
std
::
vector
<
float
>
ic_data
{
0.9569
,
-
0.5981
,
1.1312
,
1.0945
,
-
0.8323
,
0.3998
,
0.1831
,
0.5938
,
1.1055
,
-
0.1212
,
-
0.9097
,
0.7831
,
2.7096
,
-
0.1790
,
0.0022
,
-
0.8040
,
-
1.6991
,
-
1.9498
,
-
1.2567
,
-
0.4114
,
0.1578
,
0.0567
,
0.8069
,
-
0.5141
};
std
::
vector
<
float
>
pph_data
{
1.84369764
,
0.68413646
,
-
0.44892886
,
-
1.50904413
,
0.3860796
,
-
0.52186625
,
1.08474445
,
-
1.80867321
,
1.32594529
,
0.4336262
,
-
0.83699064
,
0.49162736
,
-
0.8271
,
-
0.5683
,
0.4562
,
-
1.2545
,
1.2729
,
-
0.4082
,
-
0.4392
,
-
0.9406
,
0.7794
,
1.8194
,
-
0.5811
,
0.2166
};
float
clip
=
0.0
f
;
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
// concatenation of hidden states as program output
{
std
::
vector
<
int
>
sl_data
{
1
,
2
,
3
};
migraphx
::
shape
sl_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
auto
sql
=
mm
->
add_literal
(
migraphx
::
literal
{
sl_shape
,
sl_data
});
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
auto
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
out_hs
);
auto
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
out_hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
out_hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
out_hs
);
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lho
);
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lco
);
mm
->
add_return
({
out_hs
,
lho
,
lco
});
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
outputs
=
p
.
eval
({});
auto
arg_hs
=
outputs
.
front
();
auto
arg_lho
=
outputs
.
at
(
1
);
auto
arg_lco
=
outputs
.
at
(
2
);
std
::
vector
<
float
>
output_data
;
std
::
vector
<
float
>
last_output_data
;
std
::
vector
<
float
>
last_cell_data
;
arg_hs
.
visit
([
&
](
auto
output
)
{
output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
arg_lho
.
visit
([
&
](
auto
output
)
{
last_output_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
arg_lco
.
visit
([
&
](
auto
output
)
{
last_cell_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
output_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.141643
,
0.0451978
,
0.140804
,
0.0745128
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.96657
,
0.0755112
,
0.0620917
,
-
0.264845
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.262807
,
0.275286
,
0.358395
,
0.266267
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.128254
,
0.125398
,
0.0665142
,
-
0.163651
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.0644683
,
0.371512
,
0.212431
,
-
0.116131
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
float
>
last_output_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.141643
,
0.0451978
,
0.140804
,
0.0745128
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.911307
,
0.11468
,
0.114449
,
0.0196755
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.262807
,
0.275286
,
0.358395
,
0.266267
};
std
::
vector
<
float
>
last_cell_data_gold
{
0.600582
,
-
0.601197
,
0.353558
,
0.789097
,
-
0.326822
,
0.301121
,
0.219523
,
0.415242
,
0.737121
,
0.134902
,
-
0.303595
,
0.241948
,
2.08242
,
0.442513
,
0.187127
,
0.0577626
,
0.391174
,
0.0308845
,
-
0.561745
,
0.0730323
,
-
0.611307
,
0.55454
,
0.4364
,
0.509436
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
output_data
,
output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
last_output_data
,
last_output_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
last_cell_data
,
last_cell_data_gold
));
}
// last cell output as program output
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
seq_orig
=
mm
->
add_literal
(
migraphx
::
literal
{
in_shape
,
input_data
});
auto
ih
=
mm
->
add_literal
(
migraphx
::
literal
{
ih_shape
,
ih_data
});
auto
ic
=
mm
->
add_literal
(
migraphx
::
literal
{
ic_shape
,
ic_data
});
auto
w
=
mm
->
add_literal
(
migraphx
::
literal
{
w_shape
,
w_data
});
auto
r
=
mm
->
add_literal
(
migraphx
::
literal
{
r_shape
,
r_data
});
auto
bias
=
mm
->
add_literal
(
migraphx
::
literal
{
b_shape
,
bias_data
});
auto
pph
=
mm
->
add_literal
(
migraphx
::
literal
{
pph_shape
,
pph_data
});
migraphx
::
shape
pad_seq_s
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
input_size
}};
std
::
vector
<
float
>
pad_data
(
pad_seq_s
.
elements
(),
0.0
f
);
auto
seq_p
=
mm
->
add_literal
(
migraphx
::
literal
{
pad_seq_s
,
pad_data
});
auto
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"concat"
,
{{
"axis"
,
1
}}),
seq_orig
,
seq_p
);
migraphx
::
shape
seq_len_s
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
std
::
vector
<
int32_t
>
len_data
(
batch_size
,
static_cast
<
int32_t
>
(
seq_len
));
auto
sql
=
mm
->
add_literal
(
seq_len_s
,
len_data
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
},
{
"input_forget"
,
0
}}),
seq
,
w
,
r
,
bias
,
sql
,
ih
,
ic
,
pph
);
auto
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
hs
);
auto
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
lho
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lho
);
lco
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
lco
);
mm
->
add_return
({
hs
,
lho
,
lco
});
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
outputs
=
p
.
eval
({});
auto
res_hs
=
outputs
.
at
(
0
);
auto
res_lho
=
outputs
.
at
(
1
);
auto
res_lco
=
outputs
.
at
(
2
);
std
::
vector
<
float
>
hs_data
;
std
::
vector
<
float
>
lho_data
;
std
::
vector
<
float
>
lco_data
;
res_hs
.
visit
([
&
](
auto
output
)
{
hs_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_lho
.
visit
([
&
](
auto
output
)
{
lho_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_lco
.
visit
([
&
](
auto
output
)
{
lco_data
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
hs_data_gold
{
0.079753
,
-
0.289854
,
0.160043
,
0.115056
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.186991
,
-
0.0624168
,
0.205513
,
0.0836373
,
-
0.175114
,
-
0.00543549
,
0.178681
,
-
0.266999
,
0.0459033
,
0.0414126
,
0.272303
,
0.0393149
,
-
0.182201
,
-
0.0232277
,
0.235501
,
-
0.213485
,
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.185038
,
-
0.026845
,
0.177273
,
-
0.0774616
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.294074
,
-
0.0319677
,
-
0.0955337
,
0.104168
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.421857
,
0.0459771
,
-
0.144955
,
0.0720673
,
0.928866
,
0.113685
,
0.220626
,
-
0.0432316
,
0.218258
,
0.0944405
,
0.0431211
,
-
0.132394
,
0.960938
,
0.133565
,
0.269741
,
0.130438
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.946669
,
0.0868676
,
0.044508
,
-
0.373961
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0.022618
,
-
0.121195
,
-
0.4065
,
-
0.252054
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
,
-
0.0300906
,
-
0.0890598
,
-
0.135266
,
-
0.0413375
,
-
0.063456
,
0.148524
,
0.05108
,
-
0.0234895
,
0.103489
,
0.0142918
,
-
0.123408
,
0.0401075
,
-
0.0252804
,
0.267356
,
0.146353
,
0.0789186
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.0681467
,
0.382748
,
0.230211
,
-
0.161537
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
std
::
vector
<
float
>
lho_data_gold
{
-
0.058052
,
0.0795391
,
0.266617
,
-
0.0128746
,
-
0.120174
,
0.043157
,
0.117138
,
-
0.222188
,
0.0309878
,
0.0971544
,
0.149294
,
-
0.0492549
,
0.789732
,
0.128538
,
0.20909
,
0.0553812
,
0.187761
,
0.0501726
,
-
0.121584
,
0.0606723
,
-
0.224905
,
0.32421
,
0.344048
,
0.271694
};
std
::
vector
<
float
>
lco_data_gold
{
-
0.077353
,
0.245616
,
0.361023
,
-
0.0443759
,
-
0.889294
,
0.182463
,
0.186512
,
-
0.402334
,
0.0685243
,
0.20465
,
0.277867
,
-
0.112934
,
1.48161
,
0.524116
,
0.347113
,
0.181813
,
0.67312
,
0.120508
,
-
0.726968
,
0.113845
,
-
0.434265
,
0.747833
,
0.416053
,
0.558713
};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
hs_data
,
hs_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
lho_data
,
lho_data_gold
));
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
lco_data
,
lco_data_gold
));
}
}
TEST_CASE
(
lstm_bidirectional_actv_func
)
{
std
::
size_t
batch_size
=
3
;
...
...
test/simplify_qdq_test.cpp
View file @
7bffbe43
...
...
@@ -44,20 +44,34 @@ void run_pass(migraphx::module& m)
sqdq
.
apply
(
m
);
}
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
broadcast_scale
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
shift
)
const
std
::
vector
<
std
::
size_t
>&
out_lens
,
std
::
size_t
axis
)
{
auto
lens
=
x
->
get_shape
().
lens
();
if
(
scale
->
get_shape
().
lens
()
==
out_lens
)
return
scale
;
migraphx
::
instruction_ref
scale_mb
;
if
(
scale
->
get_shape
().
lens
().
front
()
==
1
)
auto
scale_lens
=
scale
->
get_shape
().
lens
();
if
(
scale_lens
.
front
()
==
1
and
scale_lens
.
size
()
==
1
)
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
scale
);
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_
lens
}}),
scale
);
else
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}}),
scale
);
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
axis
},
{
"out_lens"
,
out_lens
}}),
scale
);
return
scale_mb
;
}
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
,
migraphx
::
instruction_ref
shift
,
std
::
size_t
q_axis
=
1
)
{
auto
lens
=
x
->
get_shape
().
lens
();
auto
scale_mb
=
broadcast_scale
(
m
,
scale
,
lens
,
q_axis
);
auto
shift_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
shift
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
,
shift_mb
);
...
...
@@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m,
migraphx
::
instruction_ref
add_quantize_op
(
migraphx
::
module
&
m
,
const
std
::
string
&
name
,
migraphx
::
instruction_ref
x
,
migraphx
::
instruction_ref
scale
)
migraphx
::
instruction_ref
scale
,
std
::
size_t
q_axis
=
1
)
{
auto
lens
=
x
->
get_shape
().
lens
();
migraphx
::
instruction_ref
scale_mb
;
if
(
scale
->
get_shape
().
lens
().
front
()
==
1
)
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
lens
}}),
scale
);
else
scale_mb
=
m
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
lens
}}),
scale
);
auto
lens
=
x
->
get_shape
().
lens
();
auto
scale_mb
=
broadcast_scale
(
m
,
scale
,
lens
,
q_axis
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
name
),
x
,
scale_mb
);
}
migraphx
::
instruction_ref
add_scale_mul
(
migraphx
::
module
&
m
,
migraphx
::
instruction_ref
scale1
,
migraphx
::
instruction_ref
scale2
,
std
::
size_t
axis1
,
std
::
size_t
axis2
,
const
std
::
vector
<
std
::
size_t
>&
out_lens
)
{
auto
scale1_mb
=
broadcast_scale
(
m
,
scale1
,
out_lens
,
axis1
);
auto
scale2_mb
=
broadcast_scale
(
m
,
scale2
,
out_lens
,
axis2
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale1_mb
,
scale2_mb
);
}
TEST_CASE
(
remove_qdq
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
100
,
100
}};
...
...
@@ -159,18 +180,62 @@ TEST_CASE(dot)
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
0.4
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
0
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
0
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
1
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
scale1
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m2
.
add_literal
(
0.4
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale1
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
0
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale1
,
scale2
,
0
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
...
...
@@ -178,6 +243,180 @@ TEST_CASE(dot)
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_broadcasted
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
2
,
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
d2_mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
1000
,
1024
}}}),
d2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_mb
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2_mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
1000
,
1024
}}}),
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_mb
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_transposed
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1024
,
1000
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale
,
zero
);
auto
d2_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d2
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_t
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale
,
zero
);
auto
q2_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_t
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale_transposed_broadcasted
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1024
,
1000
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
shape
sh4
{
migraphx
::
shape
::
float_type
,
{
1024
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh4
,
0
));
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
2
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
2
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
0
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
0
);
auto
d2_t
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
d2
);
auto
d2_mb
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
1000
,
1024
}}}),
d2_t
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2_mb
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
sh4
,
0
));
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
2
);
auto
q2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
0
);
auto
q2_t
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
q2
);
auto
q2_mb
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
3
,
1000
,
1024
}}}),
q2_t
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2_mb
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale1
,
scale2
,
2
,
3
,
dot
->
get_shape
().
lens
());
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale
);
m2
.
add_return
({
d3
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_multi_scale_unsupported_axis
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
migraphx
::
shape
sh2
{
migraphx
::
shape
::
float_type
,
{
1000
,
1024
}};
migraphx
::
shape
sh3
{
migraphx
::
shape
::
float_type
,
{
1000
}};
migraphx
::
module
m1
;
{
auto
t1
=
m1
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m1
.
add_parameter
(
"t2"
,
sh2
);
auto
scale1
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
sh3
,
0
));
auto
scale2
=
m1
.
add_literal
(
0.4
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t1
,
scale1
,
zero
,
1
);
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
scale1
,
zero
,
1
);
auto
q2
=
add_quantize_op
(
m1
,
"quantizelinear"
,
t2
,
scale2
,
zero
,
1
);
auto
d2
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q2
,
scale2
,
zero
,
1
);
auto
dot
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
d1
,
d2
);
m1
.
add_return
({
dot
});
}
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
t1
,
t2
);
m2
.
add_return
({
dot
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
dot_non_zero_point
)
{
migraphx
::
shape
sh1
{
migraphx
::
shape
::
float_type
,
{
1280
,
1000
}};
...
...
@@ -269,18 +508,18 @@ TEST_CASE(dot_add)
migraphx
::
module
m2
;
{
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
ab
=
m2
.
add_parameter
(
"ab"
,
sh3
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q
1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t
1
,
scale
,
zero
);
auto
q2
=
add_
quantiz
e_op
(
m2
,
"quant
izelinear"
,
t2
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q1
,
q2
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale
1
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d3
,
ab
);
auto
t1
=
m2
.
add_parameter
(
"t1"
,
sh1
);
auto
t2
=
m2
.
add_parameter
(
"t2"
,
sh2
);
auto
ab
=
m2
.
add_parameter
(
"ab"
,
sh3
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t1
,
scale
,
zero
);
auto
q
2
=
add_quantize_op
(
m2
,
"quantizelinear"
,
t
2
,
scale
,
zero
);
auto
dot
=
m2
.
add_
instruction
(
migraphx
::
mak
e_op
(
"quant
_dot"
),
q1
,
q2
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
dot
->
get_shape
().
lens
()
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_
scale
);
auto
add
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d3
,
ab
);
m2
.
add_return
({
add
});
}
...
...
@@ -320,26 +559,80 @@ TEST_CASE(conv)
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
m2
.
add_return
({
d6
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv_multi_scale
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
migraphx
::
shape
s8
{
migraphx
::
shape
::
float_type
,
{
1280
}};
migraphx
::
module
m1
;
{
auto
input
=
m1
.
add_parameter
(
"input"
,
s7
);
auto
weights
=
m1
.
add_parameter
(
"weights"
,
s4
);
auto
w_scale
=
m1
.
add_literal
(
migraphx
::
generate_literal
(
s8
,
0
));
auto
inp_scale
=
m1
.
add_literal
(
0.5
f
);
auto
zero
=
m1
.
add_literal
(
std
::
int8_t
{
0
});
auto
d1
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
weights
,
w_scale
,
zero
,
0
);
auto
q1
=
add_quantize_op
(
m1
,
"quantizelinear"
,
input
,
inp_scale
,
zero
);
auto
d5
=
add_quantize_op
(
m1
,
"dequantizelinear"
,
q1
,
inp_scale
,
zero
);
auto
c1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
d5
,
d1
);
m1
.
add_return
({
c1
});
}
migraphx
::
module
m2
;
{
auto
input
=
m2
.
add_parameter
(
"input"
,
s7
);
auto
weights
=
m2
.
add_parameter
(
"weights"
,
s4
);
auto
w_scale
=
m2
.
add_literal
(
migraphx
::
generate_literal
(
s8
,
0
));
auto
inp_scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
q_inp
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
inp_scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q_inp
,
weights
);
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
m2
.
add_return
({
d6
});
auto
out_scale
=
add_scale_mul
(
m2
,
inp_scale
,
w_scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d1
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
m2
.
add_return
({
d1
});
}
run_pass
(
m1
);
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
conv_multi_scale
)
TEST_CASE
(
conv_multi_scale
_unsupported_axis
)
{
migraphx
::
shape
s4
{
migraphx
::
shape
::
int8_type
,
{
1280
,
320
,
1
,
1
}};
migraphx
::
shape
s7
{
migraphx
::
shape
::
float_type
,
{
1
,
320
,
7
,
7
}};
...
...
@@ -430,20 +723,20 @@ TEST_CASE(conv_bias_add)
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
b1
=
m2
.
add_instruction
(
auto
out_scale
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d6
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale
);
auto
b1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d6
,
b1
);
m2
.
add_return
({
a1
});
...
...
@@ -519,22 +812,21 @@ TEST_CASE(conv_pooling_dot)
auto
scale
=
m2
.
add_literal
(
0.5
f
);
auto
zero
=
m2
.
add_literal
(
std
::
int8_t
{
0
});
auto
zero32
=
m2
.
add_literal
(
std
::
int32_t
{
0
});
auto
scale1
=
m2
.
add_literal
(
0.25
f
);
auto
scale2
=
m2
.
add_literal
(
0.25
f
);
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
auto
d2
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
bias
,
scale
,
zero32
);
auto
d3
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
ab
,
scale
,
zero
);
auto
q1
=
add_quantize_op
(
m2
,
"quantizelinear"
,
input
,
scale
,
zero
);
auto
c1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_convolution"
,
{{
"padding"
,
{
0
,
0
,
0
,
0
}},
{
"stride"
,
{
1
,
1
}},
{
"dilation"
,
{
1
,
1
}},
{
"group"
,
1
},
{
"padding_mode"
,
0
}}),
q1
,
weights
);
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
scale1
);
auto
bc1
=
m2
.
add_instruction
(
auto
out_scale1
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
1
,
c1
->
get_shape
().
lens
());
auto
d5
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
c1
,
out_scale1
);
auto
bc1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"broadcast"
,
{{
"axis"
,
1
},
{
"out_lens"
,
{
1
,
1280
,
7
,
7
}}}),
d2
);
auto
a1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d5
,
bc1
);
auto
ap
=
...
...
@@ -545,10 +837,11 @@ TEST_CASE(conv_pooling_dot)
{
"lengths"
,
{
7
,
7
}},
{
"ceil_mode"
,
0
}}),
a1
);
auto
fl
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
ap
);
auto
q4
=
add_quantize_op
(
m2
,
"quantizelinear"
,
fl
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q4
,
db
);
auto
d9
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
scale2
);
auto
fl
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"flatten"
,
{{
"axis"
,
1
}}),
ap
);
auto
q4
=
add_quantize_op
(
m2
,
"quantizelinear"
,
fl
,
scale
,
zero
);
auto
dot
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"quant_dot"
),
q4
,
db
);
auto
out_scale2
=
add_scale_mul
(
m2
,
scale
,
scale
,
1
,
0
,
dot
->
get_shape
().
lens
());
auto
d9
=
add_quantize_op
(
m2
,
"dequantizelinear"
,
dot
,
out_scale2
);
auto
mb1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
1000
}}}),
d3
);
auto
a2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
d9
,
mb1
);
...
...
test/verify/test_lstm_bidirct_3args_layout.cpp
0 → 100644
View file @
7bffbe43
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_3args_layout
:
verify_program
<
test_lstm_bidirct_3args_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_bidirct_last_layout.cpp
0 → 100644
View file @
7bffbe43
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_bidirct_last_layout
:
verify_program
<
test_lstm_bidirct_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
2
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
bidirectional
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_hs_layout.cpp
0 → 100644
View file @
7bffbe43
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_hs_layout
:
verify_program
<
test_lstm_forward_hs_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
auto
und
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"undefined"
));
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
und
,
ih
,
ic
,
pph
);
std
::
vector
<
int64_t
>
perm_hid
{
2
,
0
,
1
,
3
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm_hid
}}),
hs
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_forward_last_layout.cpp
0 → 100644
View file @
7bffbe43
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_forward_last_layout
:
verify_program
<
test_lstm_forward_last_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
8
*
hidden_size
}};
migraphx
::
shape
ih_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
l_shape
{
migraphx
::
shape
::
int32_type
,
{
batch_size
}};
migraphx
::
shape
ic_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
num_dirct
,
hidden_size
}};
migraphx
::
shape
pph_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
3
*
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
auto
bias
=
mm
->
add_parameter
(
"bias"
,
b_shape
);
auto
ih
=
mm
->
add_parameter
(
"ih"
,
ih_shape
);
auto
len
=
mm
->
add_literal
(
migraphx
::
literal
(
l_shape
,
{
1
,
2
}));
auto
ic
=
mm
->
add_parameter
(
"ic"
,
ic_shape
);
auto
pph
=
mm
->
add_parameter
(
"pph"
,
pph_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
ih
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ih
);
ic
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
ic
);
auto
output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
forward
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
,
bias
,
len
,
ih
,
ic
,
pph
);
auto
last_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_hs_output"
),
output
,
len
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
last_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
test/verify/test_lstm_reverse_3args_cell_output_layout.cpp
0 → 100644
View file @
7bffbe43
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
struct
test_lstm_reverse_3args_cell_layout
:
verify_program
<
test_lstm_reverse_3args_cell_layout
>
{
migraphx
::
program
create_program
()
const
{
std
::
size_t
batch_size
=
2
;
std
::
size_t
seq_len
=
3
;
std
::
size_t
hidden_size
=
5
;
std
::
size_t
input_size
=
8
;
std
::
size_t
num_dirct
=
1
;
float
clip
=
0.0
f
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
in_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
seq_len
,
input_size
}};
migraphx
::
shape
w_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
input_size
}};
migraphx
::
shape
r_shape
{
migraphx
::
shape
::
float_type
,
{
num_dirct
,
4
*
hidden_size
,
hidden_size
}};
auto
seq
=
mm
->
add_parameter
(
"seq"
,
in_shape
);
auto
w
=
mm
->
add_parameter
(
"w"
,
w_shape
);
auto
r
=
mm
->
add_parameter
(
"r"
,
r_shape
);
std
::
vector
<
int64_t
>
perm
{
1
,
0
,
2
};
seq
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
seq
);
auto
hs
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"lstm"
,
{{
"hidden_size"
,
hidden_size
},
{
"actv_func"
,
migraphx
::
to_value
(
std
::
vector
<
migraphx
::
operation
>
{
migraphx
::
make_op
(
"sigmoid"
),
migraphx
::
make_op
(
"tanh"
),
migraphx
::
make_op
(
"tanh"
)})},
{
"direction"
,
migraphx
::
to_value
(
migraphx
::
op
::
rnn_direction
::
reverse
)},
{
"clip"
,
clip
}}),
seq
,
w
,
r
);
auto
cell_output
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"rnn_last_cell_output"
),
hs
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
cell_output
);
return
p
;
}
std
::
string
section
()
const
{
return
"rnn"
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment