gemm_kernels.cu 20.7 KB
Newer Older
James Fleming's avatar
James Fleming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*
 * Modified by Neural Magic
 * Adapted from https://github.com/Vahe1994/AQLM
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
21
#include <torch/all.h>
James Fleming's avatar
James Fleming committed
22
23
24
25
26
27
28
29
30
31
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>

#include <iostream>
#include <cstdlib>

namespace vllm {
namespace aqlm {

__global__ void Code1x16MatVec(
32
33
34
35
36
37
    const int4* __restrict__ A, const int4* __restrict__ B,
    int4* __restrict__ C, const int4* __restrict__ codebook, const int prob_m,
    const int prob_k,
    const int4 codebook_a_sizes,  // cumulative sizes of A spanning each
                                  // codebook, at most 3 long.
    const int codebook_stride     // as int4.
James Fleming's avatar
James Fleming committed
38
39
40
41
42
) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

43
44
45
  if (pred) {
    // advance to the correct codebook, this easy because we only multiply one
    // column of the codebook.
James Fleming's avatar
James Fleming committed
46
    auto codebook_size = &codebook_a_sizes.x;
47
48
49
    while (a_gl_rd >= *codebook_size) {
      codebook += codebook_stride;
      ++codebook_size;
James Fleming's avatar
James Fleming committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    }
  }

  int b_gl_rd = 0;
  int c_gl_wr = a_gl_rd;
  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;

  __shared__ int4 sh_b[32 * 9];
  float res = 0;

  int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  while (iters--) {
    // We pad shared memory to avoid bank conflicts during reads
    __syncthreads();
    for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
66
      if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
James Fleming's avatar
James Fleming committed
67
68
69
70
71
72
73
    }
    __syncthreads();
    b_gl_rd += 32 * 8;

    int b_sh_rd = 9 * (threadIdx.x % 32);
    if (pred && a_gl_rd < a_gl_end) {
      const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
74
#pragma unroll
James Fleming's avatar
James Fleming committed
75
76
      for (int i = 0; i < 8; i++) {
        uint32_t dec[4];
77
78
79
80
81
        // We bypass the L1 cache to avoid massive amounts of memory streaming
        // that doesn't actually help us; this brings > 2x speedup.
        asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
                     : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
                     : "l"((void*)&codebook[enc[i]]));
James Fleming's avatar
James Fleming committed
82
83
84
        half2* a = reinterpret_cast<half2*>(&dec);
        half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
        half2 res2 = {};
85
86
#pragma unroll
        for (int j = 0; j < 4; j++) res2 = __hfma2(a[j], b[j], res2);
James Fleming's avatar
James Fleming committed
87
88
89
90
91
92
93
94
        res += __half2float(res2.x) + __half2float(res2.y);
        b_sh_rd++;
      }
      a_gl_rd += 32;
    }
  }

  if (pred) {
95
96
#pragma unroll
    for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
James Fleming's avatar
James Fleming committed
97
98
99
100
101
102
    if (threadIdx.x % 32 == 0)
      reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  }
}

__global__ void Code2x8MatVec(
103
104
105
106
107
108
    const int4* __restrict__ A, const int4* __restrict__ B,
    int4* __restrict__ C, const int4* __restrict__ codebook, int prob_m,
    int prob_k,
    const int4 codebook_a_sizes,  // cumulative sizes of A spanning each
                                  // codebook, at most 3 long.
    const int codebook_stride     // as int4.
James Fleming's avatar
James Fleming committed
109
110
111
112
113
114

) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

115
116
117
  if (pred) {
    // advance to the correct codebook, this easy because we only multiply one
    // column of the codebook.
James Fleming's avatar
James Fleming committed
118
    auto codebook_size = &codebook_a_sizes.x;
119
120
121
    while (a_gl_rd >= *codebook_size) {
      codebook += codebook_stride;
      ++codebook_size;
James Fleming's avatar
James Fleming committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
    }
  }

  int b_gl_rd = 0;
  int c_gl_wr = a_gl_rd;
  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  int lane = threadIdx.x % 8;

  extern __shared__ int4 sh[];
  int4* sh_b = sh;
  int4* sh_code = sh_b + 32 * 9;
  int4* sh_code0 = sh_code;
  int4* sh_code1 = sh_code + 256 * 8;

  for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
    int4 dec = codebook[i];
139
140
#pragma unroll
    for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
James Fleming's avatar
James Fleming committed
141
142
143
144
145
146
147
148
149
150
  }
  __syncthreads();

  float res = 0;

  int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  while (iters--) {
    // We pad shared memory to avoid bank conflicts during reads
    __syncthreads();
    for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
151
      if (b_gl_rd + i < prob_k / 8) sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
James Fleming's avatar
James Fleming committed
152
153
154
155
156
157
158
    }
    __syncthreads();
    b_gl_rd += 32 * 8;

    int b_sh_rd = 9 * (threadIdx.x % 32);
    if (pred && a_gl_rd < a_gl_end) {
      const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
159
#pragma unroll
James Fleming's avatar
James Fleming committed
160
      for (int i = 0; i < 8; i++) {
161
162
163
164
165
        half2* a0 =
            reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
        half2* a1 =
            reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
        half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
James Fleming's avatar
James Fleming committed
166
        half2 res2 = {};
167
#pragma unroll
James Fleming's avatar
James Fleming committed
168
169
170
171
172
173
174
175
176
177
        for (int j = 0; j < 4; j++)
          res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
        res += __half2float(res2.x) + __half2float(res2.y);
        b_sh_rd++;
      }
      a_gl_rd += 32;
    }
  }

  if (pred) {
178
179
#pragma unroll
    for (int i = 16; i > 0; i /= 2) res += __shfl_down_sync(0xffffffff, res, i);
James Fleming's avatar
James Fleming committed
180
181
182
183
184
185
    if (threadIdx.x % 32 == 0)
      reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  }
}

__global__ void Code1x16Dequant(
186
187
188
189
190
    const int4* __restrict__ A, int4* __restrict__ C,
    const int4* __restrict__ codebook, int prob_m, int prob_k,
    const int4 codebook_a_sizes,  // cumulative sizes of A spanning each
                                  // codebook, at most 3 long, sums to m.
    const int codebook_stride     // as int4
James Fleming's avatar
James Fleming committed
191
192
193
194
195
) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

196
197
198
  if (pred) {
    // advance to the correct codebook, this easy because we only multiply one
    // column of the codebook.
James Fleming's avatar
James Fleming committed
199
    auto codebook_size = &codebook_a_sizes.x;
200
201
202
    while (a_gl_rd >= *codebook_size) {
      codebook += codebook_stride;
      ++codebook_size;
James Fleming's avatar
James Fleming committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    }
  }

  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;

  int c_gl_stride = prob_k / 8;
  int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;

  int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
  while (iters--) {
    if (pred && a_gl_rd < a_gl_end) {
      const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
217
#pragma unroll
James Fleming's avatar
James Fleming committed
218
219
220
      for (int i = 0; i < 8; i++) {
        int4 chunk;
        auto dec = reinterpret_cast<uint32_t*>(&chunk);
221
222
223
224
225
        // We bypass the L1 cache to avoid massive amounts of memory streaming
        // that doesn't actually help us; this brings > 2x speedup.
        asm volatile("ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
                     : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
                     : "l"((void*)&codebook[enc[i]]));
James Fleming's avatar
James Fleming committed
226
227
228
229
230
231
232
233
234

        C[a_gl_rd * 8 + i] = chunk;
      }
    }
    a_gl_rd += 32;
  }
}

__global__ void Code2x8Dequant(
235
236
237
238
239
240
    const int4* __restrict__ A, int4* __restrict__ C,
    const int4* __restrict__ codebook, int prob_m, int prob_k,
    const int4
        codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at
                           // most 3 long, corresponds to cols.
    const int codebook_stride  // as int4
James Fleming's avatar
James Fleming committed
241
242
243
244
245
) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

246
247
248
  if (pred) {
    // advance to the correct codebook, this easy because we only multiply one
    // column of the codebook.
James Fleming's avatar
James Fleming committed
249
    auto codebook_size = &codebook_a_sizes.x;
250
251
252
    while (a_gl_rd >= *codebook_size) {
      codebook += codebook_stride;
      ++codebook_size;
James Fleming's avatar
James Fleming committed
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
    }
  }

  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  int lane = threadIdx.x % 8;

  int c_gl_stride = prob_k / 8;
  int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;

  extern __shared__ int4 sh[];
  int4* sh_code = sh;
  int4* sh_code0 = sh_code;
  int4* sh_code1 = sh_code + 256 * 8;

  for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
    int4 dec = codebook[i];
271
272
#pragma unroll
    for (int j = 0; j < 8; j++) sh_code[8 * i + (j + lane) % 8] = dec;
James Fleming's avatar
James Fleming committed
273
274
275
276
277
278
279
  }
  __syncthreads();

  int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
  while (iters--) {
    if (pred && a_gl_rd < a_gl_end) {
      const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
280
#pragma unroll
James Fleming's avatar
James Fleming committed
281
282
      for (int i = 0; i < 8; i++) {
        int4 chunk;
283
284
285
286
287
        half2* a0 =
            reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
        half2* a1 =
            reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
#pragma unroll
James Fleming's avatar
James Fleming committed
288
289
290
291
292
293
294
295
296
        for (int j = 0; j < 4; j++)
          reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
        C[a_gl_rd * 8 + i] = chunk;
      }
    }
    a_gl_rd += 32;
  }
}

297
inline int ceildiv(int a, int b) { return (a + b - 1) / b; }
James Fleming's avatar
James Fleming committed
298
299
300

const int THREAD_M = 16;

301
302
303
304
305
void code1x16_matvec_cuda(const void* __restrict__ A,
                          const void* __restrict__ B, void* __restrict__ C,
                          const void* __restrict__ codebook, int prob_m,
                          int prob_k, const int4 codebook_a_sizes,
                          const int codebook_stride) {
James Fleming's avatar
James Fleming committed
306
307
308
309
310
311
312
313
314
315
316
317
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
318
319
320
  Code1x16MatVec<<<blocks, threads, 16 * 32 * 9, stream>>>(
      (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
      prob_k, codebook_a_sizes, codebook_stride);
James Fleming's avatar
James Fleming committed
321
322
}

323
324
325
326
327
void code2x8_matvec_cuda(const void* __restrict__ A, const void* __restrict__ B,
                         void* __restrict__ C,
                         const void* __restrict__ codebook, int prob_m,
                         int prob_k, const int4 codebook_a_sizes,
                         const int codebook_stride) {
James Fleming's avatar
James Fleming committed
328
329
330
331
332
333
334
335
336
337
338
339
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  int shared = 16 * (2 * 256 * 8 + 32 * 9);
340
341
  cudaFuncSetAttribute(Code2x8MatVec,
                       cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
James Fleming's avatar
James Fleming committed
342
343
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  Code2x8MatVec<<<blocks, threads, shared, stream>>>(
344
345
      (const int4*)A, (const int4*)B, (int4*)C, (const int4*)codebook, prob_m,
      prob_k, codebook_a_sizes, codebook_stride);
James Fleming's avatar
James Fleming committed
346
347
348
}

void code1x16_dequant_cuda(
349
350
351
352
353
    const void* __restrict__ A, void* __restrict__ C,
    const void* __restrict__ codebook, int prob_m, int prob_k,
    const int4 codebook_a_sizes,  // cumulative sizes of A spanning each
                                  // codebook, at most 3 long.
    const int codebook_stride     // as int4.
James Fleming's avatar
James Fleming committed
354
355
356
357
358
359
360
361
362
363
364
365
366
367
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  Code1x16Dequant<<<blocks, threads, 0, stream>>>(
368
369
370
371
      (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
      codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at
                         // most 3 long.
      codebook_stride    // as int4.
James Fleming's avatar
James Fleming committed
372
373
374
375
  );
}

// Dequantizes the code and codebook into weights.
376
377
378
379
380
381
382
void code2x8_dequant_cuda(
    const void* __restrict__ A, void* __restrict__ C,
    const void* __restrict__ codebook, int prob_m, int prob_k,
    const int4
        codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at
                           // most 3 long, corresponds to cols.
    const int codebook_stride  // as int4
James Fleming's avatar
James Fleming committed
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  int shared = 16 * (2 * 256 * 8 + 32 * 9);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

398
399
  cudaFuncSetAttribute(Code2x8Dequant,
                       cudaFuncAttributeMaxDynamicSharedMemorySize, shared);
James Fleming's avatar
James Fleming committed
400
  Code2x8Dequant<<<blocks, threads, shared, stream>>>(
401
402
      (const int4*)A, (int4*)C, (const int4*)codebook, prob_m, prob_k,
      codebook_a_sizes, codebook_stride);
James Fleming's avatar
James Fleming committed
403
404
}

405
int codebook_stride(const torch::Tensor& codebooks) {
James Fleming's avatar
James Fleming committed
406
407
408
409
  return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
}

void code1x16_matvec(
410
411
412
413
    const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C,
    const torch::Tensor& codebook,
    const int4 codebook_a_sizes  // cumulative sizes of A spanning each
                                 // codebook, at most 3 long.
James Fleming's avatar
James Fleming committed
414
415
416
417
418
) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  int prob_m = C.size(0);
  int prob_k = B.size(0);

419
420
421
  code1x16_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
                       codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
                       codebook_stride(codebook));
James Fleming's avatar
James Fleming committed
422
423
}

424
425
426
427
428
429
torch::Tensor code1x16_matmat(const torch::Tensor& input,
                              const torch::Tensor& codes,
                              const torch::Tensor& codebooks,
                              const torch::Tensor& scales,
                              const int4 codebook_a_sizes,
                              const std::optional<torch::Tensor>& bias) {
James Fleming's avatar
James Fleming committed
430
431
432
  auto input_sizes = input.sizes();
  auto out_features = codes.size(0) * codebooks.size(2);
  auto flat_input = input.reshape({-1, input.size(-1)});
433
434
435
  auto flat_output = torch::empty(
      {flat_input.size(0), out_features},
      torch::TensorOptions().dtype(input.dtype()).device(input.device()));
James Fleming's avatar
James Fleming committed
436
437
438
439

  for (int i = 0; i < flat_input.size(0); ++i) {
    auto input_vec = flat_input.index({i});
    auto output_vec = flat_output.index({i});
440
441
    code1x16_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
                    codebook_a_sizes);
James Fleming's avatar
James Fleming committed
442
443
444
445
446
447
448
449
450
451
452
453
454
455
  }
  flat_output *= scales.flatten().unsqueeze(0);

  if (bias.has_value()) {
    flat_output += bias->unsqueeze(0);
  }

  auto output_sizes = input_sizes.vec();
  output_sizes.pop_back();
  output_sizes.push_back(-1);
  auto output = flat_output.reshape(output_sizes);
  return output;
}

456
457
458
void code2x8_matvec(const torch::Tensor& A, const torch::Tensor& B,
                    torch::Tensor& C, const torch::Tensor& codebook,
                    const int4 codebook_a_sizes) {
James Fleming's avatar
James Fleming committed
459
460
461
  const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  int prob_m = C.size(0);
  int prob_k = B.size(0);
462
463
464
  code2x8_matvec_cuda(A.data_ptr(), B.data_ptr(), C.data_ptr(),
                      codebook.data_ptr(), prob_m, prob_k, codebook_a_sizes,
                      2 * codebook_stride(codebook));
James Fleming's avatar
James Fleming committed
465
466
}

467
468
469
470
471
472
torch::Tensor code2x8_matmat(const torch::Tensor& input,
                             const torch::Tensor& codes,
                             const torch::Tensor& codebooks,
                             const torch::Tensor& scales,
                             const int4 codebook_a_sizes,
                             const std::optional<torch::Tensor>& bias) {
James Fleming's avatar
James Fleming committed
473
474
475
  auto input_sizes = input.sizes();
  auto out_features = codes.size(0) * codebooks.size(2);
  auto flat_input = input.reshape({-1, input.size(-1)});
476
477
478
  auto flat_output = torch::empty(
      {flat_input.size(0), out_features},
      torch::TensorOptions().dtype(input.dtype()).device(input.device()));
James Fleming's avatar
James Fleming committed
479
480
481
482

  for (int i = 0; i < flat_input.size(0); ++i) {
    auto input_vec = flat_input.index({i});
    auto output_vec = flat_output.index({i});
483
484
    code2x8_matvec(codes.squeeze(2), input_vec, output_vec, codebooks,
                   codebook_a_sizes);
James Fleming's avatar
James Fleming committed
485
486
487
488
489
490
491
492
493
494
495
496
497
498
  }
  flat_output *= scales.flatten().unsqueeze(0);
  if (bias.has_value()) {
    flat_output += bias->unsqueeze(0);
  }

  auto output_sizes = input_sizes.vec();
  output_sizes.pop_back();
  output_sizes.push_back(-1);
  auto output = flat_output.reshape(output_sizes);
  return output;
}

// Accumulate the partition sizes.
499
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) {
James Fleming's avatar
James Fleming committed
500
501
502
503
504
  int4 cumulative_sizes;
  auto cumulative_size = &cumulative_sizes.x;
  int i = 0;
  int last = 0;
  assert(codebook_partition_sizes.size(0) <= 4);
505
  for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) {
James Fleming's avatar
James Fleming committed
506
507
508
509
    *cumulative_size = codebook_partition_sizes[i].item<int>() + last;
    last = *cumulative_size;
  }
  // fill in the rest with unreachable.
510
511
  for (; i < 4; ++i, ++cumulative_size) {
    *cumulative_size = last * 10;
James Fleming's avatar
James Fleming committed
512
513
514
515
  }
  return cumulative_sizes;
}

516
517
}  // namespace aqlm
}  // namespace vllm
James Fleming's avatar
James Fleming committed
518

519
520
521
522
523
524
525
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
                        const torch::Tensor& codebooks,
                        const torch::Tensor& scales,
                        const torch::Tensor& codebook_partition_sizes,
                        const std::optional<torch::Tensor>& bias) {
  int4 cumulative_sizes =
      vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
James Fleming's avatar
James Fleming committed
526
527
528
529

  int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
  int const entries = codebooks.size(1);

530
531
532
  if (nbooks == 1 && entries == (1 << 16)) {
    return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales,
                                       cumulative_sizes, bias);
James Fleming's avatar
James Fleming committed
533
  }
534
535
536
  if (nbooks == 2 && entries == (1 << 8)) {
    return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales,
                                      cumulative_sizes, bias);
James Fleming's avatar
James Fleming committed
537
538
  }

539
540
  TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
              " entries is not currently supported.")
James Fleming's avatar
James Fleming committed
541
542
543
  return {};
}

544
545
546
547
548
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
                           const torch::Tensor& codebooks,
                           const torch::Tensor& codebook_partition_sizes) {
  int4 cumulative_sizes =
      vllm::aqlm::accumulate_sizes(codebook_partition_sizes);
James Fleming's avatar
James Fleming committed
549
550
551
552
553
554
555
556
557
558
559
560
561
562

  int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
  int const entries = codebooks.size(1);

  const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
  int rows = codes.size(1);
  int cols = codes.size(0);

  auto in_features = codes.size(1) * 8;
  auto out_features = codes.size(0);

  assert(out_features = codebook_partition_sizes.sum().item<int>());

  auto weights = torch::empty({out_features, in_features},
563
564
565
566
567
568
569
570
571
                              torch::TensorOptions()
                                  .dtype(codebooks.dtype())
                                  .device(codebooks.device()));

  if (nbooks == 1 && entries == (1 << 16)) {
    vllm::aqlm::code1x16_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
                                      codebooks.data_ptr(), out_features,
                                      in_features, cumulative_sizes,
                                      vllm::aqlm::codebook_stride(codebooks));
James Fleming's avatar
James Fleming committed
572

573
574
575
576
577
    // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
    // and not consistent with gemv implementation.) weights *=
    // scales.index({"...", 0, 0});

    return weights;
James Fleming's avatar
James Fleming committed
578
579
  }

580
581
582
583
584
585
586
587
588
589
590
  if (nbooks == 2 && entries == (1 << 8)) {
    vllm::aqlm::code2x8_dequant_cuda(codes.data_ptr(), weights.data_ptr(),
                                     codebooks.data_ptr(), out_features,
                                     in_features, cumulative_sizes,
                                     vllm::aqlm::codebook_stride(codebooks));

    // if you wanted to flip to scaling the weights, (though it's 30%-ish slower
    // and not consistent with gemv implementation) weights *=
    // scales.index({"...", 0, 0});

    return weights;
James Fleming's avatar
James Fleming committed
591
592
  }

593
594
  TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries,
              " entries is not currently supported.")
James Fleming's avatar
James Fleming committed
595
596
  return {};
}