#pragma once #include #include #include #include #include #include #include #include #include std::vector get_headers_for_test() { std::vector result; auto hs = ck::host::GetHeaders(); std::transform( hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { return {p.first, p.second}; }); return result; } template std::size_t GetSize(V mLens, V mStrides) { std::size_t space = 1; for(std::size_t i = 0; i < mLens.Size(); ++i) { if(mLens[i] == 0) continue; space += (mLens[i] - 1) * mStrides[i]; } return space; } template rtc::buffer generate_buffer(V mLens, V mStrides, std::size_t seed = 0) { std::size_t space = GetSize(mLens, mStrides); rtc::buffer result(space); std::mt19937 gen(seed); std::uniform_real_distribution dis(-1.0); std::generate(result.begin(), result.end(), [&] { return dis(gen); }); // std::fill(result.begin(), result.end(), 1); return result; } template bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) { return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { return fabs(x - y) < atol + rtol * fabs(y); }); } std::string classify(double x) { switch(std::fpclassify(x)) { case FP_INFINITE: return "inf"; case FP_NAN: return "nan"; case FP_NORMAL: return "normal"; case FP_SUBNORMAL: return "subnormal"; case FP_ZERO: return "zero"; default: return "unknown"; } } template void print_classification(const Buffer& x) { std::unordered_set result; for(const auto& i : x) result.insert(classify(i)); for(const auto& c : result) std::cout << c << ", "; std::cout << std::endl; } template void print_statistics(const Buffer& x) { std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; double num_elements = x.size(); auto mean = std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; auto stddev = std::sqrt( std::accumulate(x.begin(), x.end(), double{0.0}, [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / num_elements); std::cout << "Mean: " << mean << ", "; std::cout << "StdDev: " << stddev << "\n"; } template void print_preview(const Buffer& x) { if(x.size() <= 10) { std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); } else { std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); std::cout << "..., "; std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); } std::cout << std::endl; } template struct check_all { rtc::buffer data{}; bool operator()(const rtc::buffer& x) { if(data.empty()) { data = x; return true; } return allclose(data, x); } }; template auto report(const Solution& solution, bool pass) { return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); }