gemm_multiple_d.cpp 2.41 KB
Newer Older
Paul's avatar
Format  
Paul committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include "ck/host/device_gemm_multiple_d/problem.hpp"
#include "ck/host/device_gemm_multiple_d/operation.hpp"
#include "ck/host/headers.hpp"
#include "ck/host/stringutils.hpp"
#include <algorithm>
#include <iterator>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>

std::vector<rtc::src_file> get_headers_for_test()
{
    std::vector<rtc::src_file> result;
    auto hs = ck::host::GetHeaders();
    std::transform(
        hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
            auto s = p.second;
            std::string content{s.first, s.second};
            return {p.first, content};
        });
    return result;
}

Paul's avatar
Format  
Paul committed
23
24
25
const std::string gemm_compile_check = R"__ck__(
#include <${include}>

Paul's avatar
Paul committed
26
extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) {
Paul's avatar
Format  
Paul committed
27
28
    using G = ${template};
    constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})),
Paul's avatar
Paul committed
29
                                             ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${n}, ${k})),
Paul's avatar
Format  
Paul committed
30
                                             ck::make_tuple(),
Paul's avatar
Paul committed
31
                                             ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n})));
Paul's avatar
Format  
Paul committed
32
33
34

    static_assert(desc.IsValid(), "Invalid ck gemm.");

Paul's avatar
Format  
Paul committed
35
36
37
38
39
40
41
42
    if constexpr(desc.IsValid())
    {
        ${template}::Run(desc,
               a,
               b,
               ck::make_tuple(),
               c);
    }
Paul's avatar
Format  
Paul committed
43
44
45
46
47
}

)__ck__";

TEST_CASE(test_problem_kernel)
Paul's avatar
Format  
Paul committed
48
49
{
    ck::host::device_gemm_multiple_d::Problem prob;
Paul's avatar
Format  
Paul committed
50
51
52
53
    prob.M = 256;
    prob.N = 256;
    prob.K = 256;
    for(auto solution : prob.GetSolutions("gfx90a"))
Paul's avatar
Format  
Paul committed
54
    {
Paul's avatar
Paul committed
55
        auto src  = ck::host::InterpolateString(gemm_compile_check,
Paul's avatar
Format  
Paul committed
56
57
58
59
60
                                               {{"include", prob.GetIncludeHeader()},
                                                {"template", solution.ToTemplateString()},
                                                {"m", std::to_string(prob.M)},
                                                {"n", std::to_string(prob.N)},
                                                {"k", std::to_string(prob.K)}});
Paul's avatar
Format  
Paul committed
61
62
63
64
65
66
67
68
69
        auto srcs = get_headers_for_test();
        srcs.push_back({"main.cpp", src});
        rtc::compile_options options;
        options.kernel_name = "f";
        rtc::compile_kernel(srcs, options);
    }
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }