gemv_awq.cu 11.8 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
/*
Muyang Li's avatar
Muyang Li committed
2
3
 * Modified from NVIDIA
 * [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
@article{lin2023awq,
  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  journal={arXiv},
  year={2023}
}
*/

#include "gemv_awq.h"
#include "../dispatch_utils.h"

#include "../utils.cuh"

fengzch-das's avatar
fengzch-das committed
33
34
#include <cuda_fp16.h>
#include <cuda_bf16.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
35
36
37
38
39
40
41
42
#include <stdio.h>
#include "dequantize.cuh"

#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128

// Reduce sum within the warp using the tree reduction algorithm.
Muyang Li's avatar
Muyang Li committed
43
44
45
46
47
48
49
50
51
52
53
54
template<typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_smem)[Num * 4]) {
    // kInterleave = 4
    float fpsum[Num];
#pragma unroll
    for (int i = 0; i < Num; ++i) {
        fpsum[i] = static_cast<float>(psum[i]);
    }

#pragma unroll
    for (int i = 0; i < Num; ++i) {
        // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
limm's avatar
limm committed
55
56
57
        fpsum[i] += __shfl_xor(fpsum[i], 16);
        fpsum[i] += __shfl_xor(fpsum[i], 8);
        fpsum[i] += __shfl_xor(fpsum[i], 1);
Muyang Li's avatar
Muyang Li committed
58
59
60
61
62
63
64
65
66
67
    }
    __syncthreads();
    int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
    if (lane == 0 || lane == 2 || lane == 4 || lane == 6) {
#pragma unroll
        for (int i = 0; i < Num; ++i) {
            out_smem[warp][i * 4 + lane / 2] = fpsum[i];
        }
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
68
69
};

Muyang Li's avatar
Muyang Li committed
70
71
__device__ __forceinline__ int make_divisible(int c, int divisor) {
    return (c + divisor - 1) / divisor;
Zhekai Zhang's avatar
Zhekai Zhang committed
72
73
}

Muyang Li's avatar
Muyang Li committed
74
75
template<typename half_t>
__device__ __forceinline__ packed_as<half_t, 2>::type half2half2(half_t x);
Zhekai Zhang's avatar
Zhekai Zhang committed
76

Muyang Li's avatar
Muyang Li committed
77
78
template<>
__device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
79
80
81
    return __half2half2(x);
}

Muyang Li's avatar
Muyang Li committed
82
template<>
fengzch-das's avatar
fengzch-das committed
83
__device__ __forceinline__ packed_as<__nv_bfloat16, 2>::type half2half2<__nv_bfloat16>(__nv_bfloat16 x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
84
85
86
87
    return __bfloat162bfloat162(x);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
88
__device__ __forceinline__ float2 half22float2(T val);
Zhekai Zhang's avatar
Zhekai Zhang committed
89
90

template<>
Muyang Li's avatar
Muyang Li committed
91
__device__ __forceinline__ float2 half22float2<half2>(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
92
93
94
95
    return __half22float2(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
96
__device__ __forceinline__ float2 half22float2<__nv_bfloat162>(__nv_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
97
98
99
    return __bfloat1622float2(val);
}

Muyang Li's avatar
Muyang Li committed
100
101
102
103
104
105
106
107
108
template<typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(const half_t *inputs,
                            const uint32_t *weight,
                            const half_t *scales,
                            const half_t *zeros,
                            half_t *outputs,
                            const int IC,
                            const int OC) {

fengzch's avatar
fengzch committed
109
110
111
112
113
114
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
//     if constexpr (std::is_same_v<half_t, __nv_bfloat16>) {
//         trap_unsupported_arch();
//         return;
//     }
// #endif
Muyang Li's avatar
Muyang Li committed
115
116
    using half2_t  = typename packed_as<half_t, 2>::type;
    using accum_t  = float;
Zhekai Zhang's avatar
Zhekai Zhang committed
117
118
    using accum2_t = typename packed_as<accum_t, 2>::type;

Muyang Li's avatar
Muyang Li committed
119
120
    const int kStride            = 64;
    const int kElemsPerThread    = MEM_ACCESS_SIZE / 4;
Zhekai Zhang's avatar
Zhekai Zhang committed
121
122
123
124
125
126
    const int kThreadsNumPerTile = kStride / kElemsPerThread;
    // assert(MEM_ACCESS_SIZE == 128);

    // static constexpr int kShuffleSize = 32;
    static constexpr int kShuffleBasicTile = 2;
    static constexpr int kShuffleContinous = 4;
Muyang Li's avatar
Muyang Li committed
127
    static constexpr int kShuffleStrided   = 4;
Zhekai Zhang's avatar
Zhekai Zhang committed
128

Muyang Li's avatar
Muyang Li committed
129
    constexpr int Num         = NPerBlock * Batch;
Zhekai Zhang's avatar
Zhekai Zhang committed
130
131
132
133
    constexpr int kInterleave = 4;

    alignas(16) half_t local_inputs[kElemsPerThread];
    alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
Muyang Li's avatar
Muyang Li committed
134
    alignas(16) half_t half_weight_buffer[kElemsPerThread];
Zhekai Zhang's avatar
Zhekai Zhang committed
135
136
137
138
139
140
141
    alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
    alignas(16) half_t local_scale[NPerBlock];
    alignas(16) half_t local_scaled_zeros[NPerBlock];

    accum_t psum[Num];
    for (int i = 0; i < Num; ++i)
        psum[i] = static_cast<accum_t>(0.f);
Muyang Li's avatar
Muyang Li committed
142

143
144
145
146
    // extern __shared__ uint8_t shmem[];
    // float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);

    __shared__ float out_smem[BlockSize / WARP_SIZE * 2][Num * kInterleave];
Zhekai Zhang's avatar
Zhekai Zhang committed
147
148
149

    const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
    const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
Muyang Li's avatar
Muyang Li committed
150
151
    const int act_k_offset   = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride +
                             (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
Zhekai Zhang's avatar
Zhekai Zhang committed
152
153
    const int group_offset = act_k_offset / GroupSize;
    // TODO: use make_divisible
Muyang Li's avatar
Muyang Li committed
154
155
156
157
    const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
    const half_t *scale_ptr        = scales + blk_row_offset + thd_row_offset + group_offset * OC;
    const half_t *zeros_ptr        = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
    const half_t *inputs_ptr       = inputs + act_k_offset;
Zhekai Zhang's avatar
Zhekai Zhang committed
158

Muyang Li's avatar
Muyang Li committed
159
    const int act_forward_step   = BlockSize * kElemsPerThread / kInterleave;
Zhekai Zhang's avatar
Zhekai Zhang committed
160
161
162
    const int scale_forward_step = act_forward_step / GroupSize * OC;

    // Main loop iteration, each block completes the outputs for several OCs
Muyang Li's avatar
Muyang Li committed
163
164
165
166
    for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) {
// Load qweight, scales and scaled_zeros
#pragma unroll
        for (int idx = 0; idx < NPerBlock; ++idx) {
Zhekai Zhang's avatar
Zhekai Zhang committed
167
            // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
Muyang Li's avatar
Muyang Li committed
168
169
170
171
172
173
174
            *((float4 *)(local_qweights)) = *((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
            local_scale[idx]              = *(scale_ptr + idx * kInterleave);
            local_scaled_zeros[idx]       = *(zeros_ptr + idx * kInterleave);

// Map int4 qweight to fp format
#pragma unroll
            for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) {
Zhekai Zhang's avatar
Zhekai Zhang committed
175
                // Converts 32 bits (8 x int4) to 8 fp16
Muyang Li's avatar
Muyang Li committed
176
177
                dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i),
                                        reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
Zhekai Zhang's avatar
Zhekai Zhang committed
178
179
            }

Muyang Li's avatar
Muyang Li committed
180
181
182
183
184
185
186
187
188
189
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
            for (int i = 0; i < kShuffleContinous; ++i) {
#pragma unroll
                for (int j = 0; j < kShuffleStrided; ++j) {
                    half2_t w = *reinterpret_cast<half2_t *>(half_weight_buffer +
                                                             (i + j * kShuffleContinous) * kShuffleBasicTile);
                    w         = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
Zhekai Zhang's avatar
Zhekai Zhang committed
190
                }
Muyang Li's avatar
Muyang Li committed
191
192
193
194
195
196
197
            }
        }
#pragma unroll
        for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) {
            const half_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
            for (int idx = 0; idx < kElemsPerThread / 8; ++idx) {
Zhekai Zhang's avatar
Zhekai Zhang committed
198
                // load activation, 8 halves (128 bits) / step.
Muyang Li's avatar
Muyang Li committed
199
                *((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
Zhekai Zhang's avatar
Zhekai Zhang committed
200
            }
Muyang Li's avatar
Muyang Li committed
201
202
203
204
205
206
207
208
209
210
// Perform the MACs
#pragma unroll
            for (int x = 0; x < NPerBlock / 2; ++x) {
#pragma unroll
                for (int y = 0; y < kElemsPerThread; ++y) {
                    accum2_t prod = cuda_cast<accum2_t>(
                        __hmul2(*reinterpret_cast<half2_t *>(dequantized_weight + y * NPerBlock + x * 2),
                                half2half2(local_inputs[y])));
                    *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) =
                        prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
Zhekai Zhang's avatar
Zhekai Zhang committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
                    // *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
                    //     = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
                    //         half2half2(local_inputs[y]),
                    //         *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2));
                }
            }
        }
        inputs_ptr += act_forward_step;
        scale_ptr += scale_forward_step;
        zeros_ptr += scale_forward_step;
    }

    warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);

    // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
Muyang Li's avatar
Muyang Li committed
226
    for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) {
Zhekai Zhang's avatar
Zhekai Zhang committed
227
        int batch_idx = i / (NPerBlock * kInterleave);
Muyang Li's avatar
Muyang Li committed
228
229
230
        int oc_idx    = i % (NPerBlock * kInterleave);
        float acc     = 0.f;
        for (int j = 0; j < BlockSize / WARP_SIZE; ++j) {
Zhekai Zhang's avatar
Zhekai Zhang committed
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
            acc += out_smem[j][i];
        }
        outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
    }
}

/*
Computes GEMV (PyTorch interface).

Args:
  _in_feats: tensor of shape [B, IC];
  _kernel: int tensor of shape [OC, IC // 8];
  _zeros: int tensor of shape [OC, IC // G // 8];
  _scaling_factors: tensor of shape [OC, IC // G];
  blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
  blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;

Returns:
  out_feats: tensor of shape [B, OC];
*/
Tensor gemv_awq(
Muyang Li's avatar
Muyang Li committed
252
    Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size) {
253
254
    return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
        assert(isTypeMatch<half_t>(_in_feats.dtype()));
Zhekai Zhang's avatar
Zhekai Zhang committed
255

Muyang Li's avatar
Muyang Li committed
256
        auto output_shape   = _in_feats.shape.dataExtent;
257
        output_shape.back() = n;
Zhekai Zhang's avatar
Zhekai Zhang committed
258

Muyang Li's avatar
Muyang Li committed
259
260
261
262
        auto in_feats        = reinterpret_cast<half_t *>(_in_feats.data_ptr<half_t>());
        auto kernel          = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
        auto zeros           = reinterpret_cast<half_t *>(_zeros.data_ptr<half_t>());
        auto scaling_factors = reinterpret_cast<half_t *>(_scaling_factors.data_ptr<half_t>());
Zhekai Zhang's avatar
Zhekai Zhang committed
263

264
        Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
Muyang Li's avatar
Muyang Li committed
265
266
267
        half_t *out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());

        static constexpr int N_PER_BLOCK  = 2;
268
        static constexpr int K_INTERLEAVE = 4;
Muyang Li's avatar
Muyang Li committed
269
        static constexpr int BLOCK_SIZE   = 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
270

271
272
        dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
        dim3 num_threads(BLOCK_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
273

274
        constexpr int GROUP_SIZE = 64;
Zhekai Zhang's avatar
Zhekai Zhang committed
275

276
        assert(m > 0 && m <= 8);
277
        assert(group_size == GROUP_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
278

279
        dispatchVal(m, std::make_integer_sequence<int, 9>(), [&]<int M>() {
280
281
282
283
284
            if constexpr (M == 0) {
                assert(false);
                return;
            }
            if constexpr (M > 0) {
fengzch-das's avatar
fengzch-das committed
285
286
                gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>
                    <<<num_blocks, num_threads, 0, getCurrentCUDAStream()>>>(
Muyang Li's avatar
Muyang Li committed
287
                        in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
fengzch-das's avatar
fengzch-das committed
288
                checkCUDA(cudaGetLastError());
289
290
            }
        });
Zhekai Zhang's avatar
Zhekai Zhang committed
291

292
        return _out_feats;
Zhekai Zhang's avatar
Zhekai Zhang committed
293
294
    });
}