awq_kernel.cu 9.28 KB
Newer Older
1
2
3
4
5
// Adapted from
// https://github.com/vllm-project/vllm/blob/eb59b5a6cba6727d3727c0372258db9002f687c1/csrc/quantization/awq/gemm_kernels.cu#L350
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
#include <torch/all.h>
AniZpZ's avatar
AniZpZ committed
6
7
8
9
10
11
12
13
14
15
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif

template <int lut>
__device__ inline int lop3(int a, int b, int c) {
  int res;
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut));
  return res;
}
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
  uint4 result;

  uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);

  // First, we extract the i4s and construct an intermediate fp16 number.
  static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
  static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
  static constexpr uint32_t TOP_MASK = 0x00f000f0;
  static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;

  // Note that the entire sequence only requires 1 shift instruction. This is
  // thanks to the register packing format and the fact that we force our
  // integers to be unsigned, and account for this in the fp16 subtractions. In
  // addition, I exploit the fact that sub and fma have the same throughput in
  // order to convert elt_23 and elt_67 to fp16 without having to shift them to
  // the bottom bits before hand.

  // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
  // dependency if we issue immediately before required.
  const uint32_t top_i4s = i4s >> 8;
  // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[0])
               : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[1])
               : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[2])
               : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
  // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
  asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
               : "=r"(h[3])
               : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));

  // This is the half2 {1024, 1024} represented as an integer.
  static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
  // This is the half2 {1 / 16, 1 / 16} represented as an integer.
  static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
  // This is the half2 {-64, -64} represented as an integer.
  static constexpr uint32_t NEG_64 = 0xd400d400;

  // Finally, we construct the output numbers.
  // Convert elt_01
  asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
  // Convert elt_23
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
  // Convert elt_45
  asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
  // Convert elt_67
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));

  return result;
#else
  assert(false);
  return {};
#endif
}

AniZpZ's avatar
AniZpZ committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
  uint4 result;
  uint32_t* h = reinterpret_cast<uint32_t*>(&result);
  uint32_t const i4s = source;

  // Define masks and constants
  static constexpr uint32_t MASK = 0x000f000f;
  static constexpr uint32_t EX = 0x43004300;
  static constexpr uint32_t MUL = 0x3F803F80;
  static constexpr uint32_t ADD = 0xC300C300;

  int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s, MASK, EX);
  int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 4, MASK, EX);
  int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 8, MASK, EX);
  int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(i4s >> 12, MASK, EX);

  nv_bfloat162* res = reinterpret_cast<nv_bfloat162*>(h);
  res[0] = __hfma2(
      *reinterpret_cast<nv_bfloat162*>(&lo0),
      *reinterpret_cast<const nv_bfloat162*>(&MUL),
      *reinterpret_cast<const nv_bfloat162*>(&ADD));
  res[1] = __hfma2(
      *reinterpret_cast<nv_bfloat162*>(&hi0),
      *reinterpret_cast<const nv_bfloat162*>(&MUL),
      *reinterpret_cast<const nv_bfloat162*>(&ADD));
  res[2] = __hfma2(
      *reinterpret_cast<nv_bfloat162*>(&lo1),
      *reinterpret_cast<const nv_bfloat162*>(&MUL),
      *reinterpret_cast<const nv_bfloat162*>(&ADD));
  res[3] = __hfma2(
      *reinterpret_cast<nv_bfloat162*>(&hi1),
      *reinterpret_cast<const nv_bfloat162*>(&MUL),
      *reinterpret_cast<const nv_bfloat162*>(&ADD));

  return result;
#else
  assert(false);
  return {};
#endif
}

template <typename OutputT>
124
125
__global__ void __launch_bounds__(256) dequantize_weights(
    int* __restrict__ qweight,
AniZpZ's avatar
AniZpZ committed
126
    OutputT* __restrict__ scales,
127
    int* __restrict__ qzeros,
AniZpZ's avatar
AniZpZ committed
128
    OutputT* __restrict__ output,
129
130
131
132
133
    int group_size,
    int qweight_cols) {
  int col = blockIdx.x * blockDim.x + threadIdx.x;
  int row = blockIdx.y * blockDim.y + threadIdx.y;

AniZpZ's avatar
AniZpZ committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
  int group_idx = row / group_size;
  int scale_offset = 8 * col + group_idx * qweight_cols * 8;
  uint4 loaded_scale = *(uint4*)(scales + scale_offset);

  // Handle different data types
  if constexpr (std::is_same<OutputT, half>::value) {
    // FP16 path
    uint4 zeros = dequantize_s4_to_fp16x2(qzeros[col + group_idx * qweight_cols]);
    uint4 weight_fp16 = dequantize_s4_to_fp16x2(qweight[col + row * qweight_cols]);

    // Use PTX assembly for FP16 operations
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(zeros.x));
    asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.x) : "r"(weight_fp16.x), "r"(loaded_scale.x));
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(zeros.y));
    asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.y) : "r"(weight_fp16.y), "r"(loaded_scale.y));
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(zeros.z));
    asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.z) : "r"(weight_fp16.z), "r"(loaded_scale.z));
    asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(zeros.w));
    asm volatile("mul.rn.f16x2 %0, %1, %2;\n" : "=r"(weight_fp16.w) : "r"(weight_fp16.w), "r"(loaded_scale.w));

    OutputT* output_ptr = output + 8 * col + 8 * row * qweight_cols;
    *(uint4*)output_ptr = weight_fp16;
  } else if constexpr (std::is_same<OutputT, __nv_bfloat16>::value) {
    uint4 weight_raw = dequantize_s4_to_bf16x2(qweight[col + row * qweight_cols]);
    uint4 zero_raw = dequantize_s4_to_bf16x2(qzeros[col + group_idx * qweight_cols]);
    uint4 scale_raw = *reinterpret_cast<uint4*>(scales + scale_offset);

    // Vectorized processing (each uint4 contains 4 nv_bfloat162)
    nv_bfloat162* weight_vec = reinterpret_cast<nv_bfloat162*>(&weight_raw);
    nv_bfloat162* zero_vec = reinterpret_cast<nv_bfloat162*>(&zero_raw);
    nv_bfloat162* scale_vec = reinterpret_cast<nv_bfloat162*>(&scale_raw);

// Single instruction dual-channel operation
#pragma unroll
    for (int i = 0; i < 4; ++i) {  // uint4 = 4 * nv_bfloat162
      weight_vec[i] = __hmul2(__hsub2(weight_vec[i], zero_vec[i]), scale_vec[i]);
    }

    // Directly store to OutputT array (guaranteed contiguous memory)
    OutputT* output_ptr = output + 8 * col + row * qweight_cols * 8;
    static_assert(sizeof(uint4) == 8 * sizeof(OutputT), "Memory layout mismatch");
    *reinterpret_cast<uint4*>(output_ptr) = weight_raw;
  }
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
}

torch::Tensor awq_dequantize(torch::Tensor qweight, torch::Tensor scales, torch::Tensor qzeros) {
  int qweight_rows = qweight.size(0);
  int qweight_cols = qweight.size(1);
  int group_size = qweight_rows / scales.size(0);

  int x_num_threads = 16;
  int y_num_threads = 16;
  int x_blocks = qweight_cols / x_num_threads;
  int y_blocks = qweight_rows / y_num_threads;

  const at::cuda::OptionalCUDAGuard device_guard(device_of(qweight));

  auto output_tensor_options = torch::TensorOptions().dtype(scales.dtype()).device(scales.device());
  at::Tensor output = torch::empty({qweight_rows, qweight_cols * 8}, output_tensor_options);

  auto _qweight = reinterpret_cast<int*>(qweight.data_ptr<int>());
  auto _zeros = reinterpret_cast<int*>(qzeros.data_ptr<int>());

  dim3 num_blocks(x_blocks, y_blocks);
  dim3 threads_per_block(x_num_threads, y_num_threads);
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AniZpZ's avatar
AniZpZ committed
200
201
202
203
204
205
206
207
208
209
210
211

  if (scales.scalar_type() == at::ScalarType::Half) {
    auto _scales = reinterpret_cast<half*>(scales.data_ptr<at::Half>());
    auto _output = reinterpret_cast<half*>(output.data_ptr<at::Half>());
    dequantize_weights<half>
        <<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
  } else {
    auto _scales = reinterpret_cast<__nv_bfloat16*>(scales.data_ptr<at::BFloat16>());
    auto _output = reinterpret_cast<__nv_bfloat16*>(output.data_ptr<at::BFloat16>());
    dequantize_weights<__nv_bfloat16>
        <<<num_blocks, threads_per_block, 0, stream>>>(_qweight, _scales, _zeros, _output, group_size, qweight_cols);
  }
212
213
214

  return output;
}