"sgl-kernel/vscode:/vscode.git/clone" did not exist on "d353d08b4e8987f6e4a9c6e36c266c4dc00e7942"
quant_cuda_kernel.cu 5.37 KB
Newer Older
chooper1's avatar
chooper1 committed
1
2
3
4
5
6
7
8
9
#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
10
#include <c10/cuda/CUDAGuard.h>
chooper1's avatar
chooper1 committed
11
12
13
14
15
16
17
18
19
20
21
22
23

#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16

namespace vllm {
namespace squeezellm {

__device__ inline unsigned int as_unsigned(int i) {
  return *reinterpret_cast<unsigned int*>(&i);
}

// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
24
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
25
    const  half2* __restrict__ vec,
26
27
28
#else
    const  __half2* __restrict__ vec,
#endif
chooper1's avatar
chooper1 committed
29
    const    int* __restrict__ mat,
30
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
31
           half2* __restrict__ mul,
32
33
34
#else
          float2* __restrict__ mul,
#endif
chooper1's avatar
chooper1 committed
35
36
37
38
39
40
41
42
43
44
45
46
    const  __half* __restrict__ lookup_table,
    int height,
    int width,
    int batch,
    int vec_height
) {

  const int blockwidth2 = BLOCKWIDTH / 2;

  int row = BLOCKHEIGHT4 * blockIdx.x;
  int col =  BLOCKWIDTH * blockIdx.y + threadIdx.x;

47
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
48
  __shared__ half2 blockvec[blockwidth2];
49
50
51
#else
  __shared__ __half2 blockvec[blockwidth2];
#endif
chooper1's avatar
chooper1 committed
52
53
54
55
56
57
58
59
60
61

  __shared__ __half deq2[16][BLOCKWIDTH];
  int off = threadIdx.x;
  int column_offset = col * 16;
  for (int val = 0; val < 16; val += 1) {
    int lut_index = column_offset + val;
    deq2[val][off] = lookup_table[lut_index];
  }

  __half res;
62
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
63
64
  half2 res2;
  half2 tmp2;
65
66
67
68
#else
  __half2 res2;
  __half2 tmp2;
#endif
chooper1's avatar
chooper1 committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

  int i;
  int k;

  unsigned int tmp1;
  unsigned int lut_index1, lut_index2;

  for (int b = 0; b < batch; ++b){
    i = width * row + col;
    res = __int2half_rd(0);
    k = 0;

    __syncthreads();
    if (threadIdx.x < blockwidth2)
      blockvec[threadIdx.x] = vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + threadIdx.x];
    __syncthreads();

    while (k < blockwidth2) {
      tmp1 = as_unsigned(mat[i]);

89
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
90
91
      res2 = {};
      tmp2 = {};
92
93
94
95
96
97
#else
      res2.x = __half_as_ushort(__float2half(0));
      res2.y = __half_as_ushort(__float2half(0));
      tmp2.x = __half_as_ushort(__float2half(0));
      tmp2.y = __half_as_ushort(__float2half(0));
#endif
chooper1's avatar
chooper1 committed
98
99
100

      lut_index1 = tmp1 & 0xF;
      lut_index2 = (tmp1 >> 4) & 0xF;
101
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
102
103
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
104
105
106
107
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
108
109
110
111
      res2 = __hfma2(tmp2, blockvec[k + 0], res2);

      lut_index1 = (tmp1 >> 8) & 0xF;
      lut_index2 = (tmp1 >> 12) & 0xF;
112
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
113
114
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
115
116
117
118
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
119
120
121
122
      res2 = __hfma2(tmp2, blockvec[k + 1], res2);

      lut_index1 = (tmp1 >> 16) & 0xF;
      lut_index2 = (tmp1 >> 20) & 0xF;
123
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
124
125
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
126
127
128
129
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
130
131
132
133
      res2 = __hfma2(tmp2, blockvec[k + 2], res2);

      lut_index1 = (tmp1 >> 24) & 0xF;
      lut_index2 = (tmp1 >> 28) & 0xF;
134
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
135
136
      tmp2.x = deq2[lut_index1][off];
      tmp2.y = deq2[lut_index2][off];
137
138
139
140
#else
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
chooper1's avatar
chooper1 committed
141
142
      res2 = __hfma2(tmp2, blockvec[k + 3], res2);

143
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
144
      res = __hadd(__hadd(res2.x, res2.y), res);
145
146
147
#else
      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
#endif
chooper1's avatar
chooper1 committed
148
149
150
151
152
153

      i += width;
      k += 4;
    }

    // col%2 -> only set one of the two values
154
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
155
156
157
158
159
160
    half2 res3 = {};
    if (col % 2 == 0) {
      res3.x = res;
    } else {
      res3.y = res;
    }
161
162
163
164
165
166
167
168
169
170
#else
    __half2 res3;
    res3.x = __half_as_ushort(__float2half(0));
    res3.y = __half_as_ushort(__float2half(0));
    if (col % 2 == 0) {
      res3.x = __half_as_ushort(res);
    } else {
      res3.y = __half_as_ushort(res);
    }
#endif
chooper1's avatar
chooper1 committed
171

172
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
173
    atomicAdd(&mul[b * width / 2 + col / 2], res3);
174
175
176
177
178
#else
    int tmp_addr = b * width / 2 + col / 2;
    atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
    atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
chooper1's avatar
chooper1 committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
  }
}

} // namespace squeezellm
} // namespace vllm

// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(
  torch::Tensor vec,
  torch::Tensor mat,
  torch::Tensor mul,
  torch::Tensor lookup_table
) {
  int height = mat.size(0);
  int width = mat.size(1);

  int batch = vec.size(0);
  int vec_height = vec.size(1);

  dim3 blocks(
    (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
    (width + BLOCKWIDTH - 1) / BLOCKWIDTH
  );
  dim3 threads(BLOCKWIDTH);
203

204
  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
205
206
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
207
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
208
    (half2*) vec.data<at::Half>(),
209
210
211
#else
    (__half2*) vec.data_ptr<at::Half>(),
#endif
chooper1's avatar
chooper1 committed
212
    mat.data_ptr<int>(),
213
#ifndef USE_ROCM
chooper1's avatar
chooper1 committed
214
215
    (half2*) mul.data<at::Half>(),
    (__half*) lookup_table.data<at::Half>(),
216
217
218
219
#else
    (float2*) mul.data_ptr<float>(),
    (__half*) lookup_table.data_ptr<at::Half>(),
#endif
chooper1's avatar
chooper1 committed
220
221
222
223
224
225
    height, width, batch, vec_height
  );
}

#undef BLOCKWIDTH
#undef BLOCKHEIGHT4