Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
68a9a23f
Unverified
Commit
68a9a23f
authored
Jul 16, 2023
by
Umang Yadav
Committed by
GitHub
Jul 16, 2023
Browse files
add verify namespace (#1952)
parent
c4765a6d
Changes
17
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
514 additions
and
507 deletions
+514
-507
docs/dev_intro.rst
docs/dev_intro.rst
+1
-1
src/include/migraphx/verify.hpp
src/include/migraphx/verify.hpp
+2
-0
src/verify_args.cpp
src/verify_args.cpp
+13
-13
test/gpu/codegen_literal.cpp
test/gpu/codegen_literal.cpp
+1
-1
test/gpu/manage_host_buffer.cpp
test/gpu/manage_host_buffer.cpp
+1
-1
test/gpu/quantization.cpp
test/gpu/quantization.cpp
+3
-3
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+68
-68
test/quantization.cpp
test/quantization.cpp
+4
-4
test/ref_dev_examples.cpp
test/ref_dev_examples.cpp
+1
-1
test/ref_dot_op_test.cpp
test/ref_dot_op_test.cpp
+41
-41
test/ref_ops_nonstd_shape_test.cpp
test/ref_ops_nonstd_shape_test.cpp
+3
-3
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+274
-269
test/ref_rnn_ops_test.cpp
test/ref_rnn_ops_test.cpp
+90
-90
test/rewrite_pooling_test.cpp
test/rewrite_pooling_test.cpp
+2
-2
test/run_on_target_test.cpp
test/run_on_target_test.cpp
+1
-1
test/shape_test.cpp
test/shape_test.cpp
+7
-7
test/simplify_qdq_test.cpp
test/simplify_qdq_test.cpp
+2
-2
No files found.
docs/dev_intro.rst
View file @
68a9a23f
...
@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
...
@@ -131,7 +131,7 @@ In this case, we can create `argument <migraphx::argument>` objects directly fro
std::vector<float> results_vector(64);
std::vector<float> results_vector(64);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, sol));
EXPECT(migraphx::verify
::verify
_range(results_vector, sol));
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
An `argument <migraphx::argument>` can handle memory buffers from either the GPU or the CPU.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
By default when running the `program <migraphx::program>`, buffers are allocated on the corresponding target.
...
...
src/include/migraphx/verify.hpp
View file @
68a9a23f
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
verify
{
// Compute the value of a range
// Compute the value of a range
template
<
class
R
>
template
<
class
R
>
...
@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
...
@@ -196,6 +197,7 @@ bool verify_range(const R1& r1, const R2& r2, double tolerance = 80, double* out
return
error
<=
threshold
;
return
error
<=
threshold
;
}
}
}
// namespace verify
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
src/verify_args.cpp
View file @
68a9a23f
...
@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
...
@@ -35,7 +35,7 @@ bool verify_args(const std::string& name,
bool
passed
=
true
;
bool
passed
=
true
;
visit_all
(
ref_arg
,
target_arg
)([
&
](
auto
ref
,
auto
target
)
{
visit_all
(
ref_arg
,
target_arg
)([
&
](
auto
ref
,
auto
target
)
{
double
error
;
double
error
;
passed
=
verify_range
(
ref
,
target
,
tolerance
,
&
error
);
passed
=
verify
::
verify_range
(
ref
,
target
,
tolerance
,
&
error
);
if
(
not
passed
)
if
(
not
passed
)
{
{
// TODO: Check for nans
// TODO: Check for nans
...
@@ -45,27 +45,27 @@ bool verify_args(const std::string& name,
...
@@ -45,27 +45,27 @@ bool verify_args(const std::string& name,
std
::
cout
<<
"ref:"
<<
ref
<<
std
::
endl
;
std
::
cout
<<
"ref:"
<<
ref
<<
std
::
endl
;
if
(
target
.
size
()
<
32
)
if
(
target
.
size
()
<
32
)
std
::
cout
<<
"target:"
<<
target
<<
std
::
endl
;
std
::
cout
<<
"target:"
<<
target
<<
std
::
endl
;
if
(
range_zero
(
ref
))
if
(
verify
::
range_zero
(
ref
))
std
::
cout
<<
"Ref data is all zeros"
<<
std
::
endl
;
std
::
cout
<<
"Ref data is all zeros"
<<
std
::
endl
;
if
(
range_zero
(
target
))
if
(
verify
::
range_zero
(
target
))
std
::
cout
<<
"Target data is all zeros"
<<
std
::
endl
;
std
::
cout
<<
"Target data is all zeros"
<<
std
::
endl
;
auto
mxdiff
=
max_diff
(
ref
,
target
);
auto
mxdiff
=
verify
::
max_diff
(
ref
,
target
);
std
::
cout
<<
"Max diff: "
<<
mxdiff
<<
std
::
endl
;
std
::
cout
<<
"Max diff: "
<<
mxdiff
<<
std
::
endl
;
auto
idx
=
mismatch_idx
(
ref
,
target
,
float_equal
);
auto
idx
=
verify
::
mismatch_idx
(
ref
,
target
,
float_equal
);
if
(
idx
<
range_distance
(
ref
))
if
(
idx
<
verify
::
range_distance
(
ref
))
{
{
std
::
cout
<<
"Mismatch at "
<<
idx
<<
": "
<<
ref
[
idx
]
<<
" != "
<<
target
[
idx
]
std
::
cout
<<
"Mismatch at "
<<
idx
<<
": "
<<
ref
[
idx
]
<<
" != "
<<
target
[
idx
]
<<
std
::
endl
;
<<
std
::
endl
;
}
}
auto
ref_nan_idx
=
find_idx
(
ref
,
not_finite
);
auto
ref_nan_idx
=
find_idx
(
ref
,
verify
::
not_finite
);
if
(
ref_nan_idx
>=
0
)
if
(
ref_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in ref at "
<<
ref_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in ref at "
<<
ref_nan_idx
<<
": "
<<
ref
[
ref_nan_idx
]
<<
std
::
endl
;
<<
ref
[
ref_nan_idx
]
<<
std
::
endl
;
auto
target_nan_idx
=
find_idx
(
target
,
not_finite
);
auto
target_nan_idx
=
find_idx
(
target
,
verify
::
not_finite
);
if
(
target_nan_idx
>=
0
)
if
(
target_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
...
@@ -73,27 +73,27 @@ bool verify_args(const std::string& name,
...
@@ -73,27 +73,27 @@ bool verify_args(const std::string& name,
}
}
else
else
{
{
if
(
range_zero
(
ref
))
if
(
verify
::
range_zero
(
ref
))
std
::
cout
<<
"Ref data is all zeros"
<<
std
::
endl
;
std
::
cout
<<
"Ref data is all zeros"
<<
std
::
endl
;
if
(
range_zero
(
target
))
if
(
verify
::
range_zero
(
target
))
std
::
cout
<<
"Target data is all zeros"
<<
std
::
endl
;
std
::
cout
<<
"Target data is all zeros"
<<
std
::
endl
;
// auto mxdiff = max_diff(ref, target);
// auto mxdiff = max_diff(ref, target);
// std::cout << "Max diff: " << mxdiff << std::endl;
// std::cout << "Max diff: " << mxdiff << std::endl;
// auto idx = mismatch_idx(ref, target, float_equal);
// auto idx = mismatch_idx(ref, target, float_equal);
// if(idx < range_distance(ref))
// if(idx <
verify::
range_distance(ref))
// {
// {
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// std::cout << "Mismatch at " << idx << ": " << ref[idx] << " != " << target[idx]
// << std::endl;
// << std::endl;
// }
// }
auto
ref_nan_idx
=
find_idx
(
ref
,
not_finite
);
auto
ref_nan_idx
=
find_idx
(
ref
,
verify
::
not_finite
);
if
(
ref_nan_idx
>=
0
)
if
(
ref_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in ref at "
<<
ref_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in ref at "
<<
ref_nan_idx
<<
": "
<<
ref
[
ref_nan_idx
]
<<
std
::
endl
;
<<
ref
[
ref_nan_idx
]
<<
std
::
endl
;
auto
target_nan_idx
=
find_idx
(
target
,
not_finite
);
auto
target_nan_idx
=
find_idx
(
target
,
verify
::
not_finite
);
if
(
target_nan_idx
>=
0
)
if
(
target_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
...
...
test/gpu/codegen_literal.cpp
View file @
68a9a23f
...
@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
...
@@ -80,7 +80,7 @@ TEST_CASE(mul_literal_round_test)
migraphx
::
target
gpu_t
=
migraphx
::
make_target
(
"gpu"
);
migraphx
::
target
gpu_t
=
migraphx
::
make_target
(
"gpu"
);
run_prog
(
p
,
gpu_t
,
m
,
gpu_result
);
run_prog
(
p
,
gpu_t
,
m
,
gpu_result
);
EXPECT
(
migraphx
::
verify_range
(
ref_result
,
gpu_result
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
ref_result
,
gpu_result
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/gpu/manage_host_buffer.cpp
View file @
68a9a23f
...
@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy)
...
@@ -64,7 +64,7 @@ TEST_CASE(host_same_buffer_copy)
auto
result
=
p
.
eval
(
pp
).
back
();
auto
result
=
p
.
eval
(
pp
).
back
();
std
::
vector
<
float
>
results_vector
(
ss
.
elements
(),
-
1
);
std
::
vector
<
float
>
results_vector
(
ss
.
elements
(),
-
1
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c_vec
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c_vec
,
results_vector
));
}
}
TEST_CASE
(
arguments_lifetime
)
TEST_CASE
(
arguments_lifetime
)
...
...
test/gpu/quantization.cpp
View file @
68a9a23f
...
@@ -52,7 +52,7 @@ TEST_CASE(gpu_target_copy)
...
@@ -52,7 +52,7 @@ TEST_CASE(gpu_target_copy)
std
::
vector
<
int8_t
>
val_final
;
std
::
vector
<
int8_t
>
val_final
;
ref_arg_final
.
visit
([
&
](
auto
v
)
{
val_final
.
assign
(
v
.
begin
(),
v
.
end
());
});
ref_arg_final
.
visit
([
&
](
auto
v
)
{
val_final
.
assign
(
v
.
begin
(),
v
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
val_orig
,
val_final
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
val_orig
,
val_final
));
}
}
TEST_CASE
(
int8_quantization
)
TEST_CASE
(
int8_quantization
)
...
@@ -118,9 +118,9 @@ TEST_CASE(int8_quantization)
...
@@ -118,9 +118,9 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much
// the regular pipeline uses the rewrite_quantization in the much
// earlier stage.
// earlier stage.
if
(
migraphx
::
gpu
::
mlir_enabled
())
if
(
migraphx
::
gpu
::
mlir_enabled
())
EXPECT
(
migraphx
::
verify_range
(
ref_result
,
gpu_result
,
1e5
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
ref_result
,
gpu_result
,
1e5
));
else
else
EXPECT
(
migraphx
::
verify_range
(
ref_result
,
gpu_result
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
ref_result
,
gpu_result
));
}
}
}
}
...
...
test/onnx/verify_onnx.cpp
View file @
68a9a23f
This diff is collapsed.
Click to expand it.
test/quantization.cpp
View file @
68a9a23f
...
@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy)
...
@@ -1020,7 +1020,7 @@ TEST_CASE(target_copy)
std
::
vector
<
float
>
orig_result
;
std
::
vector
<
float
>
orig_result
;
run_prog
(
p
,
ref_t
,
m
,
orig_result
);
run_prog
(
p
,
ref_t
,
m
,
orig_result
);
EXPECT
(
migraphx
::
verify_range
(
ref_result
,
orig_result
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
ref_result
,
orig_result
));
}
}
}
}
...
@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot)
...
@@ -1084,7 +1084,7 @@ TEST_CASE(int8_quantization_dot)
std
::
vector
<
float
>
no_quant_result
;
std
::
vector
<
float
>
no_quant_result
;
run_prog
(
p
,
ref_t
,
m
,
no_quant_result
);
run_prog
(
p
,
ref_t
,
m
,
no_quant_result
);
EXPECT
(
migraphx
::
verify_range
(
quant_result
,
no_quant_result
,
30000
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
quant_result
,
no_quant_result
,
30000
));
}
}
}
}
...
@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv)
...
@@ -1129,7 +1129,7 @@ TEST_CASE(int8_quantization_conv)
std
::
vector
<
float
>
no_quant_result
;
std
::
vector
<
float
>
no_quant_result
;
run_prog
(
p
,
ref_t
,
no_quant_result
);
run_prog
(
p
,
ref_t
,
no_quant_result
);
EXPECT
(
migraphx
::
verify_range
(
quant_result
,
no_quant_result
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
quant_result
,
no_quant_result
));
}
}
}
}
...
@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture)
...
@@ -1281,7 +1281,7 @@ TEST_CASE(test_op_capture)
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
cap_res
.
visit
([
&
](
auto
output
)
{
cap_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res
.
visit
([
&
](
auto
output
)
{
vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
vec
,
cap_vec
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
vec
,
cap_vec
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/ref_dev_examples.cpp
View file @
68a9a23f
...
@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
...
@@ -168,7 +168,7 @@ TEST_CASE(handling_tensors)
std
::
vector
<
float
>
results_vector
(
64
);
std
::
vector
<
float
>
results_vector
(
64
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
sol
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
results_vector
,
sol
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/ref_dot_op_test.cpp
View file @
68a9a23f
...
@@ -80,7 +80,7 @@ void dot_2d_test()
...
@@ -80,7 +80,7 @@ void dot_2d_test()
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
T
>
results_vector
;
std
::
vector
<
T
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE_REGISTER
(
dot_2d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_2d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_2d_test
<
double
>
)
TEST_CASE_REGISTER
(
dot_2d_test
<
double
>
)
...
@@ -131,7 +131,7 @@ void dot_4d_test()
...
@@ -131,7 +131,7 @@ void dot_4d_test()
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
T
>
results_vector
;
std
::
vector
<
T
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE_REGISTER
(
dot_4d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_4d_test
<
float
>
)
TEST_CASE_REGISTER
(
dot_4d_test
<
double
>
)
TEST_CASE_REGISTER
(
dot_4d_test
<
double
>
)
...
@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
...
@@ -186,7 +186,7 @@ TEST_CASE(dot_3D_test)
0.40245487
,
0.40245487
,
1.80182751
};
1.80182751
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_3D_C_test0
)
TEST_CASE
(
dot_3D_C_test0
)
...
@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
...
@@ -262,7 +262,7 @@ TEST_CASE(dot_3D_C_test0)
0.40245487
,
0.40245487
,
1.80182751
};
1.80182751
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_3D_C_test1
)
TEST_CASE
(
dot_3D_C_test1
)
...
@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
...
@@ -321,7 +321,7 @@ TEST_CASE(dot_3D_C_test1)
-
0.95536130
,
-
0.95536130
,
2.27996211
};
2.27996211
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_4D_test1
)
TEST_CASE
(
dot_4D_test1
)
...
@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
...
@@ -360,7 +360,7 @@ TEST_CASE(dot_4D_test1)
-
0.95467340
,
-
1.74728628
,
-
2.42477030
,
0.76262372
,
0.15539164
,
-
0.95467340
,
-
1.74728628
,
-
2.42477030
,
0.76262372
,
0.15539164
,
3.32281958
,
0.96769613
,
0.43727545
,
2.43019906
};
3.32281958
,
0.96769613
,
0.43727545
,
2.43019906
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_4D_alpha_beta_test
)
TEST_CASE
(
dot_4D_alpha_beta_test
)
...
@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
...
@@ -414,7 +414,7 @@ TEST_CASE(dot_4D_alpha_beta_test)
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_4D_alpha_beta_C_test
)
TEST_CASE
(
dot_4D_alpha_beta_C_test
)
...
@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
...
@@ -466,7 +466,7 @@ TEST_CASE(dot_4D_alpha_beta_C_test)
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
-
0.17183724
,
0.10858734
,
0.39406289
,
0.04662959
,
1.07979824
,
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
0.40355016
,
0.52410648
,
-
0.31728447
,
1.09550845
};
EXPECT
(
migraphx
::
verify_range
(
m
,
m_res
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
m_res
));
}
}
TEST_CASE
(
dot_2D_C_test0
)
TEST_CASE
(
dot_2D_C_test0
)
...
@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
...
@@ -529,7 +529,7 @@ TEST_CASE(dot_2D_C_test0)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
...
@@ -567,7 +567,7 @@ TEST_CASE(dot_vv_inner_product)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
...
@@ -604,7 +604,7 @@ TEST_CASE(dot_vv_inner_product)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
...
@@ -642,7 +642,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
...
@@ -679,7 +679,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
...
@@ -726,7 +726,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
...
@@ -774,7 +774,7 @@ TEST_CASE(dot_vm)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
...
@@ -813,7 +813,7 @@ TEST_CASE(dot_mv)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
...
@@ -851,7 +851,7 @@ TEST_CASE(dot_mv)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
...
@@ -895,7 +895,7 @@ TEST_CASE(dot_mv)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
...
@@ -949,7 +949,7 @@ TEST_CASE(dot_mm1)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
...
@@ -1002,7 +1002,7 @@ TEST_CASE(dot_mm1)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
...
@@ -1047,7 +1047,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
...
@@ -1089,7 +1089,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
...
@@ -1141,7 +1141,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
...
@@ -1189,7 +1189,7 @@ TEST_CASE(dot_mm2)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
...
@@ -1242,7 +1242,7 @@ TEST_CASE(dot_dyn_2D_test)
-
1.29885596e+00
,
-
1.29885596e+00
,
2.16294914e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE
(
dot_dyn_4D_test
)
TEST_CASE
(
dot_dyn_4D_test
)
...
@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
...
@@ -1296,7 +1296,7 @@ TEST_CASE(dot_dyn_4D_test)
-
1.29885596e+00
,
-
1.29885596e+00
,
2.16294914e+00
,
2.16294914e+00
,
-
1.48101497e-01
};
-
1.48101497e-01
};
EXPECT
(
migraphx
::
verify_range
(
c
,
results_vector
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
c
,
results_vector
));
}
}
TEST_CASE
(
quant_dot_2args_multi4
)
TEST_CASE
(
quant_dot_2args_multi4
)
...
@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1324,7 +1324,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1352,7 +1352,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1380,7 +1380,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
...
@@ -1410,7 +1410,7 @@ TEST_CASE(quant_dot_2args_multi4)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1438,7 +1438,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1465,7 +1465,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1493,7 +1493,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
...
@@ -1522,7 +1522,7 @@ TEST_CASE(quant_dot_2args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1554,7 +1554,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1582,7 +1582,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1613,7 +1613,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1644,7 +1644,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
...
@@ -1677,7 +1677,7 @@ TEST_CASE(quant_dot_3args_general)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
...
@@ -1713,7 +1713,7 @@ TEST_CASE(quant_dot_3args_batch)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
{
{
...
@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
...
@@ -1751,7 +1751,7 @@ TEST_CASE(quant_dot_3args_batch)
auto
result
=
p
.
eval
({}).
back
();
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
m
;
std
::
vector
<
float
>
m
;
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
m
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
m
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
m
,
gold
));
}
}
}
}
...
...
test/ref_ops_nonstd_shape_test.cpp
View file @
68a9a23f
...
@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
...
@@ -49,7 +49,7 @@ TEST_CASE(argmax_test_nonstd_shape)
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
result_vec
,
res_gold_vec
));
}
}
TEST_CASE
(
argmin_test_nonstd_shape
)
TEST_CASE
(
argmin_test_nonstd_shape
)
...
@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
...
@@ -68,7 +68,7 @@ TEST_CASE(argmin_test_nonstd_shape)
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
result_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
int64_t
>
res_gold_vec
;
std
::
vector
<
int64_t
>
res_gold_vec
;
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
res_gold
.
visit
([
&
](
auto
output
)
{
res_gold_vec
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
res_gold_vec
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
result_vec
,
res_gold_vec
));
}
}
TEST_CASE
(
isnan_broadcast_test
)
TEST_CASE
(
isnan_broadcast_test
)
...
@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
...
@@ -88,7 +88,7 @@ TEST_CASE(isnan_broadcast_test)
std
::
vector
<
float
>
results_vector
;
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
correct
=
{
0
,
0
,
0
,
0
,
1
,
1
};
std
::
vector
<
float
>
correct
=
{
0
,
0
,
0
,
0
,
1
,
1
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
correct
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
results_vector
,
correct
));
}
}
TEST_CASE
(
squeeze_transpose_test
)
TEST_CASE
(
squeeze_transpose_test
)
...
...
test/ref_ops_test.cpp
View file @
68a9a23f
This diff is collapsed.
Click to expand it.
test/ref_rnn_ops_test.cpp
View file @
68a9a23f
This diff is collapsed.
Click to expand it.
test/rewrite_pooling_test.cpp
View file @
68a9a23f
...
@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
...
@@ -198,8 +198,8 @@ TEST_CASE(literal_rewrite_pooling_test)
p2
.
compile
(
migraphx
::
make_target
(
"ref"
));
p2
.
compile
(
migraphx
::
make_target
(
"ref"
));
auto
result1
=
p1
.
eval
({}).
back
();
auto
result1
=
p1
.
eval
({}).
back
();
auto
result2
=
p2
.
eval
({}).
back
();
auto
result2
=
p2
.
eval
({}).
back
();
visit_all
(
result1
,
visit_all
(
result1
,
result2
)(
result2
)(
[
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify_range
(
r1
,
r2
));
});
[
&
](
auto
r1
,
auto
r2
)
{
EXPECT
(
migraphx
::
verify
::
verify
_range
(
r1
,
r2
));
});
};
};
test_rewrite_pooling
(
migraphx
::
op
::
pooling_mode
::
max
,
test_rewrite_pooling
(
migraphx
::
op
::
pooling_mode
::
max
,
...
...
test/run_on_target_test.cpp
View file @
68a9a23f
...
@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
...
@@ -68,7 +68,7 @@ TEST_CASE(eval_run_on_target)
std
::
vector
<
float
>
results_vector
(
3
);
std
::
vector
<
float
>
results_vector
(
3
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
std
::
vector
<
float
>
gold
=
{
0.5
,
0.25
,
0.125
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
results_vector
,
gold
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/shape_test.cpp
View file @
68a9a23f
...
@@ -947,13 +947,13 @@ TEST_CASE(test_with_type)
...
@@ -947,13 +947,13 @@ TEST_CASE(test_with_type)
TEST_CASE
(
test_multi_index
)
TEST_CASE
(
test_multi_index
)
{
{
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
6
}};
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
0
),
std
::
vector
<
size_t
>
{
0
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
4
),
std
::
vector
<
size_t
>
{
0
,
0
,
4
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
6
),
std
::
vector
<
size_t
>
{
0
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
8
),
std
::
vector
<
size_t
>
{
0
,
1
,
2
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
24
),
std
::
vector
<
size_t
>
{
1
,
0
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
30
),
std
::
vector
<
size_t
>
{
1
,
1
,
0
}));
EXPECT
(
migraphx
::
verify_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
s
.
multi
(
34
),
std
::
vector
<
size_t
>
{
1
,
1
,
4
}));
}
}
TEST_CASE
(
find_permutation_2d_standard
)
TEST_CASE
(
find_permutation_2d_standard
)
...
...
test/simplify_qdq_test.cpp
View file @
68a9a23f
...
@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
...
@@ -700,7 +700,7 @@ TEST_CASE(conv_correctness)
auto
result2
=
p2
.
eval
({{
"input"
,
input
},
{
"weights"
,
weights
}}).
back
();
auto
result2
=
p2
.
eval
({{
"input"
,
input
},
{
"weights"
,
weights
}}).
back
();
std
::
vector
<
float
>
rv2
(
16
);
std
::
vector
<
float
>
rv2
(
16
);
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
rv1
,
rv2
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
rv1
,
rv2
));
}
}
TEST_CASE
(
dot_correctness
)
TEST_CASE
(
dot_correctness
)
...
@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
...
@@ -750,7 +750,7 @@ TEST_CASE(dot_correctness)
auto
result2
=
p2
.
eval
({{
"a"
,
a
},
{
"b"
,
b
}}).
back
();
auto
result2
=
p2
.
eval
({{
"a"
,
a
},
{
"b"
,
b
}}).
back
();
std
::
vector
<
float
>
rv2
(
sh3
.
elements
());
std
::
vector
<
float
>
rv2
(
sh3
.
elements
());
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
result2
.
visit
([
&
](
auto
output
)
{
rv2
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
rv1
,
rv2
));
EXPECT
(
migraphx
::
verify
::
verify
_range
(
rv1
,
rv2
));
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment