quant_cuda_kernel.cu 5.37 KB
Newer Older
chooper1's avatar
chooper1 committed
1
2
3
4
5
6
7
8
#include <torch/all.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
9
#include <c10/cuda/CUDAGuard.h>
chooper1's avatar
chooper1 committed
10
11
12
13
14
15
16
17
18
19
20
21
22

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16

namespace vllm {
namespace squeezellm {

__device__ inline unsigned int as_unsigned(int i) {
  return *reinterpret_cast<unsigned int*>(&i);
}

// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
23
#ifndef USE_ROCM
24
    const half2* __restrict__ vec,
25
#else
26
    const __half2* __restrict__ vec,
27
#endif
28
    const int* __restrict__ mat,
29
#ifndef USE_ROCM
30
    half2* __restrict__ mul,
31
#else
32
    float2* __restrict__ mul,
33
#endif
34
35
    const __half* __restrict__ lookup_table, int height, int width, int batch,
    int vec_height) {
chooper1's avatar
chooper1 committed
36
37
38
39

  const int blockwidth2 = BLOCKWIDTH / 2;

  int row = BLOCKHEIGHT4 * blockIdx.x;
40
  int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
chooper1's avatar
chooper1 committed
41

42
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
43
  __shared__ half2 blockvec[blockwidth2];
44
45
46
#else
  __shared__ __half2 blockvec[blockwidth2];
#endif
chooper1's avatar
chooper1 committed
47
48
49
50
51
52
53
54
55
56

  __shared__ __half deq2[16][BLOCKWIDTH];
  int off = threadIdx.x;
  int column_offset = col * 16;
  for (int val = 0; val < 16; val += 1) {
    int lut_index = column_offset + val;
    deq2[val][off] = lookup_table[lut_index];
  }

  __half res;
57
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
58
59
  half2 res2;
  half2 tmp2;
60
61
62
63
#else
  __half2 res2;
  __half2 tmp2;
#endif
chooper1's avatar
chooper1 committed
64
65
66
67
68
69
70

  int i;
  int k;

  unsigned int tmp1;
  unsigned int lut_index1, lut_index2;

71
  for (int b = 0; b < batch; ++b) {
chooper1's avatar
chooper1 committed
72
73
74
75
76
77
    i = width * row + col;
    res = __int2half_rd(0);
    k = 0;

    __syncthreads();
    if (threadIdx.x < blockwidth2)
78
79
80
      blockvec[threadIdx.x] =
          vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
              threadIdx.x];
chooper1's avatar
chooper1 committed
81
82
83
84
85
    __syncthreads();

    while (k < blockwidth2) {
      tmp1 = as_unsigned(mat[i]);

86
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
87
88
      res2 = {};
      tmp2 = {};
89
90
91
92
93
94
#else
      res2.x = __half_as_ushort(__float2half(0));
      res2.y = __half_as_ushort(__float2half(0));
      tmp2.x = __half_as_ushort(__float2half(0));
      tmp2.y = __half_as_ushort(__float2half(0));
#endif
chooper1's avatar
chooper1 committed
95
96
97

      lut_index1 = tmp1 & 0xF;
      lut_index2 = (tmp1 >> 4) & 0xF;
98
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
99
100
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
101
102
103
104
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
105
106
107
108
      res2 = __hfma2(tmp2, blockvec[k + 0], res2);

      lut_index1 = (tmp1 >> 8) & 0xF;
      lut_index2 = (tmp1 >> 12) & 0xF;
109
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
110
111
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
112
113
114
115
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
116
117
118
119
      res2 = __hfma2(tmp2, blockvec[k + 1], res2);

      lut_index1 = (tmp1 >> 16) & 0xF;
      lut_index2 = (tmp1 >> 20) & 0xF;
120
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
121
122
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
123
124
125
126
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
127
128
129
130
      res2 = __hfma2(tmp2, blockvec[k + 2], res2);

      lut_index1 = (tmp1 >> 24) & 0xF;
      lut_index2 = (tmp1 >> 28) & 0xF;
131
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
132
133
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
134
135
136
137
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
138
139
      res2 = __hfma2(tmp2, blockvec[k + 3], res2);

140
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
141
      res = __hadd(__hadd(res2.x, res2.y), res);
142
#else
143
144
      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
                   res);
145
#endif
chooper1's avatar
chooper1 committed
146
147
148
149
150
151

      i += width;
      k += 4;
    }

    // col%2 -> only set one of the two values
152
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
153
154
155
156
157
158
    half2 res3 = {};
    if (col % 2 == 0) {
      res3.x = res;
    } else {
      res3.y = res;
    }
159
160
161
162
163
164
165
166
167
168
#else
    __half2 res3;
    res3.x = __half_as_ushort(__float2half(0));
    res3.y = __half_as_ushort(__float2half(0));
    if (col % 2 == 0) {
      res3.x = __half_as_ushort(res);
    } else {
      res3.y = __half_as_ushort(res);
    }
#endif
chooper1's avatar
chooper1 committed
169

170
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
171
    atomicAdd(&mul[b * width / 2 + col / 2], res3);
172
173
174
175
176
#else
    int tmp_addr = b * width / 2 + col / 2;
    atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
    atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
chooper1's avatar
chooper1 committed
177
178
179
  }
}

180
181
}  // namespace squeezellm
}  // namespace vllm
chooper1's avatar
chooper1 committed
182
183

// 4-bit matvec kernel (LUT-based)
184
185
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
                     torch::Tensor lookup_table) {
chooper1's avatar
chooper1 committed
186
187
188
189
190
191
  int height = mat.size(0);
  int width = mat.size(1);

  int batch = vec.size(0);
  int vec_height = vec.size(1);

192
193
  dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
              (width + BLOCKWIDTH - 1) / BLOCKWIDTH);
chooper1's avatar
chooper1 committed
194
  dim3 threads(BLOCKWIDTH);
195

196
  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
197
198
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
199
#ifndef USE_ROCM
200
      (half2*)vec.data<at::Half>(),
201
#else
202
      (__half2*)vec.data_ptr<at::Half>(),
203
#endif
204
      mat.data_ptr<int>(),
205
#ifndef USE_ROCM
206
      (half2*)mul.data<at::Half>(), (__half*)lookup_table.data<at::Half>(),
207
#else
208
209
      (float2*)mul.data_ptr<float>(),
      (__half*)lookup_table.data_ptr<at::Half>(),
210
#endif
211
      height, width, batch, vec_height);
chooper1's avatar
chooper1 committed
212
213
214
215
}

#undef BLOCKWIDTH
#undef BLOCKHEIGHT4