Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
16fc0314
Commit
16fc0314
authored
Mar 07, 2019
by
Khalique
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
into multibcast_check
parents
39d4398f
3499ec7d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
156 additions
and
253 deletions
+156
-253
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+0
-3
src/targets/cpu/target.cpp
src/targets/cpu/target.cpp
+3
-2
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+148
-243
test/include/test.hpp
test/include/test.hpp
+5
-5
No files found.
src/rewrite_rnn.cpp
View file @
16fc0314
...
...
@@ -987,15 +987,12 @@ std::vector<instruction_ref> rewrite_rnn::lstm_cell(bool is_forward,
auto
spph
=
prog
.
insert_instruction
(
ins
,
op
::
squeeze
{{
0
}},
pph
);
auto
pphi
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
0
},
{
hs
}},
spph
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphi
);
pphi_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphi_brcst
);
auto
ppho
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
hs
},
{
2
*
hs
}},
spph
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
ppho
);
ppho_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
ppho_brcst
);
auto
pphf
=
prog
.
insert_instruction
(
ins
,
op
::
slice
{{
0
},
{
2
*
hs
},
{
3
*
hs
}},
spph
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
broadcast
{
1
,
ic_shape
},
pphf
);
pphf_brcst
=
prog
.
insert_instruction
(
ins
,
op
::
contiguous
{},
pphf_brcst
);
}
for
(
long
i
=
0
;
i
<
seq_len
;
++
i
)
...
...
src/targets/cpu/target.cpp
View file @
16fc0314
...
...
@@ -14,8 +14,9 @@ std::string target::name() const { return "cpu"; }
std
::
vector
<
pass
>
target
::
get_passes
(
migraphx
::
context
&
)
const
{
return
{
auto_contiguous
{},
rewrite_rnn
{},
return
{
rewrite_rnn
{},
dead_code_elimination
{},
auto_contiguous
{},
dead_code_elimination
{},
lowering
{},
dead_code_elimination
{}};
...
...
test/gpu/miopen.cpp
View file @
16fc0314
...
...
@@ -16,7 +16,7 @@
#include <future>
#include <thread>
#include
"
test.hpp
"
#include
<
test.hpp
>
#ifdef __clang__
#pragma clang diagnostic push
...
...
@@ -134,7 +134,7 @@ migraphx::argument run_gpu(migraphx::program& p)
}
template
<
class
V
>
void
verify_program
()
void
run_
verify_program
()
{
auto_print
::
set_terminate_handler
(
migraphx
::
get_type_name
<
V
>
());
// std::cout << migraphx::get_type_name<V>() << std::endl;
...
...
@@ -156,7 +156,27 @@ void verify_program()
std
::
set_terminate
(
nullptr
);
}
struct
test_literals
template
<
class
T
>
int
auto_register_verify_program
()
{
test
::
add_test_case
(
migraphx
::
get_type_name
<
T
>
(),
[]
{
run_verify_program
<
T
>
();
});
return
0
;
}
template
<
class
T
>
struct
verify_program
{
static
int
static_register
;
// This typedef ensures that the static member will be instantiated if
// the class itself is instantiated
using
static_register_type
=
std
::
integral_constant
<
decltype
(
&
static_register
),
&
static_register
>
;
};
template
<
class
T
>
int
verify_program
<
T
>::
static_register
=
auto_register_verify_program
<
T
>
();
// NOLINT
struct
test_literals
:
verify_program
<
test_literals
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -171,7 +191,7 @@ struct test_literals
}
};
struct
test_add
struct
test_add
:
verify_program
<
test_add
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -184,7 +204,7 @@ struct test_add
}
};
struct
test_add_half
struct
test_add_half
:
verify_program
<
test_add_half
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -197,7 +217,7 @@ struct test_add_half
}
};
struct
test_mul
struct
test_mul
:
verify_program
<
test_mul
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -210,7 +230,7 @@ struct test_mul
}
};
struct
test_exp
struct
test_exp
:
verify_program
<
test_exp
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -223,7 +243,7 @@ struct test_exp
}
};
struct
test_log
struct
test_log
:
verify_program
<
test_log
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -236,7 +256,7 @@ struct test_log
}
};
struct
test_sin
struct
test_sin
:
verify_program
<
test_sin
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -248,7 +268,7 @@ struct test_sin
}
};
struct
test_cos
struct
test_cos
:
verify_program
<
test_cos
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -260,7 +280,7 @@ struct test_cos
}
};
struct
test_tan
struct
test_tan
:
verify_program
<
test_tan
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -272,7 +292,7 @@ struct test_tan
}
};
struct
test_sinh
struct
test_sinh
:
verify_program
<
test_sinh
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -284,7 +304,7 @@ struct test_sinh
}
};
struct
test_cosh
struct
test_cosh
:
verify_program
<
test_cosh
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -296,7 +316,7 @@ struct test_cosh
}
};
struct
test_tanh
struct
test_tanh
:
verify_program
<
test_tanh
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -307,7 +327,7 @@ struct test_tanh
}
};
struct
test_asin
struct
test_asin
:
verify_program
<
test_asin
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -319,7 +339,7 @@ struct test_asin
}
};
struct
test_acos
struct
test_acos
:
verify_program
<
test_acos
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -331,7 +351,7 @@ struct test_acos
}
};
struct
test_atan
struct
test_atan
:
verify_program
<
test_atan
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -343,7 +363,7 @@ struct test_atan
}
};
struct
test_scale
struct
test_scale
:
verify_program
<
test_scale
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -357,7 +377,7 @@ struct test_scale
}
};
struct
test_slice
struct
test_slice
:
verify_program
<
test_slice
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -372,7 +392,7 @@ struct test_slice
}
};
struct
test_triadd
struct
test_triadd
:
verify_program
<
test_triadd
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -387,7 +407,7 @@ struct test_triadd
}
};
struct
test_triadd2
struct
test_triadd2
:
verify_program
<
test_triadd2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -404,7 +424,7 @@ struct test_triadd2
}
};
struct
test_add_broadcast
struct
test_add_broadcast
:
verify_program
<
test_add_broadcast
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -418,7 +438,7 @@ struct test_add_broadcast
}
};
struct
test_add_broadcast2
struct
test_add_broadcast2
:
verify_program
<
test_add_broadcast2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -432,7 +452,7 @@ struct test_add_broadcast2
}
};
struct
test_add_broadcast3
struct
test_add_broadcast3
:
verify_program
<
test_add_broadcast3
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -446,7 +466,7 @@ struct test_add_broadcast3
}
};
struct
test_add_broadcast4
struct
test_add_broadcast4
:
verify_program
<
test_add_broadcast4
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -460,7 +480,7 @@ struct test_add_broadcast4
}
};
struct
test_add_broadcast5
struct
test_add_broadcast5
:
verify_program
<
test_add_broadcast5
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -474,7 +494,7 @@ struct test_add_broadcast5
}
};
struct
test_triadd_broadcast
struct
test_triadd_broadcast
:
verify_program
<
test_triadd_broadcast
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -490,7 +510,7 @@ struct test_triadd_broadcast
}
};
struct
test_sub
struct
test_sub
:
verify_program
<
test_sub
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -505,7 +525,7 @@ struct test_sub
}
};
struct
test_sub2
struct
test_sub2
:
verify_program
<
test_sub2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -522,7 +542,7 @@ struct test_sub2
}
};
struct
test_softmax
struct
test_softmax
:
verify_program
<
test_softmax
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -533,7 +553,7 @@ struct test_softmax
}
};
struct
test_softmax2
struct
test_softmax2
:
verify_program
<
test_softmax2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -545,7 +565,7 @@ struct test_softmax2
}
};
struct
test_conv
struct
test_conv
:
verify_program
<
test_conv
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -559,7 +579,7 @@ struct test_conv
}
};
struct
test_conv2
struct
test_conv2
:
verify_program
<
test_conv2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -573,7 +593,7 @@ struct test_conv2
}
};
struct
test_group_conv
struct
test_group_conv
:
verify_program
<
test_group_conv
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -589,7 +609,7 @@ struct test_group_conv
}
};
struct
test_conv_relu
struct
test_conv_relu
:
verify_program
<
test_conv_relu
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -604,7 +624,7 @@ struct test_conv_relu
}
};
struct
test_conv_relu_half
struct
test_conv_relu_half
:
verify_program
<
test_conv_relu_half
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -619,7 +639,7 @@ struct test_conv_relu_half
}
};
struct
test_add_relu
struct
test_add_relu
:
verify_program
<
test_add_relu
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -632,7 +652,7 @@ struct test_add_relu
}
};
struct
test_sigmoid
struct
test_sigmoid
:
verify_program
<
test_sigmoid
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -643,7 +663,7 @@ struct test_sigmoid
}
};
struct
test_abs
struct
test_abs
:
verify_program
<
test_abs
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -654,7 +674,7 @@ struct test_abs
}
};
struct
test_leaky_relu
struct
test_leaky_relu
:
verify_program
<
test_leaky_relu
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -665,7 +685,7 @@ struct test_leaky_relu
}
};
struct
test_elu
struct
test_elu
:
verify_program
<
test_elu
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -676,7 +696,7 @@ struct test_elu
}
};
struct
test_relu_lrn
struct
test_relu_lrn
:
verify_program
<
test_relu_lrn
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -688,7 +708,7 @@ struct test_relu_lrn
}
};
struct
test_conv_pooling
struct
test_conv_pooling
:
verify_program
<
test_conv_pooling
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -704,7 +724,7 @@ struct test_conv_pooling
}
};
struct
test_global_avg_pooling
struct
test_global_avg_pooling
:
verify_program
<
test_global_avg_pooling
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -719,7 +739,7 @@ struct test_global_avg_pooling
}
};
struct
test_global_max_pooling
struct
test_global_max_pooling
:
verify_program
<
test_global_max_pooling
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -734,7 +754,7 @@ struct test_global_max_pooling
}
};
struct
test_gemm
struct
test_gemm
:
verify_program
<
test_gemm
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -746,7 +766,7 @@ struct test_gemm
}
};
struct
test_gemm_ex
struct
test_gemm_ex
:
verify_program
<
test_gemm_ex
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -758,7 +778,7 @@ struct test_gemm_ex
}
};
struct
test_gemm_half
struct
test_gemm_half
:
verify_program
<
test_gemm_half
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -770,7 +790,7 @@ struct test_gemm_half
}
};
struct
test_gemm_ld
struct
test_gemm_ld
//: verify_program<test_gemm_ld>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -784,7 +804,7 @@ struct test_gemm_ld
}
};
struct
test_gemm_transposeb
struct
test_gemm_transposeb
:
verify_program
<
test_gemm_transposeb
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -797,7 +817,7 @@ struct test_gemm_transposeb
}
};
struct
test_gemm_transposeb_ex
struct
test_gemm_transposeb_ex
:
verify_program
<
test_gemm_transposeb_ex
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -810,7 +830,7 @@ struct test_gemm_transposeb_ex
}
};
struct
test_gemm_transposea
struct
test_gemm_transposea
:
verify_program
<
test_gemm_transposea
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -823,7 +843,7 @@ struct test_gemm_transposea
}
};
struct
test_gemm_transposea_ex
struct
test_gemm_transposea_ex
:
verify_program
<
test_gemm_transposea_ex
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -836,7 +856,7 @@ struct test_gemm_transposea_ex
}
};
struct
test_gemm_transposeab
struct
test_gemm_transposeab
:
verify_program
<
test_gemm_transposeab
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -882,7 +902,7 @@ struct gemm_mutli_dim_2_3
}
};
struct
test_contiguous
struct
test_contiguous
:
verify_program
<
test_contiguous
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -895,7 +915,7 @@ struct test_contiguous
}
};
struct
test_eliminate_contiguous
struct
test_eliminate_contiguous
:
verify_program
<
test_eliminate_contiguous
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -911,7 +931,7 @@ struct test_eliminate_contiguous
}
};
struct
test_transpose
struct
test_transpose
:
verify_program
<
test_transpose
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -925,7 +945,7 @@ struct test_transpose
}
};
struct
test_batchnorm_inference_2
struct
test_batchnorm_inference_2
:
verify_program
<
test_batchnorm_inference_2
>
{
const
size_t
width
=
14
;
const
size_t
height
=
14
;
...
...
@@ -948,7 +968,7 @@ struct test_batchnorm_inference_2
}
};
struct
test_batchnorm_inference
struct
test_batchnorm_inference
:
verify_program
<
test_batchnorm_inference
>
{
const
size_t
width
=
3
;
const
size_t
height
=
3
;
...
...
@@ -971,7 +991,7 @@ struct test_batchnorm_inference
}
};
struct
test_conv_bn
struct
test_conv_bn
:
verify_program
<
test_conv_bn
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -992,7 +1012,7 @@ struct test_conv_bn
}
};
struct
test_conv_bn_relu_pooling
struct
test_conv_bn_relu_pooling
:
verify_program
<
test_conv_bn_relu_pooling
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1016,7 +1036,7 @@ struct test_conv_bn_relu_pooling
}
};
struct
test_concat
struct
test_concat
:
verify_program
<
test_concat
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1033,7 +1053,7 @@ struct test_concat
}
};
struct
test_concat2
struct
test_concat2
:
verify_program
<
test_concat2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1050,7 +1070,7 @@ struct test_concat2
}
};
struct
test_concat_relu
struct
test_concat_relu
:
verify_program
<
test_concat_relu
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1071,7 +1091,7 @@ struct test_concat_relu
}
};
struct
test_pad
struct
test_pad
:
verify_program
<
test_pad
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1090,7 +1110,7 @@ struct test_pad
}
};
struct
test_pooling_autopad
struct
test_pooling_autopad
:
verify_program
<
test_pooling_autopad
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1106,7 +1126,7 @@ struct test_pooling_autopad
}
};
struct
test_gather
struct
test_gather
:
verify_program
<
test_gather
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1122,7 +1142,7 @@ struct test_gather
}
};
struct
test_gather_neg_axis
struct
test_gather_neg_axis
:
verify_program
<
test_gather_neg_axis
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1138,7 +1158,7 @@ struct test_gather_neg_axis
}
};
struct
test_gather_scalar_output
struct
test_gather_scalar_output
:
verify_program
<
test_gather_scalar_output
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1154,7 +1174,7 @@ struct test_gather_scalar_output
}
};
struct
test_gather_scalar_index
struct
test_gather_scalar_index
:
verify_program
<
test_gather_scalar_index
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1170,7 +1190,7 @@ struct test_gather_scalar_index
}
};
struct
test_gather_1d_index
struct
test_gather_1d_index
:
verify_program
<
test_gather_1d_index
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1232,7 +1252,7 @@ void manual_test_concat_relu()
std
::
cout
<<
result
<<
std
::
endl
;
}
struct
test_conv_bn_relu_pooling2
struct
test_conv_bn_relu_pooling2
:
verify_program
<
test_conv_bn_relu_pooling2
>
{
static
migraphx
::
instruction_ref
add_bn
(
migraphx
::
program
&
p
,
migraphx
::
instruction_ref
x
,
std
::
size_t
channels
)
...
...
@@ -1269,7 +1289,7 @@ struct test_conv_bn_relu_pooling2
}
};
struct
test_rnn_forward
struct
test_rnn_forward
:
verify_program
<
test_rnn_forward
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1311,7 +1331,7 @@ struct test_rnn_forward
}
};
struct
test_rnn_forward10
struct
test_rnn_forward10
:
verify_program
<
test_rnn_forward10
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1353,7 +1373,7 @@ struct test_rnn_forward10
}
};
struct
test_rnn_reverse
struct
test_rnn_reverse
:
verify_program
<
test_rnn_reverse
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1393,7 +1413,7 @@ struct test_rnn_reverse
}
};
struct
test_rnn_reverse2
struct
test_rnn_reverse2
:
verify_program
<
test_rnn_reverse2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1433,7 +1453,7 @@ struct test_rnn_reverse2
}
};
struct
test_rnn_3args
struct
test_rnn_3args
:
verify_program
<
test_rnn_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1465,7 +1485,7 @@ struct test_rnn_3args
}
};
struct
test_rnn_4args
struct
test_rnn_4args
:
verify_program
<
test_rnn_4args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1500,7 +1520,7 @@ struct test_rnn_4args
}
};
struct
test_rnn_5args
struct
test_rnn_5args
:
verify_program
<
test_rnn_5args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1539,7 +1559,7 @@ struct test_rnn_5args
}
};
struct
test_rnn_bidirectional
struct
test_rnn_bidirectional
:
verify_program
<
test_rnn_bidirectional
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1581,7 +1601,7 @@ struct test_rnn_bidirectional
}
};
struct
test_rnn_bidirectional10
struct
test_rnn_bidirectional10
:
verify_program
<
test_rnn_bidirectional10
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1622,7 +1642,7 @@ struct test_rnn_bidirectional10
}
};
struct
test_rnn_bi_3args
struct
test_rnn_bi_3args
:
verify_program
<
test_rnn_bi_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1657,7 +1677,7 @@ struct test_rnn_bi_3args
}
};
struct
test_gru_forward_last
struct
test_gru_forward_last
:
verify_program
<
test_gru_forward_last
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1701,7 +1721,7 @@ struct test_gru_forward_last
}
};
struct
test_gru_forward_hs
struct
test_gru_forward_hs
:
verify_program
<
test_gru_forward_hs
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1743,7 +1763,7 @@ struct test_gru_forward_hs
}
};
struct
test_gru_forward_3args_und
struct
test_gru_forward_3args_und
:
verify_program
<
test_gru_forward_3args_und
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1779,7 +1799,7 @@ struct test_gru_forward_3args_und
}
};
struct
test_gru_forward_3args
struct
test_gru_forward_3args
:
verify_program
<
test_gru_forward_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1811,7 +1831,7 @@ struct test_gru_forward_3args
}
};
struct
test_gru_forward_seq1
struct
test_gru_forward_seq1
:
verify_program
<
test_gru_forward_seq1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1843,7 +1863,7 @@ struct test_gru_forward_seq1
}
};
struct
test_gru_forward_default_actv
struct
test_gru_forward_default_actv
:
verify_program
<
test_gru_forward_default_actv
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1873,7 +1893,7 @@ struct test_gru_forward_default_actv
}
};
struct
test_gru_forward_default_actv1
struct
test_gru_forward_default_actv1
:
verify_program
<
test_gru_forward_default_actv1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1914,7 +1934,7 @@ struct test_gru_forward_default_actv1
}
};
struct
test_gru_reverse_last
struct
test_gru_reverse_last
:
verify_program
<
test_gru_reverse_last
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1958,7 +1978,7 @@ struct test_gru_reverse_last
}
};
struct
test_gru_reverse_3args
struct
test_gru_reverse_3args
:
verify_program
<
test_gru_reverse_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -1990,7 +2010,7 @@ struct test_gru_reverse_3args
}
};
struct
test_gru_bidirct_last
struct
test_gru_bidirct_last
:
verify_program
<
test_gru_bidirct_last
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2034,7 +2054,7 @@ struct test_gru_bidirct_last
}
};
struct
test_gru_bidirct_hs
struct
test_gru_bidirct_hs
:
verify_program
<
test_gru_bidirct_hs
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2076,7 +2096,7 @@ struct test_gru_bidirct_hs
}
};
struct
test_gru_bidirct_3args_und
struct
test_gru_bidirct_3args_und
:
verify_program
<
test_gru_bidirct_3args_und
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2112,7 +2132,7 @@ struct test_gru_bidirct_3args_und
}
};
struct
test_gru_bidirct_3args
struct
test_gru_bidirct_3args
:
verify_program
<
test_gru_bidirct_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2144,7 +2164,7 @@ struct test_gru_bidirct_3args
}
};
struct
test_gru_bidirct_seq1
struct
test_gru_bidirct_seq1
:
verify_program
<
test_gru_bidirct_seq1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2176,7 +2196,7 @@ struct test_gru_bidirct_seq1
}
};
struct
test_gru_bidirct_default_actv
struct
test_gru_bidirct_default_actv
:
verify_program
<
test_gru_bidirct_default_actv
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2206,7 +2226,7 @@ struct test_gru_bidirct_default_actv
}
};
struct
test_gru_bidirct_default_actv1
struct
test_gru_bidirct_default_actv1
:
verify_program
<
test_gru_bidirct_default_actv1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2248,7 +2268,7 @@ struct test_gru_bidirct_default_actv1
}
};
struct
test_lstm_forward_last
struct
test_lstm_forward_last
:
verify_program
<
test_lstm_forward_last
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2298,7 +2318,7 @@ struct test_lstm_forward_last
}
};
struct
test_lstm_forward_hs
struct
test_lstm_forward_hs
:
verify_program
<
test_lstm_forward_hs
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2348,7 +2368,7 @@ struct test_lstm_forward_hs
}
};
struct
test_lstm_forward_3args_und
struct
test_lstm_forward_3args_und
:
verify_program
<
test_lstm_forward_3args_und
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2388,7 +2408,7 @@ struct test_lstm_forward_3args_und
}
};
struct
test_lstm_forward_3args
struct
test_lstm_forward_3args
:
verify_program
<
test_lstm_forward_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2422,7 +2442,7 @@ struct test_lstm_forward_3args
}
};
struct
test_lstm_forward_seq1
struct
test_lstm_forward_seq1
:
verify_program
<
test_lstm_forward_seq1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2456,7 +2476,7 @@ struct test_lstm_forward_seq1
}
};
struct
test_lstm_forward_default_actv
struct
test_lstm_forward_default_actv
:
verify_program
<
test_lstm_forward_default_actv
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2486,7 +2506,7 @@ struct test_lstm_forward_default_actv
}
};
struct
test_lstm_forward_default_actv1
struct
test_lstm_forward_default_actv1
:
verify_program
<
test_lstm_forward_default_actv1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2527,7 +2547,7 @@ struct test_lstm_forward_default_actv1
}
};
struct
test_lstm_reverse_last
struct
test_lstm_reverse_last
:
verify_program
<
test_lstm_reverse_last
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2578,7 +2598,7 @@ struct test_lstm_reverse_last
}
};
struct
test_lstm_reverse_3args
struct
test_lstm_reverse_3args
:
verify_program
<
test_lstm_reverse_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2612,7 +2632,7 @@ struct test_lstm_reverse_3args
}
};
struct
test_lstm_reverse_3args_cell_output
struct
test_lstm_reverse_3args_cell_output
:
verify_program
<
test_lstm_reverse_3args_cell_output
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2647,7 +2667,7 @@ struct test_lstm_reverse_3args_cell_output
}
};
struct
test_lstm_bidirct_last
struct
test_lstm_bidirct_last
:
verify_program
<
test_lstm_bidirct_last
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2698,7 +2718,7 @@ struct test_lstm_bidirct_last
}
};
struct
test_lstm_bidirct_hs
struct
test_lstm_bidirct_hs
:
verify_program
<
test_lstm_bidirct_hs
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2740,7 +2760,7 @@ struct test_lstm_bidirct_hs
}
};
struct
test_lstm_bidirct_3args_und
struct
test_lstm_bidirct_3args_und
:
verify_program
<
test_lstm_bidirct_3args_und
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2779,7 +2799,7 @@ struct test_lstm_bidirct_3args_und
}
};
struct
test_lstm_bidirct_3args
struct
test_lstm_bidirct_3args
:
verify_program
<
test_lstm_bidirct_3args
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2811,7 +2831,7 @@ struct test_lstm_bidirct_3args
}
};
struct
test_lstm_bidirct_seq1
struct
test_lstm_bidirct_seq1
:
verify_program
<
test_lstm_bidirct_seq1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2843,7 +2863,7 @@ struct test_lstm_bidirct_seq1
}
};
struct
test_lstm_bidirct_default_actv
struct
test_lstm_bidirct_default_actv
:
verify_program
<
test_lstm_bidirct_default_actv
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2873,7 +2893,7 @@ struct test_lstm_bidirct_default_actv
}
};
struct
test_lstm_bidirct_default_actv1
struct
test_lstm_bidirct_default_actv1
:
verify_program
<
test_lstm_bidirct_default_actv1
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2915,7 +2935,7 @@ struct test_lstm_bidirct_default_actv1
}
};
struct
test_lstm_bidirct_default_actv2
struct
test_lstm_bidirct_default_actv2
:
verify_program
<
test_lstm_bidirct_default_actv2
>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2958,7 +2978,7 @@ struct test_lstm_bidirct_default_actv2
};
template
<
int
Axis
>
struct
test_logsoftmax
struct
test_logsoftmax
:
verify_program
<
test_logsoftmax
<
Axis
>>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2971,8 +2991,14 @@ struct test_logsoftmax
}
};
template
struct
test_logsoftmax
<
0
>;
template
struct
test_logsoftmax
<
1
>;
template
struct
test_logsoftmax
<
2
>;
template
struct
test_logsoftmax
<
3
>;
template
struct
test_logsoftmax
<
4
>;
template
<
int
Axis
>
struct
test_logsoftmax_1
struct
test_logsoftmax_1
:
verify_program
<
test_logsoftmax_1
<
Axis
>>
{
migraphx
::
program
create_program
()
const
{
...
...
@@ -2985,128 +3011,7 @@ struct test_logsoftmax_1
}
};
int
main
()
{
verify_program
<
test_relu_lrn
>
();
verify_program
<
test_pooling_autopad
>
();
verify_program
<
test_abs
>
();
verify_program
<
test_concat
>
();
verify_program
<
test_concat2
>
();
verify_program
<
test_concat_relu
>
();
verify_program
<
test_pad
>
();
verify_program
<
test_add
>
();
verify_program
<
test_add_half
>
();
verify_program
<
test_mul
>
();
verify_program
<
test_exp
>
();
verify_program
<
test_log
>
();
verify_program
<
test_sin
>
();
verify_program
<
test_cos
>
();
verify_program
<
test_tan
>
();
verify_program
<
test_sinh
>
();
verify_program
<
test_cosh
>
();
verify_program
<
test_tanh
>
();
verify_program
<
test_asin
>
();
verify_program
<
test_acos
>
();
verify_program
<
test_atan
>
();
verify_program
<
test_scale
>
();
verify_program
<
test_triadd
>
();
verify_program
<
test_triadd2
>
();
verify_program
<
test_add_broadcast
>
();
verify_program
<
test_add_broadcast2
>
();
verify_program
<
test_add_broadcast3
>
();
verify_program
<
test_add_broadcast4
>
();
verify_program
<
test_add_broadcast5
>
();
verify_program
<
test_triadd_broadcast
>
();
verify_program
<
test_sub
>
();
verify_program
<
test_sub2
>
();
verify_program
<
test_softmax
>
();
verify_program
<
test_softmax2
>
();
verify_program
<
test_conv
>
();
verify_program
<
test_conv2
>
();
verify_program
<
test_group_conv
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_relu_half
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_leaky_relu
>
();
verify_program
<
test_sigmoid
>
();
verify_program
<
test_elu
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_global_avg_pooling
>
();
verify_program
<
test_global_max_pooling
>
();
verify_program
<
test_gemm
>
();
verify_program
<
test_gemm_ex
>
();
verify_program
<
test_gemm_half
>
();
// verify_program<test_gemm_ld>();
verify_program
<
test_gemm_transposeb
>
();
verify_program
<
test_gemm_transposeb_ex
>
();
verify_program
<
test_gemm_transposea
>
();
verify_program
<
test_gemm_transposea_ex
>
();
verify_program
<
test_gemm_transposeab
>
();
verify_program
<
gemm_mutli_dim_2
>
();
verify_program
<
gemm_mutli_dim_2_3
>
();
verify_program
<
test_contiguous
>
();
verify_program
<
test_eliminate_contiguous
>
();
verify_program
<
test_transpose
>
();
verify_program
<
test_batchnorm_inference
>
();
verify_program
<
test_batchnorm_inference_2
>
();
verify_program
<
test_conv_bn
>
();
verify_program
<
test_conv_bn_relu_pooling
>
();
verify_program
<
test_conv_bn_relu_pooling2
>
();
verify_program
<
test_slice
>
();
verify_program
<
test_gather
>
();
verify_program
<
test_gather_neg_axis
>
();
verify_program
<
test_gather_scalar_output
>
();
verify_program
<
test_gather_scalar_index
>
();
verify_program
<
test_gather_1d_index
>
();
verify_program
<
test_rnn_forward
>
();
verify_program
<
test_rnn_forward10
>
();
verify_program
<
test_rnn_reverse
>
();
verify_program
<
test_rnn_reverse2
>
();
verify_program
<
test_rnn_3args
>
();
verify_program
<
test_rnn_4args
>
();
verify_program
<
test_rnn_5args
>
();
verify_program
<
test_rnn_bidirectional
>
();
verify_program
<
test_rnn_bidirectional10
>
();
verify_program
<
test_rnn_bi_3args
>
();
verify_program
<
test_gru_forward_last
>
();
verify_program
<
test_gru_forward_hs
>
();
verify_program
<
test_gru_forward_3args_und
>
();
verify_program
<
test_gru_forward_3args
>
();
verify_program
<
test_gru_forward_seq1
>
();
verify_program
<
test_gru_forward_default_actv
>
();
verify_program
<
test_gru_forward_default_actv1
>
();
verify_program
<
test_gru_reverse_last
>
();
verify_program
<
test_gru_reverse_3args
>
();
verify_program
<
test_gru_bidirct_last
>
();
verify_program
<
test_gru_bidirct_hs
>
();
verify_program
<
test_gru_bidirct_3args_und
>
();
verify_program
<
test_gru_bidirct_3args
>
();
verify_program
<
test_gru_bidirct_seq1
>
();
verify_program
<
test_gru_bidirct_default_actv
>
();
verify_program
<
test_gru_bidirct_default_actv1
>
();
verify_program
<
test_lstm_forward_last
>
();
verify_program
<
test_lstm_forward_hs
>
();
verify_program
<
test_lstm_forward_3args_und
>
();
verify_program
<
test_lstm_forward_3args
>
();
verify_program
<
test_lstm_forward_seq1
>
();
verify_program
<
test_lstm_forward_default_actv
>
();
verify_program
<
test_lstm_forward_default_actv1
>
();
verify_program
<
test_lstm_reverse_last
>
();
verify_program
<
test_lstm_reverse_3args
>
();
verify_program
<
test_lstm_reverse_3args_cell_output
>
();
verify_program
<
test_lstm_bidirct_last
>
();
verify_program
<
test_lstm_bidirct_hs
>
();
verify_program
<
test_lstm_bidirct_3args_und
>
();
verify_program
<
test_lstm_bidirct_3args
>
();
verify_program
<
test_lstm_bidirct_seq1
>
();
verify_program
<
test_lstm_bidirct_default_actv
>
();
verify_program
<
test_lstm_bidirct_default_actv1
>
();
verify_program
<
test_lstm_bidirct_default_actv2
>
();
verify_program
<
test_logsoftmax
<
0
>>
();
verify_program
<
test_logsoftmax
<
1
>>
();
verify_program
<
test_logsoftmax
<
2
>>
();
verify_program
<
test_logsoftmax
<
3
>>
();
verify_program
<
test_logsoftmax
<
4
>>
();
verify_program
<
test_logsoftmax_1
<
0
>>
();
verify_program
<
test_logsoftmax_1
<
1
>>
();
}
template
struct
test_logsoftmax_1
<
0
>;
template
struct
test_logsoftmax_1
<
1
>;
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/include/test.hpp
View file @
16fc0314
...
...
@@ -192,10 +192,10 @@ inline void add_test_case(std::string name, std::function<void()> f)
get_test_cases
().
emplace_back
(
std
::
move
(
name
),
std
::
move
(
f
));
}
struct
auto_register
struct
auto_register
_test_case
{
template
<
class
F
>
auto_register
(
const
char
*
name
,
F
f
)
noexcept
auto_register
_test_case
(
const
char
*
name
,
F
f
)
noexcept
{
add_test_case
(
name
,
f
);
}
...
...
@@ -259,8 +259,8 @@ inline void run(int argc, const char* argv[])
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
static test::auto_register
_test_case
TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register
_test_case
(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment