"...composable_kernel_rocm.git" did not exist on "37cdbf4f0ec88ba5064f46c3370633b5950bc7ae"
gemm.cpp 6.53 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
#include <cstring>

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
Chao Liu's avatar
Chao Liu committed
8
#include "ck/host_utility/kernel_launch.hpp"
Chao Liu's avatar
Chao Liu committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

#include "reference_gemm.hpp"
#include "gemm.hpp"

// elementwise lambda
struct AElementFunction
{
    template <typename X>
    __host__ __device__ auto operator()(const X& x) const
    {
        return x;
    }
};

struct BElementFunction
{
    template <typename X>
    __host__ __device__ auto operator()(const X& x) const
    {
        return x;
    }
};

struct CElementFunction
{
    template <typename X>
    __host__ __device__ auto operator()(const X& x) const
    {
        return x;
    }
};

int main(int argc, char* argv[])
{
    using ADataType   = ck::half_t;
    using BDataType   = ck::half_t;
    using AccDataType = float;
    using CDataType   = ck::half_t;

Chao Liu's avatar
Chao Liu committed
54
55
56
57
    using ALayout = ck::tensor_layout::gemm::RowMajor;
    using BLayout = ck::tensor_layout::gemm::ColumnMajor;
    using CLayout = ck::tensor_layout::gemm::RowMajor;

Chao Liu's avatar
Chao Liu committed
58
59
60
61
62
63
64
65
66
67
68
    ck::index_t M = 3328;
    ck::index_t N = 4096;
    ck::index_t K = 4096;

    if(argc == 4)
    {
        M = std::stoi(argv[1]);
        N = std::stoi(argv[2]);
        K = std::stoi(argv[3]);
    }

Chao Liu's avatar
Chao Liu committed
69
70
71
72
73
74
75
76
    const ck::index_t Lda = std::is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor> ? K : M;
    const ck::index_t Ldb = std::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor> ? K : N;
    const ck::index_t Ldc = std::is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor> ? N : M;

    const auto a_lengths = std::array<ck::index_t, 2>{M, K};
    const auto a_strides = std::is_same_v<ALayout, ck::tensor_layout::gemm::RowMajor>
                               ? std::array<ck::index_t, 2>{Lda, 1}
                               : std::array<ck::index_t, 2>{1, Lda};
Chao Liu's avatar
Chao Liu committed
77

Chao Liu's avatar
Chao Liu committed
78
79
80
81
    const auto b_lengths = std::array<ck::index_t, 2>{N, K};
    const auto b_strides = std::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
                               ? std::array<ck::index_t, 2>{Ldb, 1}
                               : std::array<ck::index_t, 2>{1, Ldb};
Chao Liu's avatar
Chao Liu committed
82

Chao Liu's avatar
Chao Liu committed
83
84
85
86
    const auto c_lengths = std::array<ck::index_t, 2>{M, N};
    const auto c_strides = std::is_same_v<CLayout, ck::tensor_layout::gemm::RowMajor>
                               ? std::array<ck::index_t, 2>{Ldc, 1}
                               : std::array<ck::index_t, 2>{1, Ldc};
Chao Liu's avatar
Chao Liu committed
87
88
89
90
91
92
93
94
95
96
97

    // host verify
    Tensor<ADataType> a_host(a_lengths, a_strides);
    Tensor<BDataType> b_host(b_lengths, b_strides);
    Tensor<CDataType> c_host_ref(c_lengths, c_strides);
    Tensor<CDataType> c_host_dev(c_lengths, c_strides);

    ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
    ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);

    // reference gemm
Chao Liu's avatar
Chao Liu committed
98
    reference_gemm<ADataType, ADataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
Chao Liu's avatar
Chao Liu committed
99
100
101
102
103
104
105
106

    DeviceMem a_buf(sizeof(ADataType) * a_host.GetElementSpaceSize());
    DeviceMem b_buf(sizeof(BDataType) * b_host.GetElementSpaceSize());
    DeviceMem c_buf(sizeof(CDataType) * c_host_dev.GetElementSpaceSize());

    a_buf.ToDevice(a_host.mData.data());
    b_buf.ToDevice(b_host.mData.data());

Chao Liu's avatar
Chao Liu committed
107
108
109
110
111
112
113
    // Alignment
    constexpr ck::index_t kAAlignment = 32;
    constexpr ck::index_t kBAlignment = 32;
    constexpr ck::index_t kCAlignment = 32;

    constexpr ck::index_t kBlockSize = 256;

Chao Liu's avatar
Chao Liu committed
114
115
116
117
    constexpr ck::index_t kGemmMPerBlock = 256;
    constexpr ck::index_t kGemmNPerBlock = 128;
    constexpr ck::index_t kGemmKPerBlock = 32;

Chao Liu's avatar
Chao Liu committed
118
    ck::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);
Chao Liu's avatar
Chao Liu committed
119
120
121

    std::cout << "grid size " << kGridSize << std::endl;

Chao Liu's avatar
Chao Liu committed
122
123
124
125
    constexpr ck::index_t kWarpPerCu    = 8; // 2 warps per SIMD
    constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
    constexpr ck::index_t kBlockPerCu   = kWarpPerCu / kWarpPerBlock;

Chao Liu's avatar
Chao Liu committed
126
127
128
129
    const auto gemm_kernel = Gemm<ADataType,
                                  BDataType,
                                  AccDataType,
                                  CDataType,
Chao Liu's avatar
Chao Liu committed
130
131
132
                                  ALayout,
                                  BLayout,
                                  CLayout,
Chao Liu's avatar
Chao Liu committed
133
134
135
                                  AElementFunction,
                                  BElementFunction,
                                  CElementFunction,
Chao Liu's avatar
Chao Liu committed
136
137
138
                                  kAAlignment,
                                  kBAlignment,
                                  kCAlignment,
Chao Liu's avatar
Chao Liu committed
139
140
141
142
143
                                  kBlockSize,
                                  kGemmMPerBlock,
                                  kGemmNPerBlock,
                                  kGemmKPerBlock>{};

Chao Liu's avatar
Chao Liu committed
144
145
146
147
148
149
150
151
152
153
154
155
    float ave_time =
        launch_kernel<kBlockSize, kBlockPerCu>(StreamConfig{nullptr, true},
                                               gemm_kernel,
                                               kGridSize,
                                               kBlockSize,
                                               0,
                                               static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
                                               static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
                                               static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
                                               M,
                                               N,
                                               K,
Chao Liu's avatar
Chao Liu committed
156
157
158
                                               Lda,
                                               Ldb,
                                               Ldc,
Chao Liu's avatar
Chao Liu committed
159
160
161
                                               AElementFunction{},
                                               BElementFunction{},
                                               CElementFunction{});
Chao Liu's avatar
Chao Liu committed
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177

    c_buf.FromDevice(c_host_dev.mData.data());

    std::size_t flop = std::size_t(2) * M * N * K;
    std::size_t num_btype =
        sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;

    float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

    float gb_per_sec = num_btype / 1.E6 / ave_time;

    std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
              << std::endl;

    return !ck::utils::check_err(c_host_dev, c_host_ref);
}