quant_cuda_kernel.cu 5.39 KB
Newer Older
chooper1's avatar
chooper1 committed
1
2
3
4
5
6
7
8
9
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
10
#include <c10/cuda/CUDAGuard.h>
chooper1's avatar
chooper1 committed
11
12
13
14
15
16
17
18
19
20
21
22
23

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16

namespace vllm {
namespace squeezellm {

__device__ inline unsigned int as_unsigned(int i) {
  return *reinterpret_cast<unsigned int*>(&i);
}

// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
24
#ifndef USE_ROCM
25
    const half2* __restrict__ vec,
26
#else
27
    const __half2* __restrict__ vec,
28
#endif
29
    const int* __restrict__ mat,
30
#ifndef USE_ROCM
31
    half2* __restrict__ mul,
32
#else
33
    float2* __restrict__ mul,
34
#endif
35
36
    const __half* __restrict__ lookup_table, int height, int width, int batch,
    int vec_height) {
chooper1's avatar
chooper1 committed
37
38
39
40

  const int blockwidth2 = BLOCKWIDTH / 2;

  int row = BLOCKHEIGHT4 * blockIdx.x;
41
  int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
chooper1's avatar
chooper1 committed
42

43
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
44
  __shared__ half2 blockvec[blockwidth2];
45
46
47
#else
  __shared__ __half2 blockvec[blockwidth2];
#endif
chooper1's avatar
chooper1 committed
48
49
50
51
52
53
54
55
56
57

  __shared__ __half deq2[16][BLOCKWIDTH];
  int off = threadIdx.x;
  int column_offset = col * 16;
  for (int val = 0; val < 16; val += 1) {
    int lut_index = column_offset + val;
    deq2[val][off] = lookup_table[lut_index];
  }

  __half res;
58
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
59
60
  half2 res2;
  half2 tmp2;
61
62
63
64
#else
  __half2 res2;
  __half2 tmp2;
#endif
chooper1's avatar
chooper1 committed
65
66
67
68
69
70
71

  int i;
  int k;

  unsigned int tmp1;
  unsigned int lut_index1, lut_index2;

72
  for (int b = 0; b < batch; ++b) {
chooper1's avatar
chooper1 committed
73
74
75
76
77
78
    i = width * row + col;
    res = __int2half_rd(0);
    k = 0;

    __syncthreads();
    if (threadIdx.x < blockwidth2)
79
80
81
      blockvec[threadIdx.x] =
          vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
              threadIdx.x];
chooper1's avatar
chooper1 committed
82
83
84
85
86
    __syncthreads();

    while (k < blockwidth2) {
      tmp1 = as_unsigned(mat[i]);

87
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
88
89
      res2 = {};
      tmp2 = {};
90
91
92
93
94
95
#else
      res2.x = __half_as_ushort(__float2half(0));
      res2.y = __half_as_ushort(__float2half(0));
      tmp2.x = __half_as_ushort(__float2half(0));
      tmp2.y = __half_as_ushort(__float2half(0));
#endif
chooper1's avatar
chooper1 committed
96
97
98

      lut_index1 = tmp1 & 0xF;
      lut_index2 = (tmp1 >> 4) & 0xF;
99
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
100
101
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
102
103
104
105
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
106
107
108
109
      res2 = __hfma2(tmp2, blockvec[k + 0], res2);

      lut_index1 = (tmp1 >> 8) & 0xF;
      lut_index2 = (tmp1 >> 12) & 0xF;
110
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
111
112
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
113
114
115
116
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
117
118
119
120
      res2 = __hfma2(tmp2, blockvec[k + 1], res2);

      lut_index1 = (tmp1 >> 16) & 0xF;
      lut_index2 = (tmp1 >> 20) & 0xF;
121
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
122
123
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
124
125
126
127
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
128
129
130
131
      res2 = __hfma2(tmp2, blockvec[k + 2], res2);

      lut_index1 = (tmp1 >> 24) & 0xF;
      lut_index2 = (tmp1 >> 28) & 0xF;
132
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
133
134
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
135
136
137
138
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
139
140
      res2 = __hfma2(tmp2, blockvec[k + 3], res2);

141
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
142
      res = __hadd(__hadd(res2.x, res2.y), res);
143
#else
144
145
      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
                   res);
146
#endif
chooper1's avatar
chooper1 committed
147
148
149
150
151
152

      i += width;
      k += 4;
    }

    // col%2 -> only set one of the two values
153
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
154
155
156
157
158
159
    half2 res3 = {};
    if (col % 2 == 0) {
      res3.x = res;
    } else {
      res3.y = res;
    }
160
161
162
163
164
165
166
167
168
169
#else
    __half2 res3;
    res3.x = __half_as_ushort(__float2half(0));
    res3.y = __half_as_ushort(__float2half(0));
    if (col % 2 == 0) {
      res3.x = __half_as_ushort(res);
    } else {
      res3.y = __half_as_ushort(res);
    }
#endif
chooper1's avatar
chooper1 committed
170

171
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
172
    atomicAdd(&mul[b * width / 2 + col / 2], res3);
173
174
175
176
177
#else
    int tmp_addr = b * width / 2 + col / 2;
    atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
    atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
chooper1's avatar
chooper1 committed
178
179
180
  }
}

181
182
}  // namespace squeezellm
}  // namespace vllm
chooper1's avatar
chooper1 committed
183
184

// 4-bit matvec kernel (LUT-based)
185
186
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                     torch::Tensor lookup_table) {
chooper1's avatar
chooper1 committed
187
188
189
190
191
192
  int height = mat.size(0);
  int width = mat.size(1);

  int batch = vec.size(0);
  int vec_height = vec.size(1);

193
194
  dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
              (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
chooper1's avatar
chooper1 committed
195
  dim3 threads(BLOCKWIDTH);
196

197
  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
198
199
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
200
#ifndef USE_ROCM
201
      (half2*)vec.data<at::Half>(),
202
#else
203
      (__half2*)vec.data_ptr<at::Half>(),
204
#endif
205
      mat.data_ptr<int>(),
206
#ifndef USE_ROCM
207
      (half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
208
#else
209
210
      (float2*)mul.data_ptr<float>(),
      (__half*)lookup_table.data_ptr<at::Half>(),
211
#endif
212
      height, width, batch, vec_height);
chooper1's avatar
chooper1 committed
213
214
215
216
}

#undef BLOCKWIDTH
#undef BLOCKHEIGHT4