gemm_w8a8.cu 6.8 KB
Newer Older
muyangli's avatar
muyangli committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include "zgemm.h"
#include "gemm_w8a8.cuh"

namespace nunchaku::kernels {

void quantize_w8a8_act(Tensor input, Tensor output, Tensor oscales, bool fuse_glu) {
    using GEMM = GEMM_W8A8;

    int M = input.numel() / input.shape[-1];
    int K = input.shape[-1];

    assert(output.dtype() == Tensor::INT8);
    assert(output.numel() / output.shape[-1] == M);
    assert(output.shape[-1] == fuse_glu ? K / 2 : K);

    assert(isTypeMatch<GEMM::half_t>(oscales.dtype()));
    assert(oscales.numel() == M * 1);

    auto launch = [&]<bool FUSE_GLU>() {
        using kernel = GEMM::quantize_w8a8_act_kernel<FUSE_GLU>;

        assert(kernel::check(M, K));
Muyang Li's avatar
Muyang Li committed
23
        dim3 grid  = kernel::gridSize(M, K);
muyangli's avatar
muyangli committed
24
25
        dim3 block = kernel::blockSize(M, K);

Muyang Li's avatar
Muyang Li committed
26
27
        auto func =
            invoke_kernel<kernel, const GEMM::half_t *, GEMM::packed_act_t *, GEMM::packed_ascale_t *, int, bool>;
muyangli's avatar
muyangli committed
28

fengzch-das's avatar
fengzch-das committed
29
        checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, 92160));
muyangli's avatar
muyangli committed
30

fengzch-das's avatar
fengzch-das committed
31
        func<<<grid, block, kernel::smemSize(M, K)>>>(input.data_ptr<GEMM::half_t>(),
Muyang Li's avatar
Muyang Li committed
32
33
34
35
                                                      output.data_ptr<GEMM::packed_act_t>(),
                                                      oscales.data_ptr<GEMM::packed_ascale_t>(),
                                                      K,
                                                      false);
fengzch-das's avatar
fengzch-das committed
36
        checkCUDA(cudaGetLastError());
muyangli's avatar
muyangli committed
37
38
39
40
41
42
43
44
45
    };

    if (fuse_glu) {
        launch.template operator()<true>();
    } else {
        launch.template operator()<false>();
    }
}

Muyang Li's avatar
Muyang Li committed
46
47
48
49
50
51
void gemm_w8a8(Tensor act,     // [M, K]
               Tensor wgt,     // [N, K]
               Tensor out,     // [M, N]
               Tensor ascales, // [1, M]
               Tensor wscales, // [1, N]
               Tensor bias) {
muyangli's avatar
muyangli committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    using GEMM = GEMM_W8A8;

    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1];
    assert(K == wgt.shape[1]);

    int actualM = 0;
    int actualN = 0;
    if (out.valid()) {
        actualM = out.numel() / out.shape[-1];
        actualN = out.shape[-1];

        assert(actualM <= M && M - actualM < GEMM::BLOCK_M);
        assert(actualN <= N && N - actualN < GEMM::BLOCK_N);
    }

    auto launch = [&]<typename Epilogue>(Epilogue::Arguments args) {
        dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

        bool swapBlockMN = M > N * 2;
        if (swapBlockMN) {
            std::swap(grid.x, grid.y);
        }

fengzch-das's avatar
fengzch-das committed
77
78
        invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>>
            <<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS>>>(act.data_ptr<GEMM::packed_act_t>(),
Muyang Li's avatar
Muyang Li committed
79
80
81
82
83
84
85
86
87
88
                                                          wgt.data_ptr<GEMM::packed_wgt_t>(),
                                                          ascales.data_ptr<GEMM::packed_ascale_t>(),
                                                          wscales.data_ptr<GEMM::packed_wscale_t>(),
                                                          // out.valid() ? out.data_ptr<GEMM::half_t>() : nullptr,
                                                          M,
                                                          N,
                                                          K,
                                                          args,
                                                          swapBlockMN,
                                                          false);
fengzch-das's avatar
fengzch-das committed
89
        checkCUDA(cudaGetLastError());
muyangli's avatar
muyangli committed
90
91
92
93
94
95
96
97
98
    };

    auto launch_bias = [&]<typename NextEpilogue>(NextEpilogue::Arguments nextArgs) {
        if (!bias.valid()) {
            return launch.template operator()<NextEpilogue>(nextArgs);
        }

        assert(bias.numel() == N);

Muyang Li's avatar
Muyang Li committed
99
100
        // append EpilgoueNop to workaround mismatched memory layout of std::tuple between device and host code on
        // Windows
muyangli's avatar
muyangli committed
101
        // ** sizeof(std::tuple<std::tuple<int>>) == 8 on device **
102
        using Epilogue = GEMM::EpilogueCombination<GEMM::EpilogueBias<true, false>, NextEpilogue, GEMM::EpilogueNop>;
Muyang Li's avatar
Muyang Li committed
103
104
105
106
107
        return launch.template operator()<Epilogue>({GEMM::EpilogueBias<true, false>::Arguments{
                                                         .bias = bias.data_ptr<GEMM::packed_wscale_t>(),
                                                     },
                                                     nextArgs,
                                                     {}});
muyangli's avatar
muyangli committed
108
109
110
    };

    launch_bias.template operator()<GEMM::EpilogueDefault>(GEMM::EpilogueDefault::Arguments{
Muyang Li's avatar
Muyang Li committed
111
        .out     = out.data_ptr<GEMM::half_t>(),
muyangli's avatar
muyangli committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        .actualM = actualM,
        .actualN = actualN,
    });
}

#if 0
void gemm_w8a8_fuse_litela(
    Tensor act,      // [B, (M), K]
    Tensor wgt,      // [N, K]
    Tensor out_q,    // [B, (M), N / 3]
    Tensor out_vk,   // [B, num_heads, head_dim + 1, head_dim]
    Tensor ascales,  // [1, M]
    Tensor wscales   // [1, N]
) {
    using GEMM = GEMM_W8A8;
    using Epilogue = GEMM::EpilogueLiteLA;

    int M = act.numel() / act.shape[-1];
    int N = wgt.shape[0];
    int K = act.shape[-1];
    assert(K == wgt.shape[1]);

    assert(out_vk.ndims() == 4);
    assert(out_vk.shape[2] == Epilogue::LITELA_HEAD_DIM + 1);
    assert(out_vk.shape[3] == Epilogue::LITELA_HEAD_DIM);
    assert(out_vk.shape[1] * Epilogue::LITELA_HEAD_DIM * 3 == N);

    int batch_size = out_vk.shape[0];
    int num_heads = out_vk.shape[1];

    assert(M % batch_size == 0);
    int batch_m = M / batch_size;

    Epilogue::Arguments epilogueArgs;
    epilogueArgs.batch_m = act.shape[1];
    epilogueArgs.out_q = out_q.data_ptr<GEMM::half_t>();
    epilogueArgs.out_vk = out_vk.data_ptr<float>();

fengzch-das's avatar
fengzch-das committed
150
    checkCUDA(cudaMemsetAsync(out_vk.data_ptr(), 0, out_vk.buffer->getSize()));
muyangli's avatar
muyangli committed
151

Muyang Li's avatar
Muyang Li committed
152
153
154
    auto func = invoke_kernel<GEMM::gemm_w8a8_kernel<Epilogue>,
        const GEMM::packed_act_t *,
        const GEMM::packed_wgt_t *,
muyangli's avatar
muyangli committed
155
156
157
158
159
160
161
162
        const GEMM::packed_ascale_t *,
        const GEMM::packed_wscale_t *,
        // GEMM::half_t *,
        int, int, int,
        Epilogue::Arguments,
        bool,
        bool>;

fengzch-das's avatar
fengzch-das committed
163
    checkCUDA(cudaFuncSetAttribute(func, cudaFuncAttributeMaxDynamicSharedMemorySize, Epilogue::SHMEM_SIZE));
muyangli's avatar
muyangli committed
164
165
166
167
168
169
170
171

    dim3 grid(M / GEMM::BLOCK_M, N / GEMM::BLOCK_N);

    bool swapBlockMN = M > N * 2;
    if (swapBlockMN) {
        std::swap(grid.x, grid.y);
    }

fengzch-das's avatar
fengzch-das committed
172
    func<<<grid, GEMM::WARP_SIZE * GEMM::NUM_WARPS, Epilogue::SHMEM_SIZE>>>(
muyangli's avatar
muyangli committed
173
174
175
176
177
        act.data_ptr<GEMM::packed_act_t>(),
        wgt.data_ptr<GEMM::packed_wgt_t>(),
        ascales.data_ptr<GEMM::packed_ascale_t>(),
        wscales.data_ptr<GEMM::packed_wscale_t>(),
        // nullptr,
Muyang Li's avatar
Muyang Li committed
178
        M, N, K, epilogueArgs,
muyangli's avatar
muyangli committed
179
180
181
        swapBlockMN,
        false
    );
fengzch-das's avatar
fengzch-das committed
182
    checkCUDA(cudaGetLastError());
muyangli's avatar
muyangli committed
183

fengzch-das's avatar
fengzch-das committed
184
    invoke_kernel<Epilogue::vk_mul_q_kernel><<<dim3(batch_m / 128, num_heads, batch_size), 128>>>(
muyangli's avatar
muyangli committed
185
186
187
188
        out_q.data_ptr<GEMM::half_t>(),
        out_vk.data_ptr<float>(),
        1e-6f
    );
fengzch-das's avatar
fengzch-das committed
189
    checkCUDA(cudaGetLastError());
muyangli's avatar
muyangli committed
190
191
192
}
#endif

Muyang Li's avatar
Muyang Li committed
193
}; // namespace nunchaku::kernels