gemm_w8a8.hip 6.95 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
#include "hip/hip_runtime.h"
muyangli's avatar
muyangli committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#include "zgemm.h"
#include "gemm_w8a8.cuh"

namespace nunchaku::kernels {

void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu) {
    using GEMM = GEMM_W8A8;

    int M = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == fuse_glu ? K / 2 : K);

    assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
    assert(oscales.numel() == M * 1);

    auto launch = [&]<bool FUSE_GLU>() {
        using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;

        assert(kernel::check(M, K));
Muyang Li's avatar
Muyang Li committed
24
        dim3 grid  = kernel::gridSize(M, K);
muyangli's avatar
muyangli committed
25
26
        dim3 block = kernel::blockSize(M, K);

Muyang Li's avatar
Muyang Li committed
27
28
        auto func =
            invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
muyangli's avatar
muyangli committed
29

fengzch-das's avatar
fengzch-das committed
30
        checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, 92160));
muyangli's avatar
muyangli committed
31

fengzch-das's avatar
fengzch-das committed
32
       hipLaunchKernelGGL(( func), dim3(grid), dim3(block), kernel::smemSize(M, K), 0, input.data_ptr<GEMM::half_t>(),
Muyang Li's avatar
Muyang Li committed
33
34
35
36
                                                      output.data_ptr<GEMM::packed_act_t>(),
                                                      oscales.data_ptr<GEMM::packed_ascale_t>(),
                                                      K,
                                                      false);
fengzch-das's avatar
fengzch-das committed
37
        checkCUDA(hipGetLastError());
muyangli's avatar
muyangli committed
38
39
40
41
42
43
44
45
46
    };

    if (fuse_glu) {
        launch.template operator()<true>();
    } else {
        launch.template operator()<false>();
    }
}

Muyang Li's avatar
Muyang Li committed
47
48
49
50
51
52
void gemm_w8a8(Tensor act,     // [M, K]
               Tensor wgt,     // [N, K]
               Tensor out,     // [M, N]
               Tensor ascales, // [1, M]
               Tensor wscales, // [1, N]
               Tensor bias) {
muyangli's avatar
muyangli committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    using GEMM = GEMM_W8A8;

    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1];
    assert(K == wgt.shape[1]);

    int actualM = 0;
    int actualN = 0;
    if (out.valid()) {
        actualM = out.numel() / out.shape[-1];
        actualN = out.shape[-1];

        assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
        assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
    }

    auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
        dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

        bool swapBlockMN = M > N * 2;
        if (swapBlockMN) {
            std::swap(grid.x, grid.y);
        }

fengzch-das's avatar
fengzch-das committed
78
79
       hipLaunchKernelGGL(( invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>)
            , dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), 0, 0, act.data_ptr<GEMM::packed_act_t>(),
Muyang Li's avatar
Muyang Li committed
80
81
82
83
84
85
86
87
88
89
                                                          wgt.data_ptr<GEMM::packed_wgt_t>(),
                                                          ascales.data_ptr<GEMM::packed_ascale_t>(),
                                                          wscales.data_ptr<GEMM::packed_wscale_t>(),
                                                          // out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
                                                          M,
                                                          N,
                                                          K,
                                                          args,
                                                          swapBlockMN,
                                                          false);
fengzch-das's avatar
fengzch-das committed
90
        checkCUDA(hipGetLastError());
muyangli's avatar
muyangli committed
91
92
93
94
95
96
97
98
99
    };

    auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
        if (!bias.valid()) {
            return launch.template operator()<NextEpilogue>(nextArgs);
        }

        assert(bias.numel() == N);

Muyang Li's avatar
Muyang Li committed
100
101
        // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
        // Windows
muyangli's avatar
muyangli committed
102
        // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
103
        using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
Muyang Li's avatar
Muyang Li committed
104
105
106
107
108
        return launch.template operator()<Epilogue>({GEMM::EpilogueBias<true, false>::Arguments{
                                                         .bias = bias.data_ptr<GEMM::packed_wscale_t>(),
                                                     },
                                                     nextArgs,
                                                     {}});
muyangli's avatar
muyangli committed
109
110
111
    };

    launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{
Muyang Li's avatar
Muyang Li committed
112
        .out     = out.data_ptr<GEMM::half_t>(),
muyangli's avatar
muyangli committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        .actualM = actualM,
        .actualN = actualN,
    });
}

#if 0
void gemm_w8a8_fuse_litela(
    Tensor act,      // [B, (M), K]
    Tensor wgt,      // [N, K]
    Tensor out_q,    // [B, (M), N / 3]
    Tensor out_vk,   // [B, num_heads, head_dim + 1, head_dim]
    Tensor ascales,  // [1, M]
    Tensor wscales   // [1, N]
) {
    using GEMM = GEMM_W8A8;
    using Epilogue = GEMM::EpilogueLiteLA;

    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1];
    assert(K == wgt.shape[1]);

    assert(out_vk.ndims() == 4);
    assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
    assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
    assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);

    int batch_size = out_vk.shape[0];
    int num_heads = out_vk.shape[1];

    assert(M % batch_size == 0);
    int batch_m = M / batch_size;

    Epilogue::Arguments epilogueArgs;
    epilogueArgs.batch_m = act.shape[1];
    epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
    epilogueArgs.out_vk = out_vk.data_ptr<float>();

fengzch-das's avatar
fengzch-das committed
151
    checkCUDA(hipMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
muyangli's avatar
muyangli committed
152

Muyang Li's avatar
Muyang Li committed
153
154
155
    auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
        const GEMM::packed_act_t *,
        const GEMM::packed_wgt_t *,
muyangli's avatar
muyangli committed
156
157
158
159
160
161
162
163
        const GEMM::packed_ascale_t *,
        const GEMM::packed_wscale_t *,
        // GEMM::half_t *,
        int, int, int,
        Epilogue::Arguments,
        bool,
        bool>;

fengzch-das's avatar
fengzch-das committed
164
    checkCUDA(hipFuncSetAttribute(func, hipFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
muyangli's avatar
muyangli committed
165
166
167
168
169
170
171
172

    dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

    bool swapBlockMN = M > N * 2;
    if (swapBlockMN) {
        std::swap(grid.x, grid.y);
    }

fengzch-das's avatar
fengzch-das committed
173
   hipLaunchKernelGGL(( func), dim3(grid), dim3(GEMM::WARP_SIZE * GEMM::NUM_WARPS), Epilogue::SHMEM_SIZE, 0, 
muyangli's avatar
muyangli committed
174
175
176
177
178
        act.data_ptr<GEMM::packed_act_t>(),
        wgt.data_ptr<GEMM::packed_wgt_t>(),
        ascales.data_ptr<GEMM::packed_ascale_t>(),
        wscales.data_ptr<GEMM::packed_wscale_t>(),
        // nullptr,
Muyang Li's avatar
Muyang Li committed
179
        M, N, K, epilogueArgs,
muyangli's avatar
muyangli committed
180
181
182
        swapBlockMN,
        false
    );
fengzch-das's avatar
fengzch-das committed
183
    checkCUDA(hipGetLastError());
muyangli's avatar
muyangli committed
184

fengzch-das's avatar
fengzch-das committed
185
   hipLaunchKernelGGL(( invoke_kernel<Epilogue::vk_mul_q_kernel>), dim3(dim3(batch_m / 128, num_heads, batch_size)), dim3(128), 0, 0, 
muyangli's avatar
muyangli committed
186
187
188
189
        out_q.data_ptr<GEMM::half_t>(),
        out_vk.data_ptr<float>(),
        1e-6f
    );
fengzch-das's avatar
fengzch-das committed
190
    checkCUDA(hipGetLastError());
muyangli's avatar
muyangli committed
191
192
193
}
#endif

Muyang Li's avatar
Muyang Li committed
194
}; // namespace nunchaku::kernels