quant_cuda_kernel.cu 5.2 KB
Newer Older
chooper1's avatar
chooper1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16

namespace vllm {
namespace squeezellm {

__device__ inline unsigned int as_unsigned(int i) {
  return *reinterpret_cast<unsigned int*>(&i);
}

// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
23
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
24
    const  half2* __restrict__ vec,
25
26
27
#else
    const  __half2* __restrict__ vec,
#endif
chooper1's avatar
chooper1 committed
28
    const    int* __restrict__ mat,
29
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
30
           half2* __restrict__ mul,
31
32
33
#else
          float2* __restrict__ mul,
#endif
chooper1's avatar
chooper1 committed
34
35
36
37
38
39
40
41
42
43
44
45
    const  __half* __restrict__ lookup_table,
    int height,
    int width,
    int batch,
    int vec_height
) {

  const int blockwidth2 = BLOCKWIDTH / 2;

  int row = BLOCKHEIGHT4 * blockIdx.x;
  int col =  BLOCKWIDTH * blockIdx.y + threadIdx.x;

46
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
47
  __shared__ half2 blockvec[blockwidth2];
48
49
50
#else
  __shared__ __half2 blockvec[blockwidth2];
#endif
chooper1's avatar
chooper1 committed
51
52
53
54
55
56
57
58
59
60

  __shared__ __half deq2[16][BLOCKWIDTH];
  int off = threadIdx.x;
  int column_offset = col * 16;
  for (int val = 0; val < 16; val += 1) {
    int lut_index = column_offset + val;
    deq2[val][off] = lookup_table[lut_index];
  }

  __half res;
61
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
62
63
  half2 res2;
  half2 tmp2;
64
65
66
67
#else
  __half2 res2;
  __half2 tmp2;
#endif
chooper1's avatar
chooper1 committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

  int i;
  int k;

  unsigned int tmp1;
  unsigned int lut_index1, lut_index2;

  for (int b = 0; b < batch; ++b){
    i = width * row + col;
    res = __int2half_rd(0);
    k = 0;

    __syncthreads();
    if (threadIdx.x < blockwidth2)
      blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
    __syncthreads();

    while (k < blockwidth2) {
      tmp1 = as_unsigned(mat[i]);

88
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
89
90
      res2 = {};
      tmp2 = {};
91
92
93
94
95
96
#else
      res2.x = __half_as_ushort(__float2half(0));
      res2.y = __half_as_ushort(__float2half(0));
      tmp2.x = __half_as_ushort(__float2half(0));
      tmp2.y = __half_as_ushort(__float2half(0));
#endif
chooper1's avatar
chooper1 committed
97
98
99

      lut_index1 = tmp1 & 0xF;
      lut_index2 = (tmp1 >> 4) & 0xF;
100
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
101
102
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
103
104
105
106
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
107
108
109
110
      res2 = __hfma2(tmp2, blockvec[k + 0], res2);

      lut_index1 = (tmp1 >> 8) & 0xF;
      lut_index2 = (tmp1 >> 12) & 0xF;
111
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
112
113
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
114
115
116
117
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
118
119
120
121
      res2 = __hfma2(tmp2, blockvec[k + 1], res2);

      lut_index1 = (tmp1 >> 16) & 0xF;
      lut_index2 = (tmp1 >> 20) & 0xF;
122
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
123
124
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
125
126
127
128
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
129
130
131
132
      res2 = __hfma2(tmp2, blockvec[k + 2], res2);

      lut_index1 = (tmp1 >> 24) & 0xF;
      lut_index2 = (tmp1 >> 28) & 0xF;
133
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
134
135
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
136
137
138
139
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
140
141
      res2 = __hfma2(tmp2, blockvec[k + 3], res2);

142
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
143
      res = __hadd(__hadd(res2.x, res2.y), res);
144
145
146
#else
      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
#endif
chooper1's avatar
chooper1 committed
147
148
149
150
151
152

      i += width;
      k += 4;
    }

    // col%2 -> only set one of the two values
153
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
154
155
156
157
158
159
    half2 res3 = {};
    if (col % 2 == 0) {
      res3.x = res;
    } else {
      res3.y = res;
    }
160
161
162
163
164
165
166
167
168
169
#else
    __half2 res3;
    res3.x = __half_as_ushort(__float2half(0));
    res3.y = __half_as_ushort(__float2half(0));
    if (col % 2 == 0) {
      res3.x = __half_as_ushort(res);
    } else {
      res3.y = __half_as_ushort(res);
    }
#endif
chooper1's avatar
chooper1 committed
170

171
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
172
    atomicAdd(&mul[b * width / 2 + col / 2], res3);
173
174
175
176
177
#else
    int tmp_addr = b * width / 2 + col / 2;
    atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
    atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
chooper1's avatar
chooper1 committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
  }
}

} // namespace squeezellm
} // namespace vllm

// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor lookup_table
) {
  int height = mat.size(0);
  int width = mat.size(1);

  int batch = vec.size(0);
  int vec_height = vec.size(1);

  dim3 blocks(
    (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  );
  dim3 threads(BLOCKWIDTH);

  vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
204
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
205
    (half2*) vec.data<at::Half>(),
206
207
208
#else
    (__half2*) vec.data_ptr<at::Half>(),
#endif
chooper1's avatar
chooper1 committed
209
    mat.data_ptr<int>(),
210
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
211
212
    (half2*) mul.data<at::Half>(),
    (__half*) lookup_table.data<at::Half>(),
213
214
215
216
#else
    (float2*) mul.data_ptr<float>(),
    (__half*) lookup_table.data_ptr<at::Half>(),
#endif
chooper1's avatar
chooper1 committed
217
218
219
220
221
222
    height, width, batch, vec_height
  );
}

#undef BLOCKWIDTH
#undef BLOCKHEIGHT4