gemm_gemm.cpp 5.44 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
#include <cstring>

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
Chao Liu's avatar
Chao Liu committed
8
#include "ck/host_utility/kernel_launch.hpp"
Chao Liu's avatar
Chao Liu committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

#include "reference_gemm.hpp"
#include "gemm_gemm.hpp"

int main(int argc, char* argv[])
{
    using A0DataType   = ck::half_t;
    using B0DataType   = ck::half_t;
    using Acc0DataType = float;
    using C0DataType   = ck::half_t;
    using B1DataType   = ck::half_t;
    using Acc1DataType = float;
    using C1DataType   = ck::half_t;

    ck::index_t M0 = 13312;
    ck::index_t N0 = 4096;
    ck::index_t K0 = 128;
    ck::index_t N1 = 128;

    if(argc == 5)
    {
        M0 = std::stoi(argv[1]);
        N0 = std::stoi(argv[2]);
        K0 = std::stoi(argv[3]);
        N1 = std::stoi(argv[4]);
    }

    std::array<ck::index_t, 2> a0_lengths{M0, K0};
    std::array<ck::index_t, 2> a0_strides{K0, 1};

    std::array<ck::index_t, 2> b0_lengths{N0, K0};
    std::array<ck::index_t, 2> b0_strides{K0, 1};

    std::array<ck::index_t, 2> c0_lengths{M0, N0};
    std::array<ck::index_t, 2> c0_strides{N0, 1};

    std::array<ck::index_t, 2> b1_lengths{N1, N0};
    std::array<ck::index_t, 2> b1_strides{N0, 1};

    std::array<ck::index_t, 2> c1_lengths{M0, N1};
    std::array<ck::index_t, 2> c1_strides{N1, 1};

    // host verify
    Tensor<A0DataType> a0_host(a0_lengths, a0_strides);
    Tensor<B0DataType> b0_host(b0_lengths, b0_strides);
    Tensor<B1DataType> b1_host(b1_lengths, b1_strides);
    Tensor<C0DataType> c0_host_ref(c0_lengths, c0_strides);
    Tensor<C1DataType> c1_host_ref(c1_lengths, c1_strides);
    Tensor<C1DataType> c1_host_dev(c1_lengths, c1_strides);

    ck::utils::FillUniformDistributionIntegerValue<A0DataType>{-3.f, 3.f}(a0_host);
    ck::utils::FillUniformDistributionIntegerValue<B0DataType>{-3.f, 3.f}(b0_host);
    ck::utils::FillUniformDistributionIntegerValue<B1DataType>{-3.f, 3.f}(b1_host);

    // reference gemm
    reference_gemm<A0DataType, B0DataType, C0DataType, float>(a0_host, b0_host, c0_host_ref);
    reference_gemm<C0DataType, B1DataType, C1DataType, float>(c0_host_ref, b1_host, c1_host_ref);

    DeviceMem a0_buf(sizeof(A0DataType) * a0_host.GetElementSpaceSize());
    DeviceMem b0_buf(sizeof(B0DataType) * b0_host.GetElementSpaceSize());
    DeviceMem b1_buf(sizeof(B1DataType) * b1_host.GetElementSpaceSize());
    DeviceMem c1_buf(sizeof(C1DataType) * c1_host_ref.GetElementSpaceSize());

    a0_buf.ToDevice(a0_host.mData.data());
    b0_buf.ToDevice(b0_host.mData.data());
    b1_buf.ToDevice(b1_host.mData.data());

    constexpr ck::index_t kM0PerBlock = 128;
    constexpr ck::index_t kN0PerBlock = 128;
    constexpr ck::index_t kK0PerBlock = 32;
    constexpr ck::index_t kN1PerBlock = 128;

    constexpr ck::index_t kBlockSize = 256;
    ck::index_t kGridSize            = (M0 / kM0PerBlock) * (N1 / kN1PerBlock);

    std::cout << "grid size " << kGridSize << std::endl;

Chao Liu's avatar
Chao Liu committed
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
    float ave_time =
        launch_kernel<kBlockSize, 2>(StreamConfig{nullptr, true},
                                     GemmGemm<A0DataType,
                                              B0DataType,
                                              Acc0DataType,
                                              C0DataType,
                                              B1DataType,
                                              Acc1DataType,
                                              C1DataType,
                                              kBlockSize,
                                              kM0PerBlock,
                                              kN0PerBlock,
                                              kK0PerBlock,
                                              kN1PerBlock>{},
                                     kGridSize,
Chao Liu's avatar
Chao Liu committed
107
                                     kBlockSize,
Chao Liu's avatar
Chao Liu committed
108
109
110
111
112
113
114
115
116
117
118
119
120
                                     0,
                                     static_cast<A0DataType*>(a0_buf.GetDeviceBuffer()),
                                     static_cast<B0DataType*>(b0_buf.GetDeviceBuffer()),
                                     static_cast<B1DataType*>(b1_buf.GetDeviceBuffer()),
                                     static_cast<C1DataType*>(c1_buf.GetDeviceBuffer()),
                                     M0,
                                     N0,
                                     K0,
                                     N1,
                                     K0,  // Lda0
                                     K0,  // Ldb0
                                     N0,  // Ldb1
                                     N1); // Ldc1
Chao Liu's avatar
Chao Liu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136

    c1_buf.FromDevice(c1_host_dev.mData.data());

    std::size_t flop      = std::size_t(2) * M0 * N0 * K0 + std::size_t(2) * M0 * N1 * N0;
    std::size_t num_btype = sizeof(A0DataType) * M0 * K0 + sizeof(B0DataType) * N0 * K0 +
                            sizeof(B1DataType) * N1 * N0 + sizeof(C1DataType) * M0 * N1;

    float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

    float gb_per_sec = num_btype / 1.E6 / ave_time;

    std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
              << std::endl;

    return !ck::utils::check_err(c1_host_dev, c1_host_ref);
}