batched_gemm_softmax_gemm.cpp 3.71 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include "ck/host/device_batched_gemm_softmax_gemm/problem.hpp"
#include "ck/host/stringutils.hpp"
#include "ck/host/utils.hpp"
#include "common.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <cmath>

using half = _Float16;

const std::string gemm_compile_check = R"__ck__(
#include <${include}>

extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, const ck::half_t* b1, ck::half_t* c) {
    using G = ${template};
    constexpr auto desc = G::make_descriptor(ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${k}), ck::make_tuple(${m}, 1)),
                                             ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(${n}, 1)),
                                             ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${o}), ck::make_tuple(1, ${n})),
                                             ck::make_naive_tensor_descriptor(ck::make_tuple(${m}, ${o}), ck::make_tuple(${m}, 1)));

    static_assert(desc.IsValid(), "Invalid ck gemm.");

    if constexpr(desc.IsValid())
    {
        ${template}::Run(desc,
               1.0,
               a,
               b,
               b1,
               c);
    }
}

)__ck__";

TEST_CASE(test_problem_kernel)
{
    ck::host::device_batched_gemm_softmax_gemm::Problem prob;
    prob.M      = 1024;
    prob.N      = 1024;
    prob.K      = 1024;
    prob.O      = 1024;
    prob.TransB = true;
    check_all<half> check1, check2;
    auto a  = to_gpu(generate_buffer<half>(1024 * 1024, 0));
    auto b  = to_gpu(generate_buffer<half>(1024 * 1024, 1));
    auto b1 = to_gpu(generate_buffer<half>(1024 * 1024, 2));
    auto c  = to_gpu(generate_buffer<half>(1024 * 1024, 3));

    std::string epilogue = "";
    std::string prologue = "";

    auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
    std::cout << "Num solutions: " << solutions.size() << std::endl;
    for(auto i = 0; i < solutions.size(); ++i)
    {
        std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
        auto&& solution = solutions[i];
        auto src        = ck::host::InterpolateString(gemm_compile_check,
                                                      {{"include", prob.GetIncludeHeader()},
                                                       {"template", solution.ToTemplateString()},
                                                       {"m", std::to_string(prob.M)},
                                                       {"n", std::to_string(prob.N)},
                                                       {"k", std::to_string(prob.K)},
                                                       {"o", std::to_string(prob.O)}});
        auto srcs       = get_headers_for_test();
        srcs.push_back({"main.cpp", src});
        rtc::compile_options options;
        options.kernel_name = "f";
        auto k              = rtc::compile_kernel(srcs, options);
        auto block_size     = solution.GetTemplateParameter<std::size_t>("BlockSize");
        auto m_per_block    = solution.GetTemplateParameter<std::size_t>("Gemm01MPerBlock");
        auto n_per_block    = solution.GetTemplateParameter<std::size_t>("Gemm1NPerBlock");
        auto grid_size      = ck::host::integer_divide_ceil(prob.M, m_per_block) *
                         ck::host::integer_divide_ceil(prob.N, n_per_block);
        k.launch(nullptr, grid_size * block_size, block_size)(
            a.data(), b.data(), b1.data(), c.data());

        if(solution.GetTemplateParameter<bool>("MaskOutUpperTriangle"))
            CHECK(report(solution, check1(rtc::from_gpu(c))));
        else
            CHECK(report(solution, check2(rtc::from_gpu(c))));
    }
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }