Commit c2bafa5d authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'ck-flash-attn' of...

Merge branch 'ck-flash-attn' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into ck-flash-attn
parents 370b2cce 250d3c87
...@@ -49,7 +49,7 @@ TEST_CASE(logical_or_test) ...@@ -49,7 +49,7 @@ TEST_CASE(logical_or_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 or n2; return n1 or n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(logical_or_dyn_test) TEST_CASE(logical_or_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(logical_or_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(logical_or_dyn_test)
right_data.begin(), right_data.begin(),
gold.begin(), gold.begin(),
[](bool n1, bool n2) -> bool { return n1 or n2; }); [](bool n1, bool n2) -> bool { return n1 or n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -49,7 +49,7 @@ TEST_CASE(logical_xor_test) ...@@ -49,7 +49,7 @@ TEST_CASE(logical_xor_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](bool n1, bool n2) -> bool {
return n1 ^ n2; return n1 ^ n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(logical_xor_dyn_test) TEST_CASE(logical_xor_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(logical_xor_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(logical_xor_dyn_test)
right_data.begin(), right_data.begin(),
gold.begin(), gold.begin(),
[](bool n1, bool n2) -> bool { return n1 ^ n2; }); [](bool n1, bool n2) -> bool { return n1 ^ n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -61,7 +61,7 @@ TEST_CASE(logsoftmax_test_axis_0) ...@@ -61,7 +61,7 @@ TEST_CASE(logsoftmax_test_axis_0)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_1) TEST_CASE(logsoftmax_test_axis_1)
...@@ -95,7 +95,7 @@ TEST_CASE(logsoftmax_test_axis_1) ...@@ -95,7 +95,7 @@ TEST_CASE(logsoftmax_test_axis_1)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_2) TEST_CASE(logsoftmax_test_axis_2)
...@@ -129,7 +129,7 @@ TEST_CASE(logsoftmax_test_axis_2) ...@@ -129,7 +129,7 @@ TEST_CASE(logsoftmax_test_axis_2)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_3) TEST_CASE(logsoftmax_test_axis_3)
...@@ -163,5 +163,5 @@ TEST_CASE(logsoftmax_test_axis_3) ...@@ -163,5 +163,5 @@ TEST_CASE(logsoftmax_test_axis_3)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> results_vector; std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify::verify_range(results_vector, s)); EXPECT(migraphx::verify::verify_rms_range(results_vector, s));
} }
...@@ -43,5 +43,5 @@ TEST_CASE(lrn_test) ...@@ -43,5 +43,5 @@ TEST_CASE(lrn_test)
std::vector<float> results_vector(5); std::vector<float> results_vector(5);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075}; std::vector<float> gold = {-2 / 1.000075, 1 / 1.00009, 0 / 1.000145, 1 / 1.00009, 2 / 1.000075};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(max_test) ...@@ -45,7 +45,7 @@ TEST_CASE(max_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{7, 8, 9}; std::vector<float> gold{7, 8, 9};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(max_dyn_test) TEST_CASE(max_dyn_test)
...@@ -73,5 +73,5 @@ TEST_CASE(max_dyn_test) ...@@ -73,5 +73,5 @@ TEST_CASE(max_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{7, 8, 9}; std::vector<float> gold{7, 8, 9};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(min_test) ...@@ -45,7 +45,7 @@ TEST_CASE(min_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 4, 3}; std::vector<float> gold{1, 4, 3};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(min_dyn_test) TEST_CASE(min_dyn_test)
...@@ -73,5 +73,5 @@ TEST_CASE(min_dyn_test) ...@@ -73,5 +73,5 @@ TEST_CASE(min_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1, 4, 3}; std::vector<float> gold{1, 4, 3};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(mod_test) ...@@ -45,7 +45,7 @@ TEST_CASE(mod_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 2}; std::vector<float> gold{0, 0, 2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(mod_dyn_test) TEST_CASE(mod_dyn_test)
...@@ -73,7 +73,7 @@ TEST_CASE(mod_dyn_test) ...@@ -73,7 +73,7 @@ TEST_CASE(mod_dyn_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 0, 2}; std::vector<float> gold{0, 0, 2};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(mod_float_test) TEST_CASE(mod_float_test)
...@@ -92,5 +92,5 @@ TEST_CASE(mod_float_test) ...@@ -92,5 +92,5 @@ TEST_CASE(mod_float_test)
std::vector<float> results_vector(4); std::vector<float> results_vector(4);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{1.0f, 2.5f, 2.0f}; std::vector<float> gold{1.0f, 2.5f, 2.0f};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -49,7 +49,7 @@ TEST_CASE(mul_test) ...@@ -49,7 +49,7 @@ TEST_CASE(mul_test)
data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float { data1.begin(), data1.end(), data2.begin(), gold.begin(), [](float n1, float n2) -> float {
return n1 * n2; return n1 * n2;
}); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(mul_dyn_test) TEST_CASE(mul_dyn_test)
...@@ -78,5 +78,5 @@ TEST_CASE(mul_dyn_test) ...@@ -78,5 +78,5 @@ TEST_CASE(mul_dyn_test)
y_data.begin(), y_data.begin(),
gold.begin(), gold.begin(),
[](float n1, float n2) -> float { return n1 * n2; }); [](float n1, float n2) -> float { return n1 * n2; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
...@@ -78,5 +78,6 @@ TEST_CASE(multinomial_test) ...@@ -78,5 +78,6 @@ TEST_CASE(multinomial_test)
std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) { std::transform(res_dist.begin(), res_dist.end(), res_norm.begin(), [&](auto n) {
return static_cast<double>(n) / res_dist_sum; return static_cast<double>(n) / res_dist_sum;
}); });
EXPECT(migraphx::verify::verify_range(norm, res_norm, 100000)); EXPECT(migraphx::verify::verify_range_with_tolerance(
res_norm, migraphx::verify::expected{norm}, migraphx::verify::tolerance{0.01}));
} }
...@@ -45,7 +45,7 @@ TEST_CASE(neg_test) ...@@ -45,7 +45,7 @@ TEST_CASE(neg_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>()); std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>());
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(neg_dyn_test) TEST_CASE(neg_dyn_test)
...@@ -67,5 +67,5 @@ TEST_CASE(neg_dyn_test) ...@@ -67,5 +67,5 @@ TEST_CASE(neg_dyn_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = a; std::vector<float> gold = a;
std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>()); std::transform(gold.begin(), gold.end(), gold.begin(), std::negate<float>());
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
This diff is collapsed.
...@@ -46,5 +46,5 @@ TEST_CASE(nonzero_test) ...@@ -46,5 +46,5 @@ TEST_CASE(nonzero_test)
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, std::vector<int64_t> gold = {0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0}; 1, 1, 0, 0, 0, 0, 0, 1, 0, 2, 0, 2, 0, 2, 0, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
...@@ -44,7 +44,7 @@ TEST_CASE(not_test_int32) ...@@ -44,7 +44,7 @@ TEST_CASE(not_test_int32)
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold{1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(not_test_bool) TEST_CASE(not_test_bool)
...@@ -62,7 +62,7 @@ TEST_CASE(not_test_bool) ...@@ -62,7 +62,7 @@ TEST_CASE(not_test_bool)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<bool> gold(data.size()); std::vector<bool> gold(data.size());
std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; }); std::transform(data.begin(), data.end(), gold.begin(), [](bool n) -> bool { return not n; });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(not_dyn_test) TEST_CASE(not_dyn_test)
...@@ -83,5 +83,5 @@ TEST_CASE(not_dyn_test) ...@@ -83,5 +83,5 @@ TEST_CASE(not_dyn_test)
std::vector<char> results_vector; std::vector<char> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<char> gold{1, 0, 0, 0}; std::vector<char> gold{1, 0, 0, 0};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
This diff is collapsed.
...@@ -47,5 +47,5 @@ TEST_CASE(pointwise_test) ...@@ -47,5 +47,5 @@ TEST_CASE(pointwise_test)
std::vector<float> results_vector(3); std::vector<float> results_vector(3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {0, 2, 4}; std::vector<float> gold = {0, 2, 4};
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
This diff is collapsed.
...@@ -46,7 +46,7 @@ TEST_CASE(pow_test) ...@@ -46,7 +46,7 @@ TEST_CASE(pow_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
TEST_CASE(pow_dyn_test) TEST_CASE(pow_dyn_test)
...@@ -70,5 +70,5 @@ TEST_CASE(pow_dyn_test) ...@@ -70,5 +70,5 @@ TEST_CASE(pow_dyn_test)
std::vector<float> gold = data; std::vector<float> gold = data;
std::transform( std::transform(
gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); }); gold.begin(), gold.end(), gold.begin(), [](float n) -> float { return std::pow(n, n); });
EXPECT(migraphx::verify::verify_range(results_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(results_vector, gold));
} }
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment