gemv_awq.hip 11.9 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
#include "hip/hip_runtime.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
2
/*
Muyang Li's avatar
Muyang Li committed
3
4
 * Modified from NVIDIA
 * [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
Zhekai Zhang's avatar
Zhekai Zhang committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
 * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*
@article{lin2023awq,
  title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
  author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
  journal={arXiv},
  year={2023}
}
*/

#include "gemv_awq.h"
#include "../dispatch_utils.h"

#include "../utils.cuh"

fengzch-das's avatar
fengzch-das committed
34
35
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
36
37
38
39
40
41
42
43
#include <stdio.h>
#include "dequantize.cuh"

#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128

// Reduce sum within the warp using the tree reduction algorithm.
Muyang Li's avatar
Muyang Li committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
template<typename float_t, int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(float_t *psum, float (*out_smem)[Num * 4]) {
    // kInterleave = 4
    float fpsum[Num];
#pragma unroll
    for (int i = 0; i < Num; ++i) {
        fpsum[i] = static_cast<float>(psum[i]);
    }

#pragma unroll
    for (int i = 0; i < Num; ++i) {
        // T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
        fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
        fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
        fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
    }
    __syncthreads();
    int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
    if (lane == 0 || lane == 2 || lane == 4 || lane == 6) {
#pragma unroll
        for (int i = 0; i < Num; ++i) {
            out_smem[warp][i * 4 + lane / 2] = fpsum[i];
        }
    }
    __syncthreads();
Zhekai Zhang's avatar
Zhekai Zhang committed
69
70
};

Muyang Li's avatar
Muyang Li committed
71
72
__device__ __forceinline__ int make_divisible(int c, int divisor) {
    return (c + divisor - 1) / divisor;
Zhekai Zhang's avatar
Zhekai Zhang committed
73
74
}

Muyang Li's avatar
Muyang Li committed
75
76
template<typename half_t>
__device__ __forceinline__ packed_as<half_t, 2>::type half2half2(half_t x);
Zhekai Zhang's avatar
Zhekai Zhang committed
77

Muyang Li's avatar
Muyang Li committed
78
79
template<>
__device__ __forceinline__ packed_as<half, 2>::type half2half2<half>(half x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
80
81
82
    return __half2half2(x);
}

Muyang Li's avatar
Muyang Li committed
83
template<>
fengzch-das's avatar
fengzch-das committed
84
__device__ __forceinline__ packed_as<__hip_bfloat16, 2>::type half2half2<__hip_bfloat16>(__hip_bfloat16 x) {
Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
    return __bfloat162bfloat162(x);
}

template<typename T>
Muyang Li's avatar
Muyang Li committed
89
__device__ __forceinline__ float2 half22float2(T val);
Zhekai Zhang's avatar
Zhekai Zhang committed
90
91

template<>
Muyang Li's avatar
Muyang Li committed
92
__device__ __forceinline__ float2 half22float2<half2>(half2 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
93
94
95
96
    return __half22float2(val);
}

template<>
fengzch-das's avatar
fengzch-das committed
97
__device__ __forceinline__ float2 half22float2<__hip_bfloat162>(__hip_bfloat162 val) {
Zhekai Zhang's avatar
Zhekai Zhang committed
98
99
100
    return __bfloat1622float2(val);
}

Muyang Li's avatar
Muyang Li committed
101
102
103
104
105
106
107
108
109
template<typename half_t, int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(const half_t *inputs,
                            const uint32_t *weight,
                            const half_t *scales,
                            const half_t *zeros,
                            half_t *outputs,
                            const int IC,
                            const int OC) {

fengzch-das's avatar
fengzch-das committed
110
111
#if defined(__DTK_ARCH__) && __DTK_ARCH__ < 800
    if constexpr (std::is_same_v<half_t, __hip_bfloat16>) {
112
113
114
115
        trap_unsupported_arch();
        return;
    }
#endif
Muyang Li's avatar
Muyang Li committed
116
117
    using half2_t  = typename packed_as<half_t, 2>::type;
    using accum_t  = float;
Zhekai Zhang's avatar
Zhekai Zhang committed
118
119
    using accum2_t = typename packed_as<accum_t, 2>::type;

Muyang Li's avatar
Muyang Li committed
120
121
    const int kStride            = 64;
    const int kElemsPerThread    = MEM_ACCESS_SIZE / 4;
Zhekai Zhang's avatar
Zhekai Zhang committed
122
123
124
125
126
127
    const int kThreadsNumPerTile = kStride / kElemsPerThread;
    // assert(MEM_ACCESS_SIZE == 128);

    // static constexpr int kShuffleSize = 32;
    static constexpr int kShuffleBasicTile = 2;
    static constexpr int kShuffleContinous = 4;
Muyang Li's avatar
Muyang Li committed
128
    static constexpr int kShuffleStrided   = 4;
Zhekai Zhang's avatar
Zhekai Zhang committed
129

Muyang Li's avatar
Muyang Li committed
130
    constexpr int Num         = NPerBlock * Batch;
Zhekai Zhang's avatar
Zhekai Zhang committed
131
132
133
134
    constexpr int kInterleave = 4;

    alignas(16) half_t local_inputs[kElemsPerThread];
    alignas(16) uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
Muyang Li's avatar
Muyang Li committed
135
    alignas(16) half_t half_weight_buffer[kElemsPerThread];
Zhekai Zhang's avatar
Zhekai Zhang committed
136
137
138
139
140
141
142
    alignas(16) half_t dequantized_weight[kElemsPerThread * NPerBlock];
    alignas(16) half_t local_scale[NPerBlock];
    alignas(16) half_t local_scaled_zeros[NPerBlock];

    accum_t psum[Num];
    for (int i = 0; i < Num; ++i)
        psum[i] = static_cast<accum_t>(0.f);
Muyang Li's avatar
Muyang Li committed
143

144
145
146
147
    // extern __shared__ uint8_t shmem[];
    // float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);

    __shared__ float out_smem[BlockSize / WARP_SIZE * 2][Num * kInterleave];
Zhekai Zhang's avatar
Zhekai Zhang committed
148
149
150

    const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
    const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
Muyang Li's avatar
Muyang Li committed
151
152
    const int act_k_offset   = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride +
                             (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
Zhekai Zhang's avatar
Zhekai Zhang committed
153
154
    const int group_offset = act_k_offset / GroupSize;
    // TODO: use make_divisible
Muyang Li's avatar
Muyang Li committed
155
156
157
158
    const uint32_t *blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
    const half_t *scale_ptr        = scales + blk_row_offset + thd_row_offset + group_offset * OC;
    const half_t *zeros_ptr        = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
    const half_t *inputs_ptr       = inputs + act_k_offset;
Zhekai Zhang's avatar
Zhekai Zhang committed
159

Muyang Li's avatar
Muyang Li committed
160
    const int act_forward_step   = BlockSize * kElemsPerThread / kInterleave;
Zhekai Zhang's avatar
Zhekai Zhang committed
161
162
163
    const int scale_forward_step = act_forward_step / GroupSize * OC;

    // Main loop iteration, each block completes the outputs for several OCs
Muyang Li's avatar
Muyang Li committed
164
165
166
167
    for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread) {
// Load qweight, scales and scaled_zeros
#pragma unroll
        for (int idx = 0; idx < NPerBlock; ++idx) {
Zhekai Zhang's avatar
Zhekai Zhang committed
168
            // use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
Muyang Li's avatar
Muyang Li committed
169
170
171
172
173
174
175
            *((float4 *)(local_qweights)) = *((float4 *)(blk_weight_ptr + (idx * kInterleave * IC + kk) / PACK_FACTOR));
            local_scale[idx]              = *(scale_ptr + idx * kInterleave);
            local_scaled_zeros[idx]       = *(zeros_ptr + idx * kInterleave);

// Map int4 qweight to fp format
#pragma unroll
            for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i) {
Zhekai Zhang's avatar
Zhekai Zhang committed
176
                // Converts 32 bits (8 x int4) to 8 fp16
Muyang Li's avatar
Muyang Li committed
177
178
                dequantize_s4_to_fp16x2(*reinterpret_cast<half2_t *>(local_qweights + i),
                                        reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
Zhekai Zhang's avatar
Zhekai Zhang committed
179
180
            }

Muyang Li's avatar
Muyang Li committed
181
182
183
184
185
186
187
188
189
190
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
            for (int i = 0; i < kShuffleContinous; ++i) {
#pragma unroll
                for (int j = 0; j < kShuffleStrided; ++j) {
                    half2_t w = *reinterpret_cast<half2_t *>(half_weight_buffer +
                                                             (i + j * kShuffleContinous) * kShuffleBasicTile);
                    w         = __hfma2(w, half2half2(local_scale[idx]), half2half2(local_scaled_zeros[idx]));
                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0) * NPerBlock + idx] = w.x;
                    dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1) * NPerBlock + idx] = w.y;
Zhekai Zhang's avatar
Zhekai Zhang committed
191
                }
Muyang Li's avatar
Muyang Li committed
192
193
194
195
196
197
198
            }
        }
#pragma unroll
        for (int batch_idx = 0; batch_idx < Batch; ++batch_idx) {
            const half_t *local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
            for (int idx = 0; idx < kElemsPerThread / 8; ++idx) {
Zhekai Zhang's avatar
Zhekai Zhang committed
199
                // load activation, 8 halves (128 bits) / step.
Muyang Li's avatar
Muyang Li committed
200
                *((float4 *)(local_inputs + idx * 8)) = *((float4 *)(local_inputs_ptr + idx * 8));
Zhekai Zhang's avatar
Zhekai Zhang committed
201
            }
Muyang Li's avatar
Muyang Li committed
202
203
204
205
206
207
208
209
210
211
// Perform the MACs
#pragma unroll
            for (int x = 0; x < NPerBlock / 2; ++x) {
#pragma unroll
                for (int y = 0; y < kElemsPerThread; ++y) {
                    accum2_t prod = cuda_cast<accum2_t>(
                        __hmul2(*reinterpret_cast<half2_t *>(dequantized_weight + y * NPerBlock + x * 2),
                                half2half2(local_inputs[y])));
                    *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2) =
                        prod + *reinterpret_cast<accum2_t *>(psum + batch_idx * NPerBlock + x * 2);
Zhekai Zhang's avatar
Zhekai Zhang committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
                    // *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2)
                    //     = __hfma2(*reinterpret_cast<half2_t*>(dequantized_weight + y * NPerBlock + x * 2),
                    //         half2half2(local_inputs[y]),
                    //         *reinterpret_cast<half2_t*>(psum + batch_idx * NPerBlock + x * 2));
                }
            }
        }
        inputs_ptr += act_forward_step;
        scale_ptr += scale_forward_step;
        zeros_ptr += scale_forward_step;
    }

    warp_reduce<accum_t, Num, WARP_SIZE>(psum, out_smem);

    // Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
Muyang Li's avatar
Muyang Li committed
227
    for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize) {
Zhekai Zhang's avatar
Zhekai Zhang committed
228
        int batch_idx = i / (NPerBlock * kInterleave);
Muyang Li's avatar
Muyang Li committed
229
230
231
        int oc_idx    = i % (NPerBlock * kInterleave);
        float acc     = 0.f;
        for (int j = 0; j < BlockSize / WARP_SIZE; ++j) {
Zhekai Zhang's avatar
Zhekai Zhang committed
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            acc += out_smem[j][i];
        }
        outputs[batch_idx * OC + blk_row_offset + oc_idx] = static_cast<half_t>(acc);
    }
}

/*
Computes GEMV (PyTorch interface).

Args:
  _in_feats: tensor of shape [B, IC];
  _kernel: int tensor of shape [OC, IC // 8];
  _zeros: int tensor of shape [OC, IC // G // 8];
  _scaling_factors: tensor of shape [OC, IC // G];
  blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
  blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;

Returns:
  out_feats: tensor of shape [B, OC];
*/
Tensor gemv_awq(
Muyang Li's avatar
Muyang Li committed
253
    Tensor _in_feats, Tensor _kernel, Tensor _scaling_factors, Tensor _zeros, int m, int n, int k, int group_size) {
254
255
    return dispatchFloat16(_scaling_factors.scalar_type(), [&]<typename half_t>() {
        assert(isTypeMatch<half_t>(_in_feats.dtype()));
Zhekai Zhang's avatar
Zhekai Zhang committed
256

Muyang Li's avatar
Muyang Li committed
257
        auto output_shape   = _in_feats.shape.dataExtent;
258
        output_shape.back() = n;
Zhekai Zhang's avatar
Zhekai Zhang committed
259

Muyang Li's avatar
Muyang Li committed
260
261
262
263
        auto in_feats        = reinterpret_cast<half_t *>(_in_feats.data_ptr<half_t>());
        auto kernel          = reinterpret_cast<uint32_t *>(_kernel.data_ptr());
        auto zeros           = reinterpret_cast<half_t *>(_zeros.data_ptr<half_t>());
        auto scaling_factors = reinterpret_cast<half_t *>(_scaling_factors.data_ptr<half_t>());
Zhekai Zhang's avatar
Zhekai Zhang committed
264

265
        Tensor _out_feats = Tensor::allocate(output_shape, _in_feats.dtype(), _in_feats.device());
Muyang Li's avatar
Muyang Li committed
266
267
268
        half_t *out_feats = reinterpret_cast<half_t *>(_out_feats.data_ptr());

        static constexpr int N_PER_BLOCK  = 2;
269
        static constexpr int K_INTERLEAVE = 4;
Muyang Li's avatar
Muyang Li committed
270
        static constexpr int BLOCK_SIZE   = 256;
Zhekai Zhang's avatar
Zhekai Zhang committed
271

272
273
        dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
        dim3 num_threads(BLOCK_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
274

275
        constexpr int GROUP_SIZE = 64;
Zhekai Zhang's avatar
Zhekai Zhang committed
276

277
        assert(m > 0 && m <= 8);
278
        assert(group_size == GROUP_SIZE);
Zhekai Zhang's avatar
Zhekai Zhang committed
279

280
        dispatchVal(m, std::make_integer_sequence<int, 9>(), [&]<int M>() {
281
282
283
284
285
            if constexpr (M == 0) {
                assert(false);
                return;
            }
            if constexpr (M > 0) {
fengzch-das's avatar
fengzch-das committed
286
287
               hipLaunchKernelGGL(( gemv_kernel<half_t, N_PER_BLOCK, M, BLOCK_SIZE, GROUP_SIZE>)
                    , dim3(num_blocks), dim3(num_threads), 0, getCurrentHIPStreamMasqueradingAsCUDA(), 
Muyang Li's avatar
Muyang Li committed
288
                        in_feats, kernel, scaling_factors, zeros, out_feats, k, n);
fengzch-das's avatar
fengzch-das committed
289
                checkCUDA(hipGetLastError());
290
291
            }
        });
Zhekai Zhang's avatar
Zhekai Zhang committed
292

293
        return _out_feats;
Zhekai Zhang's avatar
Zhekai Zhang committed
294
295
    });
}