gemm_w8a8.cu 6.8 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include "zgemm.h"
#include "gemm_w8a8.cuh"

namespace nunchaku::kernels {

void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu) {
    using GEMM = GEMM_W8A8;

    int M = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == fuse_glu ? K / 2 : K);

    assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
    assert(oscales.numel() == M * 1);

    auto launch = [&]<bool FUSE_GLU>() {
        using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;

        assert(kernel::check(M, K));
Muyang Li's avatar
Muyang Li committed
23
        dim3 grid  = kernel::gridSize(M, K);
muyangli's avatar
muyangli committed
24
25
        dim3 block = kernel::blockSize(M, K);

Muyang Li's avatar
Muyang Li committed
26
27
        auto func =
            invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
muyangli's avatar
muyangli committed
28
29
30

        checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));

Muyang Li's avatar
Muyang Li committed
31
32
33
34
35
        func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(),
                                                      output.data_ptr<GEMM::packed_act_t>(),
                                                      oscales.data_ptr<GEMM::packed_ascale_t>(),
                                                      K,
                                                      false);
muyangli's avatar
muyangli committed
36
37
38
39
40
41
42
43
44
45
        checkCUDA(cudaGetLastError());
    };

    if (fuse_glu) {
        launch.template operator()<true>();
    } else {
        launch.template operator()<false>();
    }
}

Muyang Li's avatar
Muyang Li committed
46
47
48
49
50
51
void gemm_w8a8(Tensor act,     // [M, K]
               Tensor wgt,     // [N, K]
               Tensor out,     // [M, N]
               Tensor ascales, // [1, M]
               Tensor wscales, // [1, N]
               Tensor bias) {
muyangli's avatar
muyangli committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    using GEMM = GEMM_W8A8;

    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1];
    assert(K == wgt.shape[1]);

    int actualM = 0;
    int actualN = 0;
    if (out.valid()) {
        actualM = out.numel() / out.shape[-1];
        actualN = out.shape[-1];

        assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
        assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
    }

    auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
        dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

        bool swapBlockMN = M > N * 2;
        if (swapBlockMN) {
            std::swap(grid.x, grid.y);
        }

Muyang Li's avatar
Muyang Li committed
77
78
79
80
81
82
83
84
85
86
87
88
        invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
            <<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(act.data_ptr<GEMM::packed_act_t>(),
                                                          wgt.data_ptr<GEMM::packed_wgt_t>(),
                                                          ascales.data_ptr<GEMM::packed_ascale_t>(),
                                                          wscales.data_ptr<GEMM::packed_wscale_t>(),
                                                          // out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
                                                          M,
                                                          N,
                                                          K,
                                                          args,
                                                          swapBlockMN,
                                                          false);
muyangli's avatar
muyangli committed
89
90
91
92
93
94
95
96
97
98
        checkCUDA(cudaGetLastError());
    };

    auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
        if (!bias.valid()) {
            return launch.template operator()<NextEpilogue>(nextArgs);
        }

        assert(bias.numel() == N);

Muyang Li's avatar
Muyang Li committed
99
100
        // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
        // Windows
muyangli's avatar
muyangli committed
101
        // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
102
        using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
Muyang Li's avatar
Muyang Li committed
103
104
105
106
107
        return launch.template operator()<Epilogue>({GEMM::EpilogueBias<true, false>::Arguments{
                                                         .bias = bias.data_ptr<GEMM::packed_wscale_t>(),
                                                     },
                                                     nextArgs,
                                                     {}});
muyangli's avatar
muyangli committed
108
109
110
    };

    launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{
Muyang Li's avatar
Muyang Li committed
111
        .out     = out.data_ptr<GEMM::half_t>(),
muyangli's avatar
muyangli committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        .actualM = actualM,
        .actualN = actualN,
    });
}

#if 0
void gemm_w8a8_fuse_litela(
    Tensor act,      // [B, (M), K]
    Tensor wgt,      // [N, K]
    Tensor out_q,    // [B, (M), N / 3]
    Tensor out_vk,   // [B, num_heads, head_dim + 1, head_dim]
    Tensor ascales,  // [1, M]
    Tensor wscales   // [1, N]
) {
    using GEMM = GEMM_W8A8;
    using Epilogue = GEMM::EpilogueLiteLA;

    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1];
    assert(K == wgt.shape[1]);

    assert(out_vk.ndims() == 4);
    assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
    assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
    assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);

    int batch_size = out_vk.shape[0];
    int num_heads = out_vk.shape[1];

    assert(M % batch_size == 0);
    int batch_m = M / batch_size;

    Epilogue::Arguments epilogueArgs;
    epilogueArgs.batch_m = act.shape[1];
    epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
    epilogueArgs.out_vk = out_vk.data_ptr<float>();

    checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));

Muyang Li's avatar
Muyang Li committed
152
153
154
    auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
        const GEMM::packed_act_t *,
        const GEMM::packed_wgt_t *,
muyangli's avatar
muyangli committed
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        const GEMM::packed_ascale_t *,
        const GEMM::packed_wscale_t *,
        // GEMM::half_t *,
        int, int, int,
        Epilogue::Arguments,
        bool,
        bool>;

    checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));

    dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

    bool swapBlockMN = M > N * 2;
    if (swapBlockMN) {
        std::swap(grid.x, grid.y);
    }

    func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
        act.data_ptr<GEMM::packed_act_t>(),
        wgt.data_ptr<GEMM::packed_wgt_t>(),
        ascales.data_ptr<GEMM::packed_ascale_t>(),
        wscales.data_ptr<GEMM::packed_wscale_t>(),
        // nullptr,
Muyang Li's avatar
Muyang Li committed
178
        M, N, K, epilogueArgs,
muyangli's avatar
muyangli committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        swapBlockMN,
        false
    );
    checkCUDA(cudaGetLastError());

    invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
        out_q.data_ptr<GEMM::half_t>(),
        out_vk.data_ptr<float>(),
        1e-6f
    );
    checkCUDA(cudaGetLastError());
}
#endif

Muyang Li's avatar
Muyang Li committed
193
}; // namespace nunchaku::kernels