"...composable_kernel_rocm.git" did not exist on "3dc5db7270f3e2129d05fac756a45add2b8165f9"
gemm.cpp 5.33 KB
Newer Older
Chao Liu's avatar
Chao Liu committed
1
2
3
4
5
6
7
#include <cstring>

#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor/tensor_view.hpp"
#include "ck/host_utility/device_prop.hpp"
Chao Liu's avatar
Chao Liu committed
8
#include "ck/host_utility/kernel_launch.hpp"
Chao Liu's avatar
Chao Liu committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"

#include "reference_gemm.hpp"
#include "gemm.hpp"

// elementwise lambda
struct AElementFunction
{
    template <typename X>
    __host__ __device__ auto operator()(const X& x) const
    {
        return x;
    }
};

struct BElementFunction
{
    template <typename X>
    __host__ __device__ auto operator()(const X& x) const
    {
        return x;
    }
};

struct CElementFunction
{
    template <typename X>
    __host__ __device__ auto operator()(const X& x) const
    {
        return x;
    }
};

int main(int argc, char* argv[])
{
    using ADataType   = ck::half_t;
    using BDataType   = ck::half_t;
    using AccDataType = float;
    using CDataType   = ck::half_t;

    ck::index_t M = 3328;
    ck::index_t N = 4096;
    ck::index_t K = 4096;

    if(argc == 4)
    {
        M = std::stoi(argv[1]);
        N = std::stoi(argv[2]);
        K = std::stoi(argv[3]);
    }

    std::array<ck::index_t, 2> a_lengths{M, K};
    std::array<ck::index_t, 2> a_strides{K, 1};

    std::array<ck::index_t, 2> b_lengths{N, K};
    std::array<ck::index_t, 2> b_strides{K, 1};

    std::array<ck::index_t, 2> c_lengths{M, N};
    std::array<ck::index_t, 2> c_strides{N, 1};

    // host verify
    Tensor<ADataType> a_host(a_lengths, a_strides);
    Tensor<BDataType> b_host(b_lengths, b_strides);
    Tensor<CDataType> c_host_ref(c_lengths, c_strides);
    Tensor<CDataType> c_host_dev(c_lengths, c_strides);

    ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_host);
    ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_host);

    // reference gemm
Chao Liu's avatar
Chao Liu committed
84
    reference_gemm<ADataType, ADataType, AccDataType, CDataType>(a_host, b_host, c_host_ref);
Chao Liu's avatar
Chao Liu committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101

    DeviceMem a_buf(sizeof(ADataType) * a_host.GetElementSpaceSize());
    DeviceMem b_buf(sizeof(BDataType) * b_host.GetElementSpaceSize());
    DeviceMem c_buf(sizeof(CDataType) * c_host_dev.GetElementSpaceSize());

    a_buf.ToDevice(a_host.mData.data());
    b_buf.ToDevice(b_host.mData.data());

    constexpr ck::index_t kGemmMPerBlock = 256;
    constexpr ck::index_t kGemmNPerBlock = 128;
    constexpr ck::index_t kGemmKPerBlock = 32;

    constexpr ck::index_t kBlockSize = 256;
    ck::index_t kGridSize            = (M / kGemmMPerBlock) * (N / kGemmNPerBlock);

    std::cout << "grid size " << kGridSize << std::endl;

Chao Liu's avatar
Chao Liu committed
102
103
104
105
    constexpr ck::index_t kWarpPerCu    = 8; // 2 warps per SIMD
    constexpr ck::index_t kWarpPerBlock = kBlockSize / warpSize;
    constexpr ck::index_t kBlockPerCu   = kWarpPerCu / kWarpPerBlock;

Chao Liu's avatar
Chao Liu committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    const auto gemm_kernel = Gemm<ADataType,
                                  BDataType,
                                  AccDataType,
                                  CDataType,
                                  ck::tensor_layout::gemm::RowMajor,
                                  ck::tensor_layout::gemm::ColumnMajor,
                                  ck::tensor_layout::gemm::RowMajor,
                                  AElementFunction,
                                  BElementFunction,
                                  CElementFunction,
                                  kBlockSize,
                                  kGemmMPerBlock,
                                  kGemmNPerBlock,
                                  kGemmKPerBlock>{};

Chao Liu's avatar
Chao Liu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    float ave_time =
        launch_kernel<kBlockSize, kBlockPerCu>(StreamConfig{nullptr, true},
                                               gemm_kernel,
                                               kGridSize,
                                               kBlockSize,
                                               0,
                                               static_cast<ADataType*>(a_buf.GetDeviceBuffer()),
                                               static_cast<BDataType*>(b_buf.GetDeviceBuffer()),
                                               static_cast<CDataType*>(c_buf.GetDeviceBuffer()),
                                               M,
                                               N,
                                               K,
                                               K,
                                               K,
                                               N,
                                               AElementFunction{},
                                               BElementFunction{},
                                               CElementFunction{});
Chao Liu's avatar
Chao Liu committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

    c_buf.FromDevice(c_host_dev.mData.data());

    std::size_t flop = std::size_t(2) * M * N * K;
    std::size_t num_btype =
        sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;

    float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

    float gb_per_sec = num_btype / 1.E6 / ave_time;

    std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
              << std::endl;

    return !ck::utils::check_err(c_host_dev, c_host_ref);
}