gemv_awq.cu 11.8 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
/*
Muyang Li's avatar
Muyang Li committed
2
3
 * Modified from NVIDIA
 * [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
@article{lin2023awq,
  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  journal={arXiv},
  year={2023}
}
*/

#include "gemv_awq.h"
#include "../dispatch_utils.h"

#include "../utils.cuh"

#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <stdio.h>
#include "dequantize.cuh"

#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128

// Reduce sum within the warp using the tree reduction algorithm.
Muyang Li's avatar
Muyang Li committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
template<typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_smem)[Num * 4]) {
    // kInterleave = 4
    float fpsum[Num];
#pragma unroll
    for (int i = 0; i < Num; ++i) {
        fpsum[i] = static_cast<float>(psum[i]);
    }

#pragma unroll
    for (int i = 0; i < Num; ++i) {
        // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
        fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
        fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
        fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
    }
    __syncthreads();
    int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
    if (lane == 0 || lane == 2 || lane == 4 || lane == 6) {
#pragma unroll
        for (int i = 0; i < Num; ++i) {
            out_smem[warp][i * 4 + lane / 2] = fpsum[i];
        }
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
68
69
};

Muyang Li's avatar
Muyang Li committed
70
71
__device__ __forceinline__ int make_divisible(int c, int divisor) {
    return (c + divisor - 1) / divisor;
Zhekai Zhang's avatar
Zhekai Zhang committed
72
73
}

Muyang Li's avatar
Muyang Li committed
74
75
template<typename half_t>
__device__ __forceinline__ packed_as<half_t, 2>::type half2half2(half_t x);
Zhekai Zhang's avatar
Zhekai Zhang committed
76

Muyang Li's avatar
Muyang Li committed
77
78
template<>
__device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
79
80
81
    return __half2half2(x);
}

Muyang Li's avatar
Muyang Li committed
82
83
template<>
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
84
85
86
87
    return __bfloat162bfloat162(x);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
88
__device__ __forceinline__ float2 half22float2(T val);
Zhekai Zhang's avatar
Zhekai Zhang committed
89
90

template<>
Muyang Li's avatar
Muyang Li committed
91
__device__ __forceinline__ float2 half22float2<half2>(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
92
93
94
95
    return __half22float2(val);
}

template<>
Muyang Li's avatar
Muyang Li committed
96
__device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
99
    return __bfloat1622float2(val);
}

Muyang Li's avatar
Muyang Li committed
100
101
102
103
104
105
106
107
108
template<typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(const half_t *inputs,
                            const uint32_t *weight,
                            const half_t *scales,
                            const half_t *zeros,
                            half_t *outputs,
                            const int IC,
                            const int OC) {

109
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
Muyang Li's avatar
Muyang Li committed
110
    if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
111
112
113
114
        trap_unsupported_arch();
        return;
    }
#endif
Muyang Li's avatar
Muyang Li committed
115
116
    using half2_t  = typename packed_as<half_t, 2>::type;
    using accum_t  = float;
Zhekai Zhang's avatar
Zhekai Zhang committed
117
118
    using accum2_t = typename packed_as<accum_t, 2>::type;

Muyang Li's avatar
Muyang Li committed
119
120
    const int kStride            = 64;
    const int kElemsPerThread    = MEM_ACCESS_SIZE / 4;
Zhekai Zhang's avatar
Zhekai Zhang committed
121
122
123
124
125
126
    const int kThreadsNumPerTile = kStride / kElemsPerThread;
    // assert(MEM_ACCESS_SIZE == 128);

    // static constexpr int kShuffleSize = 32;
    static constexpr int kShuffleBasicTile = 2;
    static constexpr int kShuffleContinous = 4;
Muyang Li's avatar
Muyang Li committed
127
    static constexpr int kShuffleStrided   = 4;
Zhekai Zhang's avatar
Zhekai Zhang committed
128

Muyang Li's avatar
Muyang Li committed
129
    constexpr int Num         = NPerBlock * Batch;
Zhekai Zhang's avatar
Zhekai Zhang committed
130
131
132
133
    constexpr int kInterleave = 4;

    alignas(16) half_t local_inputs[kElemsPerThread];
    alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
Muyang Li's avatar
Muyang Li committed
134
    alignas(16) half_t half_weight_buffer[kElemsPerThread];
Zhekai Zhang's avatar
Zhekai Zhang committed
135
136
137
138
139
140
141
    alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
    alignas(16) half_t local_scale[NPerBlock];
    alignas(16) half_t local_scaled_zeros[NPerBlock];

    accum_t psum[Num];
    for (int i = 0; i < Num; ++i)
        psum[i] = static_cast<accum_t>(0.f);
Muyang Li's avatar
Muyang Li committed
142

143
144
145
146
    // extern __shared__ uint8_t shmem[];
    // float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);

    __shared__ float out_smem[BlockSize / WARP_SIZE * 2][Num * kInterleave];
Zhekai Zhang's avatar
Zhekai Zhang committed
147
148
149

    const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
    const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
Muyang Li's avatar
Muyang Li committed
150
151
    const int act_k_offset   = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride +
                             (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
Zhekai Zhang's avatar
Zhekai Zhang committed
152
153
    const int group_offset = act_k_offset / GroupSize;
    // TODO: use make_divisible
Muyang Li's avatar
Muyang Li committed
154
155
156
157
    const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
    const half_t *scale_ptr        = scales + blk_row_offset + thd_row_offset + group_offset * OC;
    const half_t *zeros_ptr        = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
    const half_t *inputs_ptr       = inputs + act_k_offset;
Zhekai Zhang's avatar
Zhekai Zhang committed
158

Muyang Li's avatar
Muyang Li committed
159
    const int act_forward_step   = BlockSize * kElemsPerThread / kInterleave;
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
    const int scale_forward_step = act_forward_step / GroupSize * OC;

    // Main loop iteration, each block completes the outputs for several OCs
Muyang Li's avatar
Muyang Li committed
163
164
165
166
    for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) {
// Load qweight, scales and scaled_zeros
#pragma unroll
        for (int idx = 0; idx < NPerBlock; ++idx) {
Zhekai Zhang's avatar
Zhekai Zhang committed
167
            // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
Muyang Li's avatar
Muyang Li committed
168
169
170
171
172
173
174
            *((float4 *)(local_qweights)) = *((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
            local_scale[idx]              = *(scale_ptr + idx * kInterleave);
            local_scaled_zeros[idx]       = *(zeros_ptr + idx * kInterleave);

// Map int4 qweight to fp format
#pragma unroll
            for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) {
Zhekai Zhang's avatar
Zhekai Zhang committed
175
                // Converts 32 bits (8 x int4) to 8 fp16
Muyang Li's avatar
Muyang Li committed
176
177
                dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i),
                                        reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
Zhekai Zhang's avatar
Zhekai Zhang committed
178
179
            }

Muyang Li's avatar
Muyang Li committed
180
181
182
183
184
185
186
187
188
189
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
            for (int i = 0; i < kShuffleContinous; ++i) {
#pragma unroll
                for (int j = 0; j < kShuffleStrided; ++j) {
                    half2_t w = *reinterpret_cast<half2_t *>(half_weight_buffer +
                                                             (i + j * kShuffleContinous) * kShuffleBasicTile);
                    w         = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
Zhekai Zhang's avatar
Zhekai Zhang committed
190
                }
Muyang Li's avatar
Muyang Li committed
191
192
193
194
195
196
197
            }
        }
#pragma unroll
        for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) {
            const half_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
            for (int idx = 0; idx < kElemsPerThread / 8; ++idx) {
Zhekai Zhang's avatar
Zhekai Zhang committed
198
                // load activation, 8 halves (128 bits) / step.
Muyang Li's avatar
Muyang Li committed
199
                *((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
Zhekai Zhang's avatar
Zhekai Zhang committed
200
            }
Muyang Li's avatar
Muyang Li committed
201
202
203
204
205
206
207
208
209
210
// Perform the MACs
#pragma unroll
            for (int x = 0; x < NPerBlock / 2; ++x) {
#pragma unroll
                for (int y = 0; y < kElemsPerThread; ++y) {
                    accum2_t prod = cuda_cast<accum2_t>(
                        __hmul2(*reinterpret_cast<half2_t *>(dequantized_weight + y * NPerBlock + x * 2),
                                half2half2(local_inputs[y])));
                    *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) =
                        prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
Zhekai Zhang's avatar
Zhekai Zhang committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
                    // *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
                    //     = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
                    //         half2half2(local_inputs[y]),
                    //         *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2));
                }
            }
        }
        inputs_ptr += act_forward_step;
        scale_ptr += scale_forward_step;
        zeros_ptr += scale_forward_step;
    }

    warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);

    // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
Muyang Li's avatar
Muyang Li committed
226
    for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) {
Zhekai Zhang's avatar
Zhekai Zhang committed
227
        int batch_idx = i / (NPerBlock * kInterleave);
Muyang Li's avatar
Muyang Li committed
228
229
230
        int oc_idx    = i % (NPerBlock * kInterleave);
        float acc     = 0.f;
        for (int j = 0; j < BlockSize / WARP_SIZE; ++j) {
Zhekai Zhang's avatar
Zhekai Zhang committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            acc += out_smem[j][i];
        }
        outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
    }
}

/*
Computes GEMV (PyTorch interface).

Args:
  _in_feats: tensor of shape [B, IC];
  _kernel: int tensor of shape [OC, IC // 8];
  _zeros: int tensor of shape [OC, IC // G // 8];
  _scaling_factors: tensor of shape [OC, IC // G];
  blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
  blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;

Returns:
  out_feats: tensor of shape [B, OC];
*/
Tensor gemv_awq(
Muyang Li's avatar
Muyang Li committed
252
    Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size) {
253
254
    return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
        assert(isTypeMatch<half_t>(_in_feats.dtype()));
Zhekai Zhang's avatar
Zhekai Zhang committed
255

Muyang Li's avatar
Muyang Li committed
256
        auto output_shape   = _in_feats.shape.dataExtent;
257
        output_shape.back() = n;
Zhekai Zhang's avatar
Zhekai Zhang committed
258

Muyang Li's avatar
Muyang Li committed
259
260
261
262
        auto in_feats        = reinterpret_cast<half_t *>(_in_feats.data_ptr<half_t>());
        auto kernel          = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
        auto zeros           = reinterpret_cast<half_t *>(_zeros.data_ptr<half_t>());
        auto scaling_factors = reinterpret_cast<half_t *>(_scaling_factors.data_ptr<half_t>());
Zhekai Zhang's avatar
Zhekai Zhang committed
263

264
        Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
Muyang Li's avatar
Muyang Li committed
265
266
267
        half_t *out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());

        static constexpr int N_PER_BLOCK  = 2;
268
        static constexpr int K_INTERLEAVE = 4;
Muyang Li's avatar
Muyang Li committed
269
        static constexpr int BLOCK_SIZE   = 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
270

271
272
        dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
        dim3 num_threads(BLOCK_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
273

274
        constexpr int GROUP_SIZE = 64;
Zhekai Zhang's avatar
Zhekai Zhang committed
275

276
        assert(m > 0 && m <= 8);
277
        assert(group_size == GROUP_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
278

279
        dispatchVal(m, std::make_integer_sequence<int, 9>(), [&]<int M>() {
280
281
282
283
284
            if constexpr (M == 0) {
                assert(false);
                return;
            }
            if constexpr (M > 0) {
Muyang Li's avatar
Muyang Li committed
285
286
287
                gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
                    <<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
                        in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
288
289
290
                checkCUDA(cudaGetLastError());
            }
        });
Zhekai Zhang's avatar
Zhekai Zhang committed
291

292
        return _out_feats;
Zhekai Zhang's avatar
Zhekai Zhang committed
293
294
    });
}