Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
8cae675e
Commit
8cae675e
authored
Nov 07, 2018
by
Khalique
Browse files
Merge branch 'master' of
https://github.com/ROCmSoftwarePlatform/MIGraph
into transpose
parents
0643952e
414e2fac
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
278 additions
and
381 deletions
+278
-381
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+1
-0
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+5
-0
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+7
-14
test/common_subexpression_elimination_test.cpp
test/common_subexpression_elimination_test.cpp
+5
-11
test/constant_propagate_test.cpp
test/constant_propagate_test.cpp
+4
-9
test/cpu_ops_test.cpp
test/cpu_ops_test.cpp
+37
-72
test/dead_code_elimination_test.cpp
test/dead_code_elimination_test.cpp
+7
-15
test/eliminate_allocation_test.cpp
test/eliminate_allocation_test.cpp
+6
-9
test/eliminate_concat_test.cpp
test/eliminate_concat_test.cpp
+3
-7
test/eliminate_contiguous_test.cpp
test/eliminate_contiguous_test.cpp
+3
-7
test/eval_test.cpp
test/eval_test.cpp
+13
-26
test/fwd_conv_batchnorm_rewrite_test.cpp
test/fwd_conv_batchnorm_rewrite_test.cpp
+2
-6
test/include/test.hpp
test/include/test.hpp
+91
-4
test/literal_test.cpp
test/literal_test.cpp
+5
-10
test/matcher.cpp
test/matcher.cpp
+29
-68
test/memory_coloring_test.cpp
test/memory_coloring_test.cpp
+40
-82
test/op_shape_test.cpp
test/op_shape_test.cpp
+8
-17
test/operation.cpp
test/operation.cpp
+6
-13
test/output_alias.cpp
test/output_alias.cpp
+4
-9
test/program_test.cpp
test/program_test.cpp
+2
-2
No files found.
src/include/migraph/operators.hpp
View file @
8cae675e
...
...
@@ -613,6 +613,7 @@ struct identity
{
return
{
std
::
move
(
output_shape
),
std
::
move
(
args
.
at
(
0
).
data
)};
}
int
output_alias
(
const
std
::
vector
<
shape
>&
)
const
{
return
0
;
}
};
struct
abs
:
unary
...
...
src/targets/gpu/fuse_ops.cpp
View file @
8cae675e
...
...
@@ -155,6 +155,7 @@ struct hip_triadd
device
::
add
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
hip_triadd_relu
...
...
@@ -170,6 +171,7 @@ struct hip_triadd_relu
device
::
add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
3
),
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
return
args
.
at
(
3
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
hip_add_relu
...
...
@@ -185,6 +187,7 @@ struct hip_add_relu
device
::
add_relu
(
ctx
.
get_stream
().
get
(),
args
.
at
(
2
),
args
.
at
(
0
),
args
.
at
(
1
));
return
args
.
at
(
2
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
find_add_relu
...
...
@@ -271,6 +274,7 @@ struct miopen_conv_bias
f
.
compile
(
ctx
);
return
f
.
get_workspace
(
ctx
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
struct
miopen_conv_bias_relu
...
...
@@ -314,6 +318,7 @@ struct miopen_conv_bias_relu
f
.
compile
(
ctx
);
return
f
.
get_workspace
(
ctx
);
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
template
<
class
...
Ms
>
...
...
test/auto_contiguous_test.cpp
View file @
8cae675e
...
...
@@ -14,6 +14,7 @@ struct contiguous_target
migraph
::
context
get_context
()
const
{
return
{};
}
};
// TODO: Add this test case
void
literal_broadcast
()
{
migraph
::
program
p
;
...
...
@@ -25,7 +26,7 @@ void literal_broadcast()
EXPECT
(
not
p
.
get_shape
().
broadcasted
());
}
void
literal_transpose
(
)
TEST_CASE
(
literal_transpose
)
{
migraph
::
program
p
;
p
.
add_literal
(
get_2x2_transposed
());
...
...
@@ -36,7 +37,7 @@ void literal_transpose()
EXPECT
(
not
p
.
get_shape
().
transposed
());
}
void
after_literal_transpose
(
)
TEST_CASE
(
after_literal_transpose
)
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
...
...
@@ -51,7 +52,7 @@ void after_literal_transpose()
EXPECT
(
not
p
.
get_shape
().
transposed
());
}
void
after_literal_broadcast
(
)
TEST_CASE
(
after_literal_broadcast
)
{
migraph
::
program
p
;
auto
l1
=
p
.
add_literal
(
get_2x2
());
...
...
@@ -67,7 +68,7 @@ void after_literal_broadcast()
EXPECT
(
not
p
.
get_shape
().
broadcasted
());
}
void
after_param_transpose
(
)
TEST_CASE
(
after_param_transpose
)
{
migraph
::
program
p
;
auto
l
=
p
.
add_parameter
(
"2x2"
,
{
migraph
::
shape
::
float_type
,
{
2
,
2
}});
...
...
@@ -82,7 +83,7 @@ void after_param_transpose()
EXPECT
(
not
p
.
get_shape
().
transposed
());
}
void
after_param_broadcast
(
)
TEST_CASE
(
after_param_broadcast
)
{
migraph
::
program
p
;
auto
l1
=
p
.
add_parameter
(
"2x2"
,
{
migraph
::
shape
::
float_type
,
{
2
,
2
}});
...
...
@@ -98,12 +99,4 @@ void after_param_broadcast()
EXPECT
(
not
p
.
get_shape
().
broadcasted
());
}
int
main
()
{
// literal_broadcast();
literal_transpose
();
after_literal_transpose
();
after_literal_broadcast
();
after_param_transpose
();
after_param_broadcast
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/common_subexpression_elimination_test.cpp
View file @
8cae675e
...
...
@@ -14,7 +14,7 @@ struct cse_target
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
cse_test1
(
)
TEST_CASE
(
cse_test1
)
{
migraph
::
program
p1
;
{
...
...
@@ -38,7 +38,7 @@ void cse_test1()
EXPECT
(
p1
==
p2
);
}
void
cse_test2
(
)
TEST_CASE
(
cse_test2
)
{
migraph
::
program
p1
;
{
...
...
@@ -63,7 +63,7 @@ void cse_test2()
EXPECT
(
p1
==
p2
);
}
void
cse_test3
(
)
TEST_CASE
(
cse_test3
)
{
migraph
::
program
p1
;
{
...
...
@@ -86,7 +86,7 @@ void cse_test3()
EXPECT
(
p1
==
p2
);
}
void
cse_test4
(
)
TEST_CASE
(
cse_test4
)
{
migraph
::
program
p1
;
{
...
...
@@ -112,10 +112,4 @@ void cse_test4()
EXPECT
(
p1
==
p2
);
}
int
main
()
{
cse_test1
();
cse_test2
();
cse_test3
();
cse_test4
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/constant_propagate_test.cpp
View file @
8cae675e
...
...
@@ -14,7 +14,7 @@ struct const_prop_target
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
const_add1
(
)
TEST_CASE
(
const_add1
)
{
migraph
::
program
p1
;
auto
one
=
p1
.
add_literal
(
1
);
...
...
@@ -29,7 +29,7 @@ void const_add1()
EXPECT
(
p1
==
p2
);
}
void
const_add2
(
)
TEST_CASE
(
const_add2
)
{
migraph
::
program
p1
;
auto
one
=
p1
.
add_parameter
(
"one"
,
{
migraph
::
shape
::
int32_type
,
{
1
}});
...
...
@@ -44,7 +44,7 @@ void const_add2()
EXPECT
(
p1
!=
p2
);
}
void
const_add3
(
)
TEST_CASE
(
const_add3
)
{
migraph
::
program
p1
;
auto
one
=
p1
.
add_literal
(
1
);
...
...
@@ -60,9 +60,4 @@ void const_add3()
EXPECT
(
p1
==
p2
);
}
int
main
()
{
const_add1
();
const_add2
();
const_add3
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/cpu_ops_test.cpp
View file @
8cae675e
...
...
@@ -7,7 +7,7 @@
#include <migraph/verify.hpp>
#include "test.hpp"
void
slice_test
(
)
TEST_CASE
(
slice_test
)
{
{
migraph
::
program
p
;
...
...
@@ -47,7 +47,7 @@ void slice_test()
}
}
void
concat_test
(
)
TEST_CASE
(
concat_test
)
{
{
migraph
::
program
p
;
...
...
@@ -97,7 +97,7 @@ void concat_test()
}
}
void
squeeze_test
(
)
TEST_CASE
(
squeeze_test
)
{
{
migraph
::
program
p
;
...
...
@@ -134,7 +134,7 @@ void squeeze_test()
}
}
void
unsqueeze_test
(
)
TEST_CASE
(
unsqueeze_test
)
{
{
migraph
::
program
p
;
...
...
@@ -160,7 +160,7 @@ void unsqueeze_test()
}
}
void
globalavgpool_test
(
)
TEST_CASE
(
globalavgpool_test
)
{
migraph
::
program
p
;
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
}};
...
...
@@ -180,7 +180,7 @@ void globalavgpool_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
globalmaxpool_test
(
)
TEST_CASE
(
globalmaxpool_test
)
{
migraph
::
program
p
;
auto
s
=
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
}};
...
...
@@ -200,7 +200,7 @@ void globalmaxpool_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
im2col_3x3_no_pad_identity_test
(
)
TEST_CASE
(
im2col_3x3_no_pad_identity_test
)
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
3
,
3
};
...
...
@@ -229,7 +229,7 @@ void im2col_3x3_no_pad_identity_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
input
));
}
void
im2col_3x3_no_pad_test
(
)
TEST_CASE
(
im2col_3x3_no_pad_test
)
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
4
,
4
};
...
...
@@ -261,7 +261,7 @@ void im2col_3x3_no_pad_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
correct
));
}
void
im2col_3x3_stride_2_no_pad_test
(
)
TEST_CASE
(
im2col_3x3_stride_2_no_pad_test
)
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
6
,
6
};
...
...
@@ -294,7 +294,7 @@ void im2col_3x3_stride_2_no_pad_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
correct
));
}
void
im2col_3x3_with_padding_test
(
)
TEST_CASE
(
im2col_3x3_with_padding_test
)
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
2
,
2
};
...
...
@@ -326,7 +326,7 @@ void im2col_3x3_with_padding_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
correct
));
}
void
batch_norm_inference_test
(
)
TEST_CASE
(
batch_norm_inference_test
)
{
migraph
::
program
p
;
const
size_t
width
=
2
,
height
=
2
,
channels
=
4
,
batches
=
2
;
...
...
@@ -366,7 +366,7 @@ void batch_norm_inference_test()
EXPECT
(
migraph
::
verify_range
(
result_vector
,
gold
));
}
void
im2col_3x3_with_channels_identity_test
(
)
TEST_CASE
(
im2col_3x3_with_channels_identity_test
)
{
std
::
size_t
f
[
2
]
=
{
3
,
3
};
std
::
size_t
size
[
2
]
=
{
3
,
3
};
...
...
@@ -395,7 +395,7 @@ void im2col_3x3_with_channels_identity_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
input
));
}
void
exp_test
(
)
TEST_CASE
(
exp_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -409,7 +409,7 @@ void exp_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
sin_test
(
)
TEST_CASE
(
sin_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -423,7 +423,7 @@ void sin_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
cos_test
(
)
TEST_CASE
(
cos_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -437,7 +437,7 @@ void cos_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
tan_test
(
)
TEST_CASE
(
tan_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -451,7 +451,7 @@ void tan_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
add_test
(
)
TEST_CASE
(
add_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -466,7 +466,7 @@ void add_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
broadcast_test
(
)
TEST_CASE
(
broadcast_test
)
{
migraph
::
program
p
;
migraph
::
shape
a_shape
{
migraph
::
shape
::
int32_type
,
{
2
,
2
}};
...
...
@@ -485,7 +485,7 @@ void broadcast_test()
EXPECT
(
output
(
1
,
0
)
==
-
3
);
EXPECT
(
output
(
1
,
1
)
==
-
3
);
}
void
add_broadcast_test
(
)
TEST_CASE
(
add_broadcast_test
)
{
migraph
::
program
p
;
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
2
,
2
,
3
}};
...
...
@@ -506,7 +506,7 @@ void add_broadcast_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
sub_test
(
)
TEST_CASE
(
sub_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -521,7 +521,7 @@ void sub_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
mul_test
(
)
TEST_CASE
(
mul_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -536,7 +536,7 @@ void mul_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
div_test
(
)
TEST_CASE
(
div_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -551,7 +551,7 @@ void div_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
relu_test
(
)
TEST_CASE
(
relu_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -565,7 +565,7 @@ void relu_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
leaky_relu_test
(
)
TEST_CASE
(
leaky_relu_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
...
...
@@ -579,7 +579,7 @@ void leaky_relu_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
imagescaler_test
(
)
TEST_CASE
(
imagescaler_test
)
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
}};
...
...
@@ -626,7 +626,7 @@ void imagescaler_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
void
reshape_test
(
)
TEST_CASE
(
reshape_test
)
{
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
std
::
vector
<
float
>
data
(
24
);
...
...
@@ -716,8 +716,10 @@ void gemm_test()
EXPECT
(
std
::
abs
(
results_vector
[
i
]
-
c
[
i
])
<
tol
);
}
}
TEST_CASE_REGISTER
(
gemm_test
<
float
>
)
TEST_CASE_REGISTER
(
gemm_test
<
double
>
)
void
maxpool_test
(
)
TEST_CASE
(
maxpool_test
)
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
...
...
@@ -763,7 +765,7 @@ void maxpool_test()
p
.
add_instruction
(
migraph
::
op
::
pooling
{
"max"
,
{{
0
,
0
}},
{{
2
,
2
}},
{{
3
,
2
}}},
al
);
p
.
compile
(
migraph
::
cpu
::
target
{});
auto
result
=
p
.
eval
({});
std
::
cout
<<
result
.
get_shape
()
<<
std
::
endl
;
//
std::cout << result.get_shape() << std::endl;
std
::
vector
<
float
>
results_vector
(
36
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
float
tol
=
1e-6
;
...
...
@@ -774,7 +776,7 @@ void maxpool_test()
}
}
void
softmax_test
(
)
TEST_CASE
(
softmax_test
)
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
...
...
@@ -833,7 +835,7 @@ void softmax_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
s
));
}
void
conv2d_test
(
)
TEST_CASE
(
conv2d_test
)
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
...
...
@@ -896,7 +898,7 @@ void conv2d_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
s
));
}
void
conv2d_padding_test
(
)
TEST_CASE
(
conv2d_padding_test
)
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
...
...
@@ -952,7 +954,7 @@ void conv2d_padding_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
s
));
}
void
conv2d_padding_stride_test
(
)
TEST_CASE
(
conv2d_padding_stride_test
)
{
migraph
::
program
p
;
std
::
vector
<
float
>
a
=
{
...
...
@@ -1013,7 +1015,7 @@ void conv2d_padding_stride_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
s
));
}
void
transpose_test
(
)
TEST_CASE
(
transpose_test
)
{
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}};
std
::
vector
<
float
>
data
(
12
);
...
...
@@ -1048,7 +1050,7 @@ void transpose_test()
}
}
void
contiguous_test
(
)
TEST_CASE
(
contiguous_test
)
{
migraph
::
shape
a_shape
{
migraph
::
shape
::
float_type
,
{
1
,
3
,
2
,
2
},
{
12
,
1
,
6
,
3
}};
std
::
vector
<
float
>
data
(
12
);
...
...
@@ -1068,41 +1070,4 @@ void contiguous_test()
EXPECT
(
migraph
::
verify_range
(
results_vector
,
gold
));
}
int
main
()
{
concat_test
();
slice_test
();
squeeze_test
();
unsqueeze_test
();
exp_test
();
sin_test
();
cos_test
();
tan_test
();
add_test
();
broadcast_test
();
add_broadcast_test
();
imagescaler_test
();
sub_test
();
mul_test
();
div_test
();
relu_test
();
leaky_relu_test
();
gemm_test
<
float
>
();
gemm_test
<
double
>
();
reshape_test
();
transpose_test
();
// contiguous_test();
softmax_test
();
// maxpool_test();
conv2d_test
();
conv2d_padding_test
();
conv2d_padding_stride_test
();
batch_norm_inference_test
();
globalavgpool_test
();
globalmaxpool_test
();
im2col_3x3_no_pad_identity_test
();
im2col_3x3_no_pad_test
();
im2col_3x3_stride_2_no_pad_test
();
im2col_3x3_with_channels_identity_test
();
im2col_3x3_with_padding_test
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/dead_code_elimination_test.cpp
View file @
8cae675e
...
...
@@ -12,7 +12,7 @@ struct dce_target
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
simple_test
(
)
TEST_CASE
(
simple_test
)
{
migraph
::
program
p
;
...
...
@@ -27,7 +27,7 @@ void simple_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
simple_test_nop
(
)
TEST_CASE
(
simple_test_nop
)
{
migraph
::
program
p
;
...
...
@@ -43,7 +43,7 @@ void simple_test_nop()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
simple_test_nop2
(
)
TEST_CASE
(
simple_test_nop2
)
{
migraph
::
program
p
;
...
...
@@ -59,7 +59,7 @@ void simple_test_nop2()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
duplicate_test1
(
)
TEST_CASE
(
duplicate_test1
)
{
migraph
::
program
p
;
...
...
@@ -75,7 +75,7 @@ void duplicate_test1()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
duplicate_test2
(
)
TEST_CASE
(
duplicate_test2
)
{
migraph
::
program
p
;
...
...
@@ -92,7 +92,7 @@ void duplicate_test2()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
depth_test
(
)
TEST_CASE
(
depth_test
)
{
migraph
::
program
p
;
...
...
@@ -111,12 +111,4 @@ void depth_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
int
main
()
{
simple_test
();
simple_test_nop
();
simple_test_nop2
();
duplicate_test1
();
duplicate_test2
();
depth_test
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/eliminate_allocation_test.cpp
View file @
8cae675e
...
...
@@ -32,7 +32,7 @@ struct allocate
}
};
void
basic
(
)
TEST_CASE
(
basic
)
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
8
}}});
...
...
@@ -49,7 +49,7 @@ void basic()
EXPECT
(
p
.
get_parameter_shape
(
"memory"
).
bytes
()
==
(
8
*
4
+
40
*
4
+
200
*
4
));
}
void
aligned
(
)
TEST_CASE
(
aligned
)
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
}}});
...
...
@@ -66,7 +66,7 @@ void aligned()
EXPECT
(
p
.
get_parameter_shape
(
"memory"
).
bytes
()
==
(
32
+
32
+
200
*
4
));
}
void
unaligned
(
)
TEST_CASE
(
unaligned
)
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
}}});
...
...
@@ -83,7 +83,7 @@ void unaligned()
EXPECT
(
p
.
get_parameter_shape
(
"memory"
).
bytes
()
==
(
1
*
4
+
2
*
4
+
200
*
4
));
}
void
float_aligned
(
)
TEST_CASE
(
float_aligned
)
{
migraph
::
program
p
;
auto
a1
=
p
.
add_instruction
(
allocate
{
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
}}});
...
...
@@ -100,11 +100,8 @@ void float_aligned()
EXPECT
(
p
.
get_parameter_shape
(
"memory"
).
bytes
()
==
(
1
*
4
+
2
*
4
+
200
*
4
));
}
int
main
()
int
main
(
int
argc
,
const
char
*
argv
[]
)
{
setenv
(
"MIGRAPH_DISABLE_MEMORY_COLORING"
,
"1"
,
1
);
basic
();
aligned
();
unaligned
();
float_aligned
();
test
::
run
(
argc
,
argv
);
}
test/eliminate_concat_test.cpp
View file @
8cae675e
...
...
@@ -79,7 +79,7 @@ struct fred_op
}
};
void
basic
(
)
TEST_CASE
(
basic
)
{
auto
create_test_program
=
[]()
{
migraph
::
program
p
;
...
...
@@ -123,7 +123,7 @@ void basic()
EXPECT
(
p1
==
p2
);
}
void
wont_work
(
)
TEST_CASE
(
wont_work
)
{
auto
create_test_program
=
[]()
{
migraph
::
program
p
;
...
...
@@ -167,8 +167,4 @@ void wont_work()
EXPECT
(
p1
==
p2
);
}
int
main
()
{
basic
();
wont_work
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/eliminate_contiguous_test.cpp
View file @
8cae675e
...
...
@@ -14,7 +14,7 @@ struct eliminate_contiguous_target
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
standard_op
(
)
TEST_CASE
(
standard_op
)
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
...
...
@@ -26,7 +26,7 @@ void standard_op()
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
count
);
}
void
non_standard_op
(
)
TEST_CASE
(
non_standard_op
)
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
get_2x2
());
...
...
@@ -38,8 +38,4 @@ void non_standard_op()
EXPECT
(
std
::
distance
(
p
.
begin
(),
p
.
end
())
==
(
count
-
1
));
}
int
main
()
{
standard_op
();
non_standard_op
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/eval_test.cpp
View file @
8cae675e
...
...
@@ -50,7 +50,7 @@ struct double_reverse_target
migraph
::
context
get_context
()
const
{
return
{};
}
};
void
literal_test1
(
)
TEST_CASE
(
literal_test1
)
{
migraph
::
program
p
;
...
...
@@ -62,7 +62,7 @@ void literal_test1()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
literal_test2
(
)
TEST_CASE
(
literal_test2
)
{
migraph
::
program
p
;
...
...
@@ -76,7 +76,7 @@ void literal_test2()
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
print_test
(
)
TEST_CASE
(
print_test
)
{
migraph
::
program
p
;
...
...
@@ -90,7 +90,7 @@ void print_test()
EXPECT
(
!
s
.
empty
());
}
void
param_test
(
)
TEST_CASE
(
param_test
)
{
migraph
::
program
p
;
...
...
@@ -104,7 +104,7 @@ void param_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
param_error_test
(
)
TEST_CASE
(
param_error_test
)
{
migraph
::
program
p
;
...
...
@@ -119,7 +119,7 @@ void param_error_test()
"Parameter not found: y"
));
}
void
replace_test
(
)
TEST_CASE
(
replace_test
)
{
migraph
::
program
p
;
...
...
@@ -134,7 +134,7 @@ void replace_test()
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
replace_ins_test
(
)
TEST_CASE
(
replace_ins_test
)
{
migraph
::
program
p
;
...
...
@@ -150,7 +150,7 @@ void replace_ins_test()
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
replace_ins_test2
(
)
TEST_CASE
(
replace_ins_test2
)
{
migraph
::
program
p
;
...
...
@@ -167,7 +167,7 @@ void replace_ins_test2()
EXPECT
(
result
!=
migraph
::
literal
{
3
});
}
void
insert_replace_test
(
)
TEST_CASE
(
insert_replace_test
)
{
migraph
::
program
p
;
...
...
@@ -185,7 +185,7 @@ void insert_replace_test()
EXPECT
(
result
!=
migraph
::
literal
{
5
});
}
void
target_test
(
)
TEST_CASE
(
target_test
)
{
migraph
::
program
p
;
...
...
@@ -198,7 +198,7 @@ void target_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
reverse_target_test
(
)
TEST_CASE
(
reverse_target_test
)
{
migraph
::
program
p
;
...
...
@@ -211,7 +211,7 @@ void reverse_target_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
void
double_reverse_target_test
(
)
TEST_CASE
(
double_reverse_target_test
)
{
migraph
::
program
p
;
...
...
@@ -224,17 +224,4 @@ void double_reverse_target_test()
EXPECT
(
result
!=
migraph
::
literal
{
4
});
}
int
main
()
{
literal_test1
();
literal_test2
();
print_test
();
param_test
();
param_error_test
();
replace_test
();
replace_ins_test
();
replace_ins_test2
();
insert_replace_test
();
target_test
();
reverse_target_test
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/fwd_conv_batchnorm_rewrite_test.cpp
View file @
8cae675e
...
...
@@ -6,7 +6,7 @@
#include <test.hpp>
#include <migraph/verify.hpp>
void
fwd_conv_batchnorm_rewrite_test
(
)
TEST_CASE
(
fwd_conv_batchnorm_rewrite_test
)
{
std
::
vector
<
float
>
xdata
=
{
0.26485917
,
0.61703885
,
0.32762103
,
0.2503367
,
0.6552712
,
0.07947932
,
0.95442678
,
...
...
@@ -64,8 +64,4 @@ void fwd_conv_batchnorm_rewrite_test()
EXPECT
(
migraph
::
verify_range
(
results_vector1
,
results_vector2
));
}
int
main
()
{
fwd_conv_batchnorm_rewrite_test
();
return
0
;
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/include/test.hpp
View file @
8cae675e
...
...
@@ -2,7 +2,10 @@
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <unordered_map>
#include <vector>
#ifndef MIGRAPH_GUARD_TEST_TEST_HPP
#define MIGRAPH_GUARD_TEST_TEST_HPP
...
...
@@ -154,11 +157,75 @@ bool throws(F f, const std::string& msg = "")
}
}
template
<
class
T
>
void
run_test
()
using
string_map
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
template
<
class
Keyword
>
string_map
parse
(
std
::
vector
<
std
::
string
>
as
,
Keyword
keyword
)
{
string_map
result
;
std
::
string
flag
;
for
(
auto
&&
x
:
as
)
{
auto
f
=
keyword
(
x
);
if
(
f
.
empty
())
{
result
[
flag
].
push_back
(
x
);
}
else
{
flag
=
f
.
front
();
result
[
flag
];
// Ensure the flag exists
}
}
return
result
;
}
inline
auto
&
get_test_cases
()
{
static
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
function
<
void
()
>>>
cases
;
return
cases
;
}
inline
void
add_test_case
(
std
::
string
name
,
std
::
function
<
void
()
>
f
)
{
T
t
=
{};
t
.
run
();
get_test_cases
().
emplace_back
(
name
,
f
);
}
struct
auto_register
{
template
<
class
F
>
auto_register
(
const
char
*
name
,
F
f
)
noexcept
{
add_test_case
(
name
,
f
);
}
};
inline
void
run_test_case
(
const
std
::
string
&
name
,
const
std
::
function
<
void
()
>&
f
)
{
std
::
cout
<<
"[ RUN ] "
<<
name
<<
std
::
endl
;
f
();
std
::
cout
<<
"[ COMPLETE ] "
<<
name
<<
std
::
endl
;
}
inline
void
run
(
int
argc
,
const
char
*
argv
[])
{
std
::
vector
<
std
::
string
>
as
(
argv
+
1
,
argv
+
argc
);
auto
args
=
parse
(
as
,
[](
auto
&&
)
->
std
::
vector
<
std
::
string
>
{
return
{};
});
auto
cases
=
args
[
""
];
if
(
cases
.
empty
())
{
for
(
auto
&&
tc
:
get_test_cases
())
run_test_case
(
tc
.
first
,
tc
.
second
);
}
else
{
std
::
unordered_map
<
std
::
string
,
std
::
function
<
void
()
>>
m
(
get_test_cases
().
begin
(),
get_test_cases
().
end
());
for
(
auto
&&
name
:
cases
)
run_test_case
(
name
,
m
[
name
]);
}
}
}
// namespace test
...
...
@@ -179,4 +246,24 @@ void run_test()
// NOLINTNEXTLINE
#define STATUS(...) EXPECT((__VA_ARGS__) == 0)
// NOLINTNEXTLINE
#define TEST_CAT(x, ...) TEST_PRIMITIVE_CAT(x, __VA_ARGS__)
#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__
// NOLINTNEXTLINE
#define TEST_CASE_REGISTER(...) \
static test::auto_register TEST_CAT(register_test_case_, __LINE__) = \
test::auto_register(#__VA_ARGS__, &__VA_ARGS__);
// NOLINTNEXTLINE
#define TEST_CASE(...) \
void __VA_ARGS__(); \
TEST_CASE_REGISTER(__VA_ARGS__) \
void __VA_ARGS__()
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wglobal-constructors"
#endif
#endif
test/literal_test.cpp
View file @
8cae675e
...
...
@@ -4,7 +4,7 @@
#include <string>
#include "test.hpp"
void
literal_test
(
)
TEST_CASE
(
literal_test
)
{
EXPECT
(
migraph
::
literal
{
1
}
==
migraph
::
literal
{
1
});
EXPECT
(
migraph
::
literal
{
1
}
!=
migraph
::
literal
{
2
});
...
...
@@ -25,7 +25,7 @@ void literal_test()
EXPECT
(
l4
.
empty
());
}
void
literal_os1
(
)
TEST_CASE
(
literal_os1
)
{
migraph
::
literal
l
{
1
};
std
::
stringstream
ss
;
...
...
@@ -33,7 +33,7 @@ void literal_os1()
EXPECT
(
ss
.
str
()
==
"1"
);
}
void
literal_os2
(
)
TEST_CASE
(
literal_os2
)
{
migraph
::
literal
l
{};
std
::
stringstream
ss
;
...
...
@@ -41,7 +41,7 @@ void literal_os2()
EXPECT
(
ss
.
str
().
empty
());
}
void
literal_os3
(
)
TEST_CASE
(
literal_os3
)
{
migraph
::
shape
s
{
migraph
::
shape
::
int64_type
,
{
3
}};
migraph
::
literal
l
{
s
,
{
1
,
2
,
3
}};
...
...
@@ -50,9 +50,4 @@ void literal_os3()
EXPECT
(
ss
.
str
()
==
"1, 2, 3"
);
}
int
main
()
{
literal_test
();
literal_os1
();
literal_os2
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/matcher.cpp
View file @
8cae675e
...
...
@@ -27,7 +27,7 @@ void match1()
EXPECT
(
bool
{
r
.
result
==
l
});
}
void
match_name1
(
)
TEST_CASE
(
match_name1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -39,7 +39,7 @@ void match_name1()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_name2
(
)
TEST_CASE
(
match_name2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -51,7 +51,7 @@ void match_name2()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_name3
(
)
TEST_CASE
(
match_name3
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -63,7 +63,7 @@ void match_name3()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg1
(
)
TEST_CASE
(
match_arg1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -75,7 +75,7 @@ void match_arg1()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg2
(
)
TEST_CASE
(
match_arg2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -87,7 +87,7 @@ void match_arg2()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_arg3
(
)
TEST_CASE
(
match_arg3
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -99,7 +99,7 @@ void match_arg3()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg4
(
)
TEST_CASE
(
match_arg4
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -111,7 +111,7 @@ void match_arg4()
EXPECT
(
bool
{
r
.
result
==
pass
});
}
void
match_arg5
(
)
TEST_CASE
(
match_arg5
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -123,7 +123,7 @@ void match_arg5()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_arg6
(
)
TEST_CASE
(
match_arg6
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -135,7 +135,7 @@ void match_arg6()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_arg7
(
)
TEST_CASE
(
match_arg7
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -148,7 +148,7 @@ void match_arg7()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_args1
(
)
TEST_CASE
(
match_args1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -161,7 +161,7 @@ void match_args1()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_args2
(
)
TEST_CASE
(
match_args2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -174,7 +174,7 @@ void match_args2()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_args3
(
)
TEST_CASE
(
match_args3
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -186,7 +186,7 @@ void match_args3()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_args4
(
)
TEST_CASE
(
match_args4
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -200,7 +200,7 @@ void match_args4()
EXPECT
(
bool
{
r
.
result
==
sum2
});
}
void
match_args5
(
)
TEST_CASE
(
match_args5
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -213,7 +213,7 @@ void match_args5()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_args6
(
)
TEST_CASE
(
match_args6
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -225,7 +225,7 @@ void match_args6()
EXPECT
(
bool
{
r
.
result
==
pass
});
}
void
match_args7
(
)
TEST_CASE
(
match_args7
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -239,7 +239,7 @@ void match_args7()
EXPECT
(
bool
{
r
.
result
==
pass
});
}
void
match_either_args1
(
)
TEST_CASE
(
match_either_args1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -253,7 +253,7 @@ void match_either_args1()
EXPECT
(
bool
{
r
.
result
==
sum2
});
}
void
match_either_args2
(
)
TEST_CASE
(
match_either_args2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -267,7 +267,7 @@ void match_either_args2()
EXPECT
(
bool
{
r
.
result
==
sum2
});
}
void
match_either_args3
(
)
TEST_CASE
(
match_either_args3
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -281,7 +281,7 @@ void match_either_args3()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_all_of1
(
)
TEST_CASE
(
match_all_of1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -294,7 +294,7 @@ void match_all_of1()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_all_of2
(
)
TEST_CASE
(
match_all_of2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -307,7 +307,7 @@ void match_all_of2()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_any_of1
(
)
TEST_CASE
(
match_any_of1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -320,7 +320,7 @@ void match_any_of1()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_any_of2
(
)
TEST_CASE
(
match_any_of2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -333,7 +333,7 @@ void match_any_of2()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_none_of1
(
)
TEST_CASE
(
match_none_of1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -346,7 +346,7 @@ void match_none_of1()
EXPECT
(
bool
{
r
.
result
==
sum
});
}
void
match_none_of2
(
)
TEST_CASE
(
match_none_of2
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -359,7 +359,7 @@ void match_none_of2()
EXPECT
(
bool
{
r
.
result
==
p
.
end
()});
}
void
match_bind1
(
)
TEST_CASE
(
match_bind1
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -400,7 +400,7 @@ struct match_find_literal
}
};
void
match_finder
(
)
TEST_CASE
(
match_finder
)
{
migraph
::
program
p
;
auto
one
=
p
.
add_literal
(
1
);
...
...
@@ -410,43 +410,4 @@ void match_finder()
match
::
find_matches
(
p
,
match_find_sum
{
sum
},
match_find_literal
{
sum
});
}
int
main
()
{
match1
();
match_name1
();
match_name2
();
match_name3
();
match_arg1
();
match_arg2
();
match_arg3
();
match_arg4
();
match_arg5
();
match_arg6
();
match_arg7
();
match_args1
();
match_args2
();
match_args3
();
match_args4
();
match_args5
();
match_args6
();
match_args7
();
match_either_args1
();
match_either_args2
();
match_either_args3
();
match_all_of1
();
match_all_of2
();
match_any_of1
();
match_any_of2
();
match_none_of1
();
match_none_of2
();
match_bind1
();
match_finder
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/memory_coloring_test.cpp
View file @
8cae675e
...
...
@@ -43,7 +43,7 @@ bool no_allocate(const migraph::program& p)
return
std
::
none_of
(
p
.
begin
(),
p
.
end
(),
[](
auto
&&
ins
)
{
return
ins
.
name
()
==
"allocate"
;
});
}
void
test1
(
)
TEST_CASE
(
test1
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -55,7 +55,7 @@ void test1()
CHECK
(
no_allocate
(
p
));
}
void
test2
(
)
TEST_CASE
(
test2
)
{
migraph
::
program
p
;
auto
input
=
p
.
add_parameter
(
"input"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
16
}});
...
...
@@ -69,7 +69,7 @@ void test2()
CHECK
(
no_allocate
(
p
));
}
void
test3
(
)
TEST_CASE
(
test3
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -82,7 +82,7 @@ void test3()
CHECK
(
no_allocate
(
p
));
}
void
test4
(
)
TEST_CASE
(
test4
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
0
}});
...
...
@@ -95,7 +95,7 @@ void test4()
CHECK
(
no_allocate
(
p
));
}
void
test5
(
)
TEST_CASE
(
test5
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
40
}});
...
...
@@ -107,7 +107,7 @@ void test5()
CHECK
(
no_allocate
(
p
));
}
void
test6
(
)
TEST_CASE
(
test6
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -120,7 +120,7 @@ void test6()
CHECK
(
no_allocate
(
p
));
}
void
test7
(
)
TEST_CASE
(
test7
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -133,7 +133,7 @@ void test7()
CHECK
(
no_allocate
(
p
));
}
void
test8
(
)
TEST_CASE
(
test8
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -146,7 +146,7 @@ void test8()
CHECK
(
no_allocate
(
p
));
}
void
test9
(
)
TEST_CASE
(
test9
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -159,7 +159,7 @@ void test9()
CHECK
(
no_allocate
(
p
));
}
void
test10
(
)
TEST_CASE
(
test10
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -169,7 +169,7 @@ void test10()
CHECK
(
no_allocate
(
p
));
}
void
test11
(
)
TEST_CASE
(
test11
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -183,7 +183,7 @@ void test11()
CHECK
(
no_allocate
(
p
));
}
void
test12
(
)
TEST_CASE
(
test12
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
40
}});
...
...
@@ -197,7 +197,7 @@ void test12()
CHECK
(
no_allocate
(
p
));
}
void
test13
(
)
TEST_CASE
(
test13
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -211,7 +211,7 @@ void test13()
CHECK
(
no_allocate
(
p
));
}
void
test14
(
)
TEST_CASE
(
test14
)
{
migraph
::
program
p
;
auto
a3
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -225,7 +225,7 @@ void test14()
CHECK
(
no_allocate
(
p
));
}
void
test15
(
)
TEST_CASE
(
test15
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -239,7 +239,7 @@ void test15()
CHECK
(
no_allocate
(
p
));
}
void
test16
(
)
TEST_CASE
(
test16
)
{
migraph
::
program
p
;
auto
a1
=
p
.
add_literal
(
migraph
::
generate_literal
({
migraph
::
shape
::
float_type
,
{
8
}}));
...
...
@@ -253,7 +253,7 @@ void test16()
CHECK
(
no_allocate
(
p
));
}
void
test17
(
)
TEST_CASE
(
test17
)
{
migraph
::
program
p
;
auto
a3
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
40
}});
...
...
@@ -267,7 +267,7 @@ void test17()
CHECK
(
no_allocate
(
p
));
}
void
test18
(
)
TEST_CASE
(
test18
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -281,7 +281,7 @@ void test18()
CHECK
(
no_allocate
(
p
));
}
void
test19
(
)
TEST_CASE
(
test19
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -295,7 +295,7 @@ void test19()
CHECK
(
no_allocate
(
p
));
}
void
test20
(
)
TEST_CASE
(
test20
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
32
}});
...
...
@@ -309,7 +309,7 @@ void test20()
CHECK
(
no_allocate
(
p
));
}
void
test21
(
)
TEST_CASE
(
test21
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
32
}});
...
...
@@ -323,7 +323,7 @@ void test21()
CHECK
(
no_allocate
(
p
));
}
void
test22
(
)
TEST_CASE
(
test22
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
32
}});
...
...
@@ -337,7 +337,7 @@ void test22()
CHECK
(
no_allocate
(
p
));
}
void
test23
(
)
TEST_CASE
(
test23
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -351,7 +351,7 @@ void test23()
CHECK
(
no_allocate
(
p
));
}
void
test24
(
)
TEST_CASE
(
test24
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
32
}});
...
...
@@ -365,7 +365,7 @@ void test24()
CHECK
(
no_allocate
(
p
));
}
void
test25
(
)
TEST_CASE
(
test25
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -379,7 +379,7 @@ void test25()
CHECK
(
no_allocate
(
p
));
}
void
test26
(
)
TEST_CASE
(
test26
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -393,7 +393,7 @@ void test26()
CHECK
(
no_allocate
(
p
));
}
void
test27
(
)
TEST_CASE
(
test27
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -405,7 +405,7 @@ void test27()
CHECK
(
no_allocate
(
p
));
}
void
test28
(
)
TEST_CASE
(
test28
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"output"
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -419,7 +419,7 @@ void test28()
CHECK
(
no_allocate
(
p
));
}
void
test29
(
)
TEST_CASE
(
test29
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"output"
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -434,7 +434,7 @@ void test29()
CHECK
(
no_allocate
(
p
));
}
void
test30
(
)
TEST_CASE
(
test30
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"x"
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -449,7 +449,7 @@ void test30()
CHECK
(
no_allocate
(
p
));
}
void
test31
(
)
TEST_CASE
(
test31
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"output"
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -463,7 +463,7 @@ void test31()
CHECK
(
no_allocate
(
p
));
}
void
test32
(
)
TEST_CASE
(
test32
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -477,7 +477,7 @@ void test32()
CHECK
(
no_allocate
(
p
));
}
void
test33
(
)
TEST_CASE
(
test33
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
8
}});
...
...
@@ -491,7 +491,7 @@ void test33()
CHECK
(
no_allocate
(
p
));
}
void
test34
(
)
TEST_CASE
(
test34
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
40
}});
...
...
@@ -505,7 +505,7 @@ void test34()
CHECK
(
no_allocate
(
p
));
}
void
test35
(
)
TEST_CASE
(
test35
)
{
migraph
::
program
p
;
auto
a1
=
add_alloc
(
p
,
{
migraph
::
shape
::
float_type
,
{
40
}});
...
...
@@ -519,7 +519,7 @@ void test35()
CHECK
(
no_allocate
(
p
));
}
void
test36
(
)
TEST_CASE
(
test36
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"output"
,
{
migraph
::
shape
::
float_type
,
{
20
}});
...
...
@@ -536,7 +536,7 @@ void test36()
CHECK
(
no_allocate
(
p
));
}
void
test37
(
)
TEST_CASE
(
test37
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"output"
,
{
migraph
::
shape
::
float_type
,
{
20
}});
...
...
@@ -553,7 +553,7 @@ void test37()
CHECK
(
no_allocate
(
p
));
}
void
test38
(
)
TEST_CASE
(
test38
)
{
migraph
::
program
p
;
auto
output
=
p
.
add_parameter
(
"output"
,
{
migraph
::
shape
::
float_type
,
{
1
,
64
,
56
,
56
}});
...
...
@@ -598,7 +598,7 @@ void test38()
CHECK
(
no_allocate
(
p
));
}
void
literal_test
(
)
TEST_CASE
(
literal_test
)
{
migraph
::
program
p
;
auto
lit
=
generate_literal
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}});
...
...
@@ -608,46 +608,4 @@ void literal_test()
CHECK
(
lit
==
result
);
}
int
main
()
{
test1
();
test2
();
test3
();
test4
();
test5
();
test6
();
test7
();
test8
();
test9
();
test10
();
test11
();
test12
();
test13
();
test14
();
test15
();
test16
();
test17
();
test18
();
test19
();
test20
();
test21
();
test22
();
test23
();
test24
();
test25
();
test26
();
test27
();
test28
();
test29
();
test30
();
test31
();
test32
();
test33
();
test34
();
test35
();
test36
();
test37
();
test38
();
literal_test
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/op_shape_test.cpp
View file @
8cae675e
...
...
@@ -52,7 +52,7 @@ void throws_shape(const migraph::shape&, Ts...)
"An expected shape should not be passed to throws_shape function"
);
}
void
batch_norm_inference_shape
(
)
TEST_CASE
(
batch_norm_inference_shape
)
{
const
size_t
channels
=
3
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
,
channels
,
3
,
3
}};
...
...
@@ -62,7 +62,7 @@ void batch_norm_inference_shape()
throws_shape
(
migraph
::
op
::
batch_norm_inference
{},
s
,
vars
,
vars
,
vars
,
vars
,
vars
);
}
void
convolution_shape
(
)
TEST_CASE
(
convolution_shape
)
{
migraph
::
shape
output
{
migraph
::
shape
::
float_type
,
{
4
,
4
,
1
,
1
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
3
,
3
}};
...
...
@@ -76,7 +76,7 @@ void convolution_shape()
throws_shape
(
migraph
::
op
::
convolution
{},
input2
,
weights
);
}
void
transpose_shape
(
)
TEST_CASE
(
transpose_shape
)
{
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
2
}};
migraph
::
shape
output
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
...
...
@@ -85,7 +85,7 @@ void transpose_shape()
throws_shape
(
migraph
::
op
::
transpose
{{
1
,
2
}},
input
);
}
void
contiguous_shape
(
)
TEST_CASE
(
contiguous_shape
)
{
migraph
::
shape
output
{
migraph
::
shape
::
float_type
,
{
2
,
2
}};
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
2
},
{
1
,
2
}};
...
...
@@ -96,7 +96,7 @@ void contiguous_shape()
expect_shape
(
single
,
migraph
::
op
::
contiguous
{},
single
);
}
void
reshape_shape
(
)
TEST_CASE
(
reshape_shape
)
{
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
24
,
1
,
1
,
1
}};
for
(
auto
&&
new_shape
:
...
...
@@ -114,7 +114,7 @@ void reshape_shape()
}
}
void
flatten_shape
(
)
TEST_CASE
(
flatten_shape
)
{
migraph
::
shape
input
{
migraph
::
shape
::
float_type
,
{
2
,
4
,
6
,
8
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
1
,
2
*
4
*
6
*
8
}},
...
...
@@ -132,7 +132,7 @@ void flatten_shape()
throws_shape
(
migraph
::
op
::
flatten
{
5
},
input
);
}
void
slice_shape
(
)
TEST_CASE
(
slice_shape
)
{
migraph
::
shape
input
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
expect_shape
(
migraph
::
shape
{
migraph
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}},
...
...
@@ -145,13 +145,4 @@ void slice_shape()
migraph
::
op
::
slice
{{
2
},
{
2
},
{
10
}},
input
);
}
int
main
()
{
batch_norm_inference_shape
();
convolution_shape
();
transpose_shape
();
contiguous_shape
();
reshape_shape
();
flatten_shape
();
slice_shape
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/operation.cpp
View file @
8cae675e
...
...
@@ -43,7 +43,7 @@ struct simple_operation_no_print
}
};
void
operation_copy_test
(
)
TEST_CASE
(
operation_copy_test
)
{
simple_operation
s
{};
migraph
::
operation
op1
=
s
;
// NOLINT
...
...
@@ -54,7 +54,7 @@ void operation_copy_test()
EXPECT
(
op2
==
op1
);
}
void
operation_equal_test
(
)
TEST_CASE
(
operation_equal_test
)
{
simple_operation
s
{};
migraph
::
operation
op1
=
s
;
...
...
@@ -72,7 +72,7 @@ struct not_operation
{
};
void
operation_any_cast
(
)
TEST_CASE
(
operation_any_cast
)
{
migraph
::
operation
op1
=
simple_operation
{};
EXPECT
(
migraph
::
any_cast
<
simple_operation
>
(
op1
).
data
==
1
);
...
...
@@ -83,7 +83,7 @@ void operation_any_cast()
EXPECT
(
migraph
::
any_cast
<
not_operation
*>
(
&
op2
)
==
nullptr
);
}
void
operation_print
(
)
TEST_CASE
(
operation_print
)
{
migraph
::
operation
op
=
simple_operation
{};
std
::
stringstream
ss
;
...
...
@@ -92,7 +92,7 @@ void operation_print()
EXPECT
(
s
==
"simple[1]"
);
}
void
operation_default_print
(
)
TEST_CASE
(
operation_default_print
)
{
migraph
::
operation
op
=
simple_operation_no_print
{};
std
::
stringstream
ss
;
...
...
@@ -101,11 +101,4 @@ void operation_default_print()
EXPECT
(
s
==
"simple"
);
}
int
main
()
{
operation_copy_test
();
operation_equal_test
();
operation_any_cast
();
operation_print
();
operation_default_print
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/output_alias.cpp
View file @
8cae675e
...
...
@@ -3,7 +3,7 @@
#include <test.hpp>
#include <basic_ops.hpp>
void
simple_alias
(
)
TEST_CASE
(
simple_alias
)
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
1
);
...
...
@@ -12,7 +12,7 @@ void simple_alias()
EXPECT
(
bool
{
migraph
::
instruction
::
get_output_alias
(
p1
)
==
l
});
}
void
cascade_alias
(
)
TEST_CASE
(
cascade_alias
)
{
migraph
::
program
p
;
auto
l
=
p
.
add_literal
(
1
);
...
...
@@ -25,7 +25,7 @@ void cascade_alias()
EXPECT
(
bool
{
migraph
::
instruction
::
get_output_alias
(
p3
)
==
l
});
}
void
no_alias
(
)
TEST_CASE
(
no_alias
)
{
migraph
::
program
p
;
auto
x
=
p
.
add_literal
(
1
);
...
...
@@ -34,9 +34,4 @@ void no_alias()
EXPECT
(
bool
{
migraph
::
instruction
::
get_output_alias
(
sum
)
==
sum
});
}
int
main
()
{
simple_alias
();
cascade_alias
();
no_alias
();
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/program_test.cpp
View file @
8cae675e
...
...
@@ -20,11 +20,11 @@ migraph::program create_program()
return
p
;
}
void
program_equality
(
)
TEST_CASE
(
program_equality
)
{
migraph
::
program
x
=
create_program
();
migraph
::
program
y
=
create_program
();
EXPECT
(
x
==
y
);
}
int
main
(
)
{
program_equality
(
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment