Commit c2bafa5d authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'ck-flash-attn' of...

Merge branch 'ck-flash-attn' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into ck-flash-attn
parents 370b2cce 250d3c87
...@@ -40,7 +40,6 @@ ...@@ -40,7 +40,6 @@
TEST_CASE(gpu_target_copy) TEST_CASE(gpu_target_copy)
{ {
migraphx::target gpu_t = migraphx::make_target("gpu"); migraphx::target gpu_t = migraphx::make_target("gpu");
migraphx::target ref_t = migraphx::make_target("ref");
migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}}; migraphx::shape s{migraphx::shape::int8_type, {2, 3, 4, 5}};
auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L); auto ref_arg_orig = migraphx::generate_argument(s, 0x123456L);
...@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy) ...@@ -52,7 +51,7 @@ TEST_CASE(gpu_target_copy)
std::vector<int8_t> val_final; std::vector<int8_t> val_final;
ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); }); ref_arg_final.visit([&](auto v) { val_final.assign(v.begin(), v.end()); });
EXPECT(migraphx::verify::verify_range(val_orig, val_final)); EXPECT(migraphx::verify::verify_rms_range(val_orig, val_final));
} }
TEST_CASE(int8_quantization) TEST_CASE(int8_quantization)
...@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization) ...@@ -118,9 +117,12 @@ TEST_CASE(int8_quantization)
// the regular pipeline uses the rewrite_quantization in the much // the regular pipeline uses the rewrite_quantization in the much
// earlier stage. // earlier stage.
if(migraphx::gpu::mlir_enabled()) if(migraphx::gpu::mlir_enabled())
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result, 1e5)); EXPECT(migraphx::verify::verify_range_with_tolerance(
gpu_result,
migraphx::verify::expected{ref_result},
migraphx::verify::tolerance{0.01}));
else else
EXPECT(migraphx::verify::verify_range(ref_result, gpu_result)); EXPECT(migraphx::verify::verify_rms_range(gpu_result, ref_result));
} }
} }
......
...@@ -24,16 +24,16 @@ ...@@ -24,16 +24,16 @@
#ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP #define MIGRAPHX_GUARD_TEST_INCLUDE_POINTWISE_HPP
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp> #include <migraphx/program.hpp>
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
template <class F> template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p, migraphx::module_ref create_pointwise_module(migraphx::program& p,
migraphx::module_ref mm, const std::string& name,
const std::string& name, std::vector<migraphx::instruction_ref> inputs,
std::vector<migraphx::instruction_ref> inputs, F f)
F f)
{ {
auto* pm = p.create_module(name); auto* pm = p.create_module(name);
pm->set_bypass(); pm->set_bypass();
...@@ -44,6 +44,17 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p, ...@@ -44,6 +44,17 @@ migraphx::instruction_ref add_pointwise(migraphx::program& p,
}); });
auto r = f(pm, params); auto r = f(pm, params);
pm->add_return({r}); pm->add_return({r});
return pm;
}
template <class F>
migraphx::instruction_ref add_pointwise(migraphx::program& p,
migraphx::module_ref mm,
const std::string& name,
std::vector<migraphx::instruction_ref> inputs,
F f)
{
auto* pm = create_pointwise_module(p, name, inputs, f);
return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm}); return mm->add_instruction(migraphx::make_op("pointwise"), inputs, {pm});
} }
......
This diff is collapsed.
...@@ -88,6 +88,13 @@ TEST_CASE(allocate_static) ...@@ -88,6 +88,13 @@ TEST_CASE(allocate_static)
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}})); expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}));
} }
TEST_CASE(allocate_static_input_error)
{
migraphx::shape input{migraphx::shape::int64_type, {3}};
migraphx::shape out_shape{migraphx::shape::float_type, {2, 3, 4}};
expect_shape(out_shape, migraphx::make_op("allocate", {{"shape", to_value(out_shape)}}), input);
}
TEST_CASE(allocate_dyn) TEST_CASE(allocate_dyn)
{ {
migraphx::shape input{migraphx::shape::int64_type, {2}}; migraphx::shape input{migraphx::shape::int64_type, {2}};
...@@ -109,6 +116,14 @@ TEST_CASE(allocate_dyn_with_shape_attr) ...@@ -109,6 +116,14 @@ TEST_CASE(allocate_dyn_with_shape_attr)
input); input);
} }
TEST_CASE(allocate_dyn_no_input_error)
{
migraphx::shape shape_attr{migraphx::shape::float_type,
{{1, 4}, {3, 3}, {4, 8, {4, 6}}, {4, 8}, {4, 6}}};
expect_shape(shape_attr,
migraphx::make_op("allocate", {{"shape", migraphx::to_value(shape_attr)}}));
}
TEST_CASE(argmax_axis0) TEST_CASE(argmax_axis0)
{ {
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}}; migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
...@@ -2524,13 +2539,21 @@ TEST_CASE(reshape_shape) ...@@ -2524,13 +2539,21 @@ TEST_CASE(reshape_shape)
migraphx::shape output{migraphx::shape::float_type, lens}; migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", new_shape}}), input);
} }
}
TEST_CASE(reshape_shape_invalid)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape : for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}}) std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}})
{ {
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input); throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
} }
}
TEST_CASE(reshape_shape_minus1_reshapes)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{ std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}}, {{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}}, {{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
...@@ -2654,11 +2677,11 @@ TEST_CASE(reshape_broadcast_squeeze) ...@@ -2654,11 +2677,11 @@ TEST_CASE(reshape_broadcast_squeeze)
expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_broadcast_squeeze_error) TEST_CASE(reshape_broadcast_squeeze_memlayout_change)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}}; migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
std::vector<int64_t> new_shape = {2, 16, 20480}; migraphx::shape output{migraphx::shape::float_type, {2, 16, 256, 80}, {0, 0, 0, 16}};
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input); expect_shape(output, migraphx::make_op("reshape", {{"dims", output.lens()}}), input);
} }
TEST_CASE(reshape_dyn_shape) TEST_CASE(reshape_dyn_shape)
...@@ -2706,6 +2729,199 @@ TEST_CASE(reshape_non_fixed_not_matching_error) ...@@ -2706,6 +2729,199 @@ TEST_CASE(reshape_non_fixed_not_matching_error)
throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input); throws_shape(migraphx::make_op("reshape", {{"dims", new_shape}}), input);
} }
TEST_CASE(reshape_lazy_shape)
{
migraphx::shape input{migraphx::shape::float_type, {24, 1, 1, 1}};
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 1, 1}, {1, 3, 4, 2}, {1, 3, 4, 2}})
{
std::vector<std::size_t> lens(new_shape.size());
std::copy(new_shape.begin(), new_shape.end(), lens.begin());
migraphx::shape output{migraphx::shape::float_type, lens};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
for(auto&& new_shape :
std::vector<std::vector<int64_t>>{{8, 3, 2, 2}, {1, 3, -1, -1}, {3, 0}, {3, 2}})
{
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
std::vector<std::pair<std::vector<int64_t>, migraphx::shape>> minus1_tests{
{{2, -1, 3}, {migraphx::shape::float_type, {2, 4, 3}}},
{{0, -1, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, -1, 0}, {migraphx::shape::float_type, {2, 12, 1}}},
{{0, 0, -1}, {migraphx::shape::float_type, {24, 1, 1}}},
{{2, 0, -1}, {migraphx::shape::float_type, {2, 1, 12}}},
{{-1, 2, 3}, {migraphx::shape::float_type, {4, 2, 3}}},
{{-1, 0, 3}, {migraphx::shape::float_type, {8, 1, 3}}},
{{-1, 0, 0}, {migraphx::shape::float_type, {24, 1, 1}}},
{{-1, 3, 0}, {migraphx::shape::float_type, {8, 3, 1}}}};
for(auto& it : minus1_tests)
{
expect_shape(it.second, migraphx::make_op("reshape_lazy", {{"dims", it.first}}), input);
}
}
// This uses the permutation to compute the reshape_lazy since its simpler than
// trying to calculate strides. As we collapse or expand dimensions, we
// remove the collapsed dimensions or duplicate the expanded dimensions in
// the permutation. Then we renumber the permutation. So for dimensions of 4,
// 24, 1, 1, 1 with a permutation of 1, 0, 2, 3, 4 that reshape_lazys to 4, 1, 3,
// 4, 2, we first remove the collapsed dimensions or duplicate the expanded
// dimensions which gives 1, 0, 0, 0, 0. Then after renumbering we get a
// final permutation of 4, 0, 1, 2, 3.
TEST_CASE(reshape_lazy_nonstandard)
{
auto input = migraphx::shape::from_permutation(migraphx::shape::float_type,
{4, 24, 1, 1, 1},
migraphx::invert_permutation({1, 0, 2, 3, 4}));
std::vector<std::pair<std::vector<std::size_t>, std::vector<int64_t>>> tests{
{{4, 24}, {1, 0}},
{{4, 24, 1, 1, 1, 1}, {1, 0, 2, 3, 4, 5}},
{{4, 8, 3, 1, 1}, {2, 0, 1, 3, 4}},
{{4, 1, 3, 4, 2}, {4, 0, 1, 2, 3}},
{{4, 1, 4, 3, 2}, {4, 0, 1, 2, 3}},
{{4, 2, 4, 3}, {3, 0, 1, 2}},
{{4, 2, 12, 1}, {2, 0, 1, 3}},
{{4, 2, 1, 12}, {3, 0, 1, 2}},
{{4, 4, 2, 3}, {3, 0, 1, 2}},
{{4, 8, 1, 3}, {3, 0, 1, 2}},
{{4, 8, 3, 1}, {2, 0, 1, 3}}};
for(const auto& [dims, perm] : tests)
{
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, dims, migraphx::invert_permutation(perm));
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", dims}}), input);
}
}
TEST_CASE(reshape_lazy_nonstandard_squeeze)
{
auto input = migraphx::shape::from_permutation(
migraphx::shape::float_type, {2, 16, 16, 1280}, migraphx::invert_permutation({0, 2, 3, 1}));
std::vector<std::size_t> lens = {2, 256, 1280};
migraphx::shape output = migraphx::shape::from_permutation(
migraphx::shape::float_type, lens, migraphx::invert_permutation({0, 2, 1}));
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", lens}}), input);
}
TEST_CASE(reshape_lazy_nonstandard_error)
{
auto input = migraphx::shape::from_permutation(migraphx::shape::float_type,
{4, 24, 1, 1, 1},
migraphx::invert_permutation({1, 0, 2, 3, 4}));
for(auto&& new_shape : std::vector<std::vector<int64_t>>{{4, 8, 3, 2, 2},
{1},
{4, 8, 4},
{4, 24, 1, 1, 1, 1, 2},
{8, 4, 4},
{4, 1, 3, -1, -1},
{4, 3, 0},
{4, 3, 2},
{3, 0},
{3, 2}})
{
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
}
TEST_CASE(reshape_lazy_nonpacked_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {4, 2, 8}, {32, 16, 2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_unsqueeze2)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {2, 2, 16}, {64, 32, 2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_nonpacked_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {4, 16}, {32, 2}};
migraphx::shape output{migraphx::shape::float_type, {64}, {2}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_unsqueeze1)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_unsqueeze2)
{
migraphx::shape input{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 16, 80}, {0, 0, 80, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
migraphx::shape output{migraphx::shape::float_type, {2, 256, 1280}, {0, 0, 1}};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", output.lens()}}), input);
}
TEST_CASE(reshape_lazy_broadcast_squeeze_error)
{
migraphx::shape input{migraphx::shape::float_type, {2, 16, 16, 1280}, {0, 0, 0, 1}};
std::vector<int64_t> new_shape = {2, 16, 20480};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_lazy_dyn_shape)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
for(auto&& new_shape : std::vector<std::vector<int64_t>>{
{-1, 1, 1, 24}, {0, 8, 3, 1}, {-1, 3, 4, 2}, {0, 2, 4, 3}})
{
std::vector<migraphx::shape::dynamic_dimension> out_dyn_dims{};
for(std::size_t i = 0; i < new_shape.size(); ++i)
{
if(new_shape[i] == 0 or new_shape[i] == -1)
{
out_dyn_dims.push_back(input.dyn_dims().at(i));
}
else
{
std::size_t d = new_shape[i];
out_dyn_dims.push_back({d, d});
}
}
migraphx::shape output{migraphx::shape::float_type, out_dyn_dims};
expect_shape(output, migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
}
TEST_CASE(reshape_lazy_multiple_non_fixed_error)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 20}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 1, 0, 24};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_lazy_fixed_ele_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {10, 10}, {1, 1}}};
std::vector<int64_t> new_shape = {0, 1, 5, 24};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
TEST_CASE(reshape_lazy_non_fixed_not_matching_error)
{
migraphx::shape input{migraphx::shape::float_type, {{1, 4}, {24, 24}, {1, 1}, {1, 1}}};
std::vector<int64_t> new_shape = {2, 1, 1, 24};
throws_shape(migraphx::make_op("reshape_lazy", {{"dims", new_shape}}), input);
}
TEST_CASE(return_shape_tuple) TEST_CASE(return_shape_tuple)
{ {
using migraphx::shape; using migraphx::shape;
......
...@@ -83,7 +83,7 @@ TEST_CASE(param_add) ...@@ -83,7 +83,7 @@ TEST_CASE(param_add)
auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2); auto hs = mm->add_instruction(migraphx::make_op("add"), hp1, hp2);
auto fs = mm->add_instruction( auto fs = mm->add_instruction(
migraphx::make_op("convert", migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}), {{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
hs); hs);
if(add_return) if(add_return)
{ {
...@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy) ...@@ -1013,7 +1013,7 @@ TEST_CASE(target_copy)
std::vector<float> orig_result; std::vector<float> orig_result;
run_prog(p, ref_t, m, orig_result); run_prog(p, ref_t, m, orig_result);
EXPECT(migraphx::verify::verify_range(ref_result, orig_result)); EXPECT(migraphx::verify::verify_rms_range(ref_result, orig_result));
} }
} }
...@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot) ...@@ -1077,7 +1077,10 @@ TEST_CASE(int8_quantization_dot)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, m, no_quant_result); run_prog(p, ref_t, m, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result, 30000)); EXPECT(migraphx::verify::verify_range_with_tolerance(
quant_result,
migraphx::verify::expected{no_quant_result},
migraphx::verify::tolerance{0.003}));
} }
} }
...@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv) ...@@ -1122,7 +1125,7 @@ TEST_CASE(int8_quantization_conv)
std::vector<float> no_quant_result; std::vector<float> no_quant_result;
run_prog(p, ref_t, no_quant_result); run_prog(p, ref_t, no_quant_result);
EXPECT(migraphx::verify::verify_range(quant_result, no_quant_result)); EXPECT(migraphx::verify::verify_rms_range(quant_result, no_quant_result));
} }
} }
...@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture) ...@@ -1274,7 +1277,7 @@ TEST_CASE(test_op_capture)
cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); }); cap_res.visit([&](auto output) { cap_vec.assign(output.begin(), output.end()); });
res.visit([&](auto output) { vec.assign(output.begin(), output.end()); }); res.visit([&](auto output) { vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(vec, cap_vec)); EXPECT(migraphx::verify::verify_rms_range(vec, cap_vec));
} }
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -42,7 +42,7 @@ TEST_CASE(abs_test) ...@@ -42,7 +42,7 @@ TEST_CASE(abs_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4}; std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(abs_dyn_test) TEST_CASE(abs_dyn_test)
...@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test) ...@@ -62,5 +62,5 @@ TEST_CASE(abs_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 2, 3, 4}; std::vector<float> gold{1, 2, 3, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(acos_test) ...@@ -45,7 +45,7 @@ TEST_CASE(acos_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(acos_dyn_test) TEST_CASE(acos_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(acos_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acosf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(acosh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(acosh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(acosh_dyn_test) TEST_CASE(acosh_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(acosh_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return acoshf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test) ...@@ -51,7 +51,7 @@ TEST_CASE(add_broadcast_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(add_multibroadcast_test) TEST_CASE(add_multibroadcast_test)
...@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test) ...@@ -75,7 +75,7 @@ TEST_CASE(add_multibroadcast_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8}; std::vector<float> gold = {0, 1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 8};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(add_test) TEST_CASE(add_test)
...@@ -91,7 +91,7 @@ TEST_CASE(add_test) ...@@ -91,7 +91,7 @@ TEST_CASE(add_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(add_dyn_test) TEST_CASE(add_dyn_test)
...@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test) ...@@ -115,7 +115,7 @@ TEST_CASE(add_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(fp16_test) TEST_CASE(fp16_test)
...@@ -134,7 +134,7 @@ TEST_CASE(fp16_test) ...@@ -134,7 +134,7 @@ TEST_CASE(fp16_test)
std::vector<migraphx::half> results_vector(1); std::vector<migraphx::half> results_vector(1);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<migraphx::half> gold{c}; std::vector<migraphx::half> gold{c};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(fp32_fp16_test) TEST_CASE(fp32_fp16_test)
...@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test) ...@@ -159,7 +159,7 @@ TEST_CASE(fp32_fp16_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> res; std::vector<float> res;
result.visit([&](auto output) { res.assign(output.begin(), output.end()); }); result.visit([&](auto output) { res.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(res, gold_res)); EXPECT(migraphx::verify::verify_rms_range(res, gold_res));
}; };
test_case({"all"}); test_case({"all"});
......
...@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0) ...@@ -47,7 +47,7 @@ TEST_CASE(argmax_test_0)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_1) TEST_CASE(argmax_test_1)
...@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1) ...@@ -66,7 +66,7 @@ TEST_CASE(argmax_test_1)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_2) TEST_CASE(argmax_test_2)
...@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2) ...@@ -85,7 +85,7 @@ TEST_CASE(argmax_test_2)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_neg_2) TEST_CASE(argmax_test_neg_2)
...@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2) ...@@ -104,7 +104,7 @@ TEST_CASE(argmax_test_neg_2)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_dyn_test) TEST_CASE(argmax_dyn_test)
...@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test) ...@@ -126,7 +126,7 @@ TEST_CASE(argmax_dyn_test)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1}; std::vector<int64_t> res_gold = {0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1};
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmax_test_nonstd_shape) TEST_CASE(argmax_test_nonstd_shape)
...@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape) ...@@ -145,5 +145,5 @@ TEST_CASE(argmax_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec));
} }
...@@ -47,7 +47,7 @@ TEST_CASE(argmin_test_0) ...@@ -47,7 +47,7 @@ TEST_CASE(argmin_test_0)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmin_test_1) TEST_CASE(argmin_test_1)
...@@ -66,7 +66,7 @@ TEST_CASE(argmin_test_1) ...@@ -66,7 +66,7 @@ TEST_CASE(argmin_test_1)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmin_test_2) TEST_CASE(argmin_test_2)
...@@ -85,7 +85,7 @@ TEST_CASE(argmin_test_2) ...@@ -85,7 +85,7 @@ TEST_CASE(argmin_test_2)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmin_test_neg_1) TEST_CASE(argmin_test_neg_1)
...@@ -104,7 +104,7 @@ TEST_CASE(argmin_test_neg_1) ...@@ -104,7 +104,7 @@ TEST_CASE(argmin_test_neg_1)
std::vector<int64_t> result_vec; std::vector<int64_t> result_vec;
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold));
} }
TEST_CASE(argmin_test_nonstd_shape) TEST_CASE(argmin_test_nonstd_shape)
...@@ -123,5 +123,5 @@ TEST_CASE(argmin_test_nonstd_shape) ...@@ -123,5 +123,5 @@ TEST_CASE(argmin_test_nonstd_shape)
result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vec.assign(output.begin(), output.end()); });
std::vector<int64_t> res_gold_vec; std::vector<int64_t> res_gold_vec;
res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); }); res_gold.visit([&](auto output) { res_gold_vec.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(result_vec, res_gold_vec)); EXPECT(migraphx::verify::verify_rms_range(result_vec, res_gold_vec));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(asin_test) ...@@ -45,7 +45,7 @@ TEST_CASE(asin_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(asin_dyn_test) TEST_CASE(asin_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(asin_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(asin_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(asinh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(asinh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(asinh_dyn_test) TEST_CASE(asinh_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(asinh_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(asinh_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return asinhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(atan_test) ...@@ -45,7 +45,7 @@ TEST_CASE(atan_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(atan_dyn_test) TEST_CASE(atan_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(atan_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(atan_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(atanh_test) ...@@ -45,7 +45,7 @@ TEST_CASE(atanh_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(atanh_dyn_test) TEST_CASE(atanh_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(atanh_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(atanh_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return atanhf(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -114,6 +114,6 @@ TEST_CASE(isnan_broadcast_test) ...@@ -114,6 +114,6 @@ TEST_CASE(isnan_broadcast_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> correct = {0, 0, 0, 0, 1, 1}; std::vector<float> gold = {0, 0, 0, 0, 1, 1};
EXPECT(migraphx::verify::verify_range(results_vector, correct)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(ceil_test) ...@@ -45,7 +45,7 @@ TEST_CASE(ceil_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(ceil_dyn_test) TEST_CASE(ceil_dyn_test)
...@@ -68,5 +68,5 @@ TEST_CASE(ceil_dyn_test) ...@@ -68,5 +68,5 @@ TEST_CASE(ceil_dyn_test)
std::vector<float> gold = input_data; std::vector<float> gold = input_data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::ceil(n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -48,7 +48,7 @@ TEST_CASE(clip_test) ...@@ -48,7 +48,7 @@ TEST_CASE(clip_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0, 0.0, 6.0}; std::vector<float> gold = {0.0, 0.0, 6.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(clip_dyn_test) TEST_CASE(clip_dyn_test)
...@@ -73,5 +73,5 @@ TEST_CASE(clip_dyn_test) ...@@ -73,5 +73,5 @@ TEST_CASE(clip_dyn_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0.0, 0.0, 6.0}; std::vector<float> gold = {0.0, 0.0, 6.0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -50,11 +50,11 @@ TEST_CASE(concat_test_1) ...@@ -50,11 +50,11 @@ TEST_CASE(concat_test_1)
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2 * 6); std::vector<int> results_vector(2 * 6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(result.get_shape().lens(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().lens(),
std::vector<std::size_t>({2, 6}))); std::vector<std::size_t>({2, 6})));
EXPECT(migraphx::verify::verify_range(result.get_shape().strides(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().strides(),
std::vector<std::size_t>({6, 1}))); std::vector<std::size_t>({6, 1})));
} }
TEST_CASE(concat_test_2) TEST_CASE(concat_test_2)
...@@ -77,11 +77,11 @@ TEST_CASE(concat_test_2) ...@@ -77,11 +77,11 @@ TEST_CASE(concat_test_2)
std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20}; std::vector<int> gold = {0, 1, 2, 3, 4, 10, 5, 6, 7, 8, 9, 20};
std::vector<int> results_vector(2 * 6); std::vector<int> results_vector(2 * 6);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(result.get_shape().lens(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().lens(),
std::vector<std::size_t>({2, 6}))); std::vector<std::size_t>({2, 6})));
EXPECT(migraphx::verify::verify_range(result.get_shape().strides(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().strides(),
std::vector<std::size_t>({6, 1}))); std::vector<std::size_t>({6, 1})));
} }
TEST_CASE(concat_test_3) TEST_CASE(concat_test_3)
...@@ -104,11 +104,11 @@ TEST_CASE(concat_test_3) ...@@ -104,11 +104,11 @@ TEST_CASE(concat_test_3)
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6 * 2); std::vector<int> results_vector(6 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(result.get_shape().lens(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().lens(),
std::vector<std::size_t>({6, 2}))); std::vector<std::size_t>({6, 2})));
EXPECT(migraphx::verify::verify_range(result.get_shape().strides(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().strides(),
std::vector<std::size_t>({2, 1}))); std::vector<std::size_t>({2, 1})));
} }
TEST_CASE(concat_test_4) TEST_CASE(concat_test_4)
...@@ -131,11 +131,11 @@ TEST_CASE(concat_test_4) ...@@ -131,11 +131,11 @@ TEST_CASE(concat_test_4)
std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<int> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
std::vector<int> results_vector(6 * 2); std::vector<int> results_vector(6 * 2);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(result.get_shape().lens(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().lens(),
std::vector<std::size_t>({6, 2}))); std::vector<std::size_t>({6, 2})));
EXPECT(migraphx::verify::verify_range(result.get_shape().strides(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().strides(),
std::vector<std::size_t>({2, 1}))); std::vector<std::size_t>({2, 1})));
} }
TEST_CASE(concat_dyn_test) TEST_CASE(concat_dyn_test)
...@@ -169,7 +169,7 @@ TEST_CASE(concat_dyn_test) ...@@ -169,7 +169,7 @@ TEST_CASE(concat_dyn_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<float> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
EXPECT(migraphx::verify::verify_range(result.get_shape().lens(), EXPECT(migraphx::verify::verify_rms_range(result.get_shape().lens(),
std::vector<std::size_t>({6, 2}))); std::vector<std::size_t>({6, 2})));
} }
...@@ -50,7 +50,7 @@ TEST_CASE(contiguous_test) ...@@ -50,7 +50,7 @@ TEST_CASE(contiguous_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; std::vector<float> gold = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(contiguous_param_test) TEST_CASE(contiguous_param_test)
...@@ -74,7 +74,7 @@ TEST_CASE(contiguous_param_test) ...@@ -74,7 +74,7 @@ TEST_CASE(contiguous_param_test)
std::vector<float> results_vector(12); std::vector<float> results_vector(12);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(contiguous_dyn_test) TEST_CASE(contiguous_dyn_test)
...@@ -100,5 +100,5 @@ TEST_CASE(contiguous_dyn_test) ...@@ -100,5 +100,5 @@ TEST_CASE(contiguous_dyn_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11}; std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment