gemm_kernels.cu 19.8 KB
Newer Older
James Fleming's avatar
James Fleming committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
/*
 * Modified by Neural Magic
 * Adapted from https://github.com/Vahe1994/AQLM
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *         http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAStream.h>
#include <c10/cuda/CUDAGuard.h>

#include <iostream>
#include <cstdlib>


namespace vllm {
namespace aqlm {

__global__ void Code1x16MatVec(
  const int4* __restrict__ A,
  const int4* __restrict__ B,
        int4* __restrict__ C,
  const int4* __restrict__ codebook,
  const int prob_m,
  const int prob_k,
  const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long.
  const int codebook_stride // as int4.
) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

  if (pred)
  {
    // advance to the correct codebook, this easy because we only multiply one column of the codebook.
    auto codebook_size = &codebook_a_sizes.x;
    while (a_gl_rd >= *codebook_size)
    {
        codebook += codebook_stride;
        ++codebook_size;
    }
  }

  int b_gl_rd = 0;
  int c_gl_wr = a_gl_rd;
  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;

  __shared__ int4 sh_b[32 * 9];
  float res = 0;

  int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  while (iters--) {
    // We pad shared memory to avoid bank conflicts during reads
    __syncthreads();
    for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
      if (b_gl_rd + i < prob_k / 8)
        sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
    }
    __syncthreads();
    b_gl_rd += 32 * 8;

    int b_sh_rd = 9 * (threadIdx.x % 32);
    if (pred && a_gl_rd < a_gl_end) {
      const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
      #pragma unroll
      for (int i = 0; i < 8; i++) {
        uint32_t dec[4];
        // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
        // actually help us; this brings > 2x speedup.
        asm volatile (
          "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
          : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
          : "l"((void*) &codebook[enc[i]])
        );
        half2* a = reinterpret_cast<half2*>(&dec);
        half2* b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
        half2 res2 = {};
        #pragma unroll
        for (int j = 0; j < 4; j++)
          res2 = __hfma2(a[j], b[j], res2);
        res += __half2float(res2.x) + __half2float(res2.y);
        b_sh_rd++;
      }
      a_gl_rd += 32;
    }
  }

  if (pred) {
    #pragma unroll
    for (int i = 16; i > 0; i /= 2)
      res += __shfl_down_sync(0xffffffff, res, i);
    if (threadIdx.x % 32 == 0)
      reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  }
}

__global__ void Code2x8MatVec(
  const int4* __restrict__ A,
  const int4* __restrict__ B,
        int4* __restrict__ C,
  const int4* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long.
  const int codebook_stride // as int4.

) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

  if (pred)
  {
    // advance to the correct codebook, this easy because we only multiply one column of the codebook.
    auto codebook_size = &codebook_a_sizes.x;
    while (a_gl_rd >= *codebook_size)
    {
        codebook += codebook_stride;
        ++codebook_size;
    }
  }

  int b_gl_rd = 0;
  int c_gl_wr = a_gl_rd;
  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  int lane = threadIdx.x % 8;

  extern __shared__ int4 sh[];
  int4* sh_b = sh;
  int4* sh_code = sh_b + 32 * 9;
  int4* sh_code0 = sh_code;
  int4* sh_code1 = sh_code + 256 * 8;

  for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
    int4 dec = codebook[i];
    #pragma unroll
    for (int j = 0; j < 8; j++)
      sh_code[8 * i + (j + lane) % 8] = dec;
  }
  __syncthreads();

  float res = 0;

  int iters = (prob_k / 8 + 8 * 32 - 1) / (8 * 32);
  while (iters--) {
    // We pad shared memory to avoid bank conflicts during reads
    __syncthreads();
    for (int i = threadIdx.x; i < 32 * 8; i += blockDim.x) {
      if (b_gl_rd + i < prob_k / 8)
        sh_b[9 * (i / 8) + i % 8] = B[b_gl_rd + i];
    }
    __syncthreads();
    b_gl_rd += 32 * 8;

    int b_sh_rd = 9 * (threadIdx.x % 32);
    if (pred && a_gl_rd < a_gl_end) {
      const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
      #pragma unroll
      for (int i = 0; i < 8; i++) {
        half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
        half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
        half2*  b = reinterpret_cast<half2*>(&sh_b[b_sh_rd]);
        half2 res2 = {};
        #pragma unroll
        for (int j = 0; j < 4; j++)
          res2 = __hfma2(__hadd2(a0[j], a1[j]), b[j], res2);
        res += __half2float(res2.x) + __half2float(res2.y);
        b_sh_rd++;
      }
      a_gl_rd += 32;
    }
  }

  if (pred) {
    #pragma unroll
    for (int i = 16; i > 0; i /= 2)
      res += __shfl_down_sync(0xffffffff, res, i);
    if (threadIdx.x % 32 == 0)
      reinterpret_cast<__half*>(C)[c_gl_wr] = __float2half(res);
  }
}


__global__ void Code1x16Dequant(
  const int4* __restrict__ A,
        int4* __restrict__ C,
  const int4* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
  const int codebook_stride // as int4
) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

  if (pred)
  {
    // advance to the correct codebook, this easy because we only multiply one column of the codebook.
    auto codebook_size = &codebook_a_sizes.x;
    while (a_gl_rd >= *codebook_size)
    {
        codebook += codebook_stride;
        ++codebook_size;
    }
  }

  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;

  int c_gl_stride = prob_k / 8;
  int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;

  int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
  while (iters--) {
    if (pred && a_gl_rd < a_gl_end) {
      const uint16_t* enc = reinterpret_cast<const uint16_t*>(&A[a_gl_rd]);
      #pragma unroll
      for (int i = 0; i < 8; i++) {
        int4 chunk;
        auto dec = reinterpret_cast<uint32_t*>(&chunk);
        // We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
        // actually help us; this brings > 2x speedup.
        asm volatile (
          "ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
          : "=r"(dec[0]), "=r"(dec[1]), "=r"(dec[2]), "=r"(dec[3])
          : "l"((void*) &codebook[enc[i]])
        );

        C[a_gl_rd * 8 + i] = chunk;
      }
    }
    a_gl_rd += 32;
  }
}


__global__ void Code2x8Dequant(
  const int4* __restrict__ A,
        int4* __restrict__ C,
  const int4* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
  const int codebook_stride // as int4
) {
  int a_gl_stride = prob_k / 8 / 8;
  int a_gl_rd = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  bool pred = a_gl_rd < prob_m;

  if (pred)
  {
    // advance to the correct codebook, this easy because we only multiply one column of the codebook.
    auto codebook_size = &codebook_a_sizes.x;
    while (a_gl_rd >= *codebook_size)
    {
        codebook += codebook_stride;
        ++codebook_size;
    }
  }

  a_gl_rd = a_gl_stride * a_gl_rd + threadIdx.x % 32;
  int a_gl_end = a_gl_rd + a_gl_stride - threadIdx.x % 32;
  int lane = threadIdx.x % 8;

  int c_gl_stride = prob_k / 8;
  int c_gl_wr = (blockDim.x / 32) * blockIdx.x + (threadIdx.x / 32);
  c_gl_wr = c_gl_stride * c_gl_wr + (threadIdx.x % 32) * 8;

  extern __shared__ int4 sh[];
  int4* sh_code = sh;
  int4* sh_code0 = sh_code;
  int4* sh_code1 = sh_code + 256 * 8;

  for (int i = threadIdx.x; i < 2 * 256; i += blockDim.x) {
    int4 dec = codebook[i];
    #pragma unroll
    for (int j = 0; j < 8; j++)
      sh_code[8 * i + (j + lane) % 8] = dec;
  }
  __syncthreads();

  float res = 0;

  int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
  while (iters--) {
    if (pred && a_gl_rd < a_gl_end) {
      const uint8_t* enc = reinterpret_cast<const uint8_t*>(&A[a_gl_rd]);
      #pragma unroll
      for (int i = 0; i < 8; i++) {
        int4 chunk;
        half2* a0 = reinterpret_cast<half2*>(&sh_code0[8 * enc[2 * i + 0] + lane]);
        half2* a1 = reinterpret_cast<half2*>(&sh_code1[8 * enc[2 * i + 1] + lane]);
        #pragma unroll
        for (int j = 0; j < 4; j++)
          reinterpret_cast<half2*>(&chunk)[j] = __hadd2(a0[j], a1[j]);
        C[a_gl_rd * 8 + i] = chunk;
      }
    }
    a_gl_rd += 32;
  }
}

inline int ceildiv(int a, int b) {
  return (a + b - 1) / b;
}

const int THREAD_M = 16;

void  code1x16_matvec_cuda(
  const void* __restrict__ A,
  const void* __restrict__ B,
        void* __restrict__ C,
  const void* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,
  const int codebook_stride
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  Code1x16MatVec<<<blocks, threads, 16*32*9, stream>>>(
    (const int4*) A,
    (const int4*) B,
    (int4*) C,
    (const int4*) codebook,
    prob_m,
    prob_k,
    codebook_a_sizes,
    codebook_stride
  );
}

void  code2x8_matvec_cuda(
  const void* __restrict__ A,
  const void* __restrict__ B,
        void* __restrict__ C,
  const void* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,
  const int codebook_stride
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  int shared = 16 * (2 * 256 * 8 + 32 * 9);
  cudaFuncSetAttribute(
    Code2x8MatVec, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  );
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  Code2x8MatVec<<<blocks, threads, shared, stream>>>(
    (const int4*) A,
    (const int4*) B,
    (int4*) C,
    (const int4*) codebook,
    prob_m,
    prob_k,
    codebook_a_sizes,
    codebook_stride
  );
}

void code1x16_dequant_cuda(
  const void* __restrict__ A,
        void* __restrict__ C,
  const void* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long.
  const int codebook_stride // as int4.
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
  Code1x16Dequant<<<blocks, threads, 0, stream>>>(
    (const int4*) A,
    (int4*) C,
    (const int4*) codebook,
    prob_m,
    prob_k,
    codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long.
    codebook_stride // as int4.
  );
}

// Dequantizes the code and codebook into weights.
void  code2x8_dequant_cuda(
  const void* __restrict__ A,
        void* __restrict__ C,
  const void* __restrict__ codebook,
  int prob_m,
  int prob_k,
  const int4 codebook_a_sizes,  // cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
  const int codebook_stride // as int4
) {
  int sms;
  cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, 0);
  int waves = 0;
  int thread_m;
  do {
    waves++;
    thread_m = ceildiv(prob_m, waves * sms);
  } while (thread_m > THREAD_M);

  int blocks = ceildiv(prob_m, thread_m);
  int threads = 32 * thread_m;
  int shared = 16 * (2 * 256 * 8 + 32 * 9);
  cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

  cudaFuncSetAttribute(
    Code2x8Dequant, cudaFuncAttributeMaxDynamicSharedMemorySize, shared
  );
  Code2x8Dequant<<<blocks, threads, shared, stream>>>(
    (const int4*) A,
    (int4*) C,
    (const int4*) codebook,
    prob_m,
    prob_k,
    codebook_a_sizes,
    codebook_stride
  );
}

int codebook_stride(const torch::Tensor& codebooks)
{
  return codebooks.stride(0) * codebooks.element_size() / sizeof(int4);
}

void code1x16_matvec(
  const torch::Tensor& A,
  const torch::Tensor& B,
        torch::Tensor& C,
  const torch::Tensor& codebook,
  const int4 codebook_a_sizes  // cumulative sizes of A spanning each codebook, at most 3 long.
) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  int prob_m = C.size(0);
  int prob_k = B.size(0);

  code1x16_matvec_cuda(
    A.data_ptr(),
    B.data_ptr(),
    C.data_ptr(),
    codebook.data_ptr(),
    prob_m,
    prob_k,
    codebook_a_sizes,
    codebook_stride(codebook)
  );
}

torch::Tensor code1x16_matmat(
  const torch::Tensor& input,
  const torch::Tensor& codes,
  const torch::Tensor& codebooks,
  const torch::Tensor& scales,
  const int4 codebook_a_sizes,
  const std::optional<torch::Tensor>& bias) {
  auto input_sizes = input.sizes();
  auto out_features = codes.size(0) * codebooks.size(2);
  auto flat_input = input.reshape({-1, input.size(-1)});
  auto flat_output = torch::empty({flat_input.size(0), out_features},
    torch::TensorOptions()
      .dtype(input.dtype())
      .device(input.device())
  );

  for (int i = 0; i < flat_input.size(0); ++i) {
    auto input_vec = flat_input.index({i});
    auto output_vec = flat_output.index({i});
    code1x16_matvec(
      codes.squeeze(2),
      input_vec,
      output_vec,
      codebooks,
      codebook_a_sizes
    );
  }
  flat_output *= scales.flatten().unsqueeze(0);

  if (bias.has_value()) {
    flat_output += bias->unsqueeze(0);
  }

  auto output_sizes = input_sizes.vec();
  output_sizes.pop_back();
  output_sizes.push_back(-1);
  auto output = flat_output.reshape(output_sizes);
  return output;
}

void code2x8_matvec(
  const torch::Tensor& A,
  const torch::Tensor& B,
        torch::Tensor& C,
  const torch::Tensor& codebook,
  const int4 codebook_a_sizes
) {
  const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
  int prob_m = C.size(0);
  int prob_k = B.size(0);
  code2x8_matvec_cuda(
    A.data_ptr(),
    B.data_ptr(),
    C.data_ptr(),
    codebook.data_ptr(),
    prob_m,
    prob_k,
    codebook_a_sizes,
    2 * codebook_stride(codebook)
  );
}

torch::Tensor code2x8_matmat(
  const torch::Tensor& input,
  const torch::Tensor& codes,
  const torch::Tensor& codebooks,
  const torch::Tensor& scales,
  const int4 codebook_a_sizes,
  const std::optional<torch::Tensor>& bias
) {
  auto input_sizes = input.sizes();
  auto out_features = codes.size(0) * codebooks.size(2);
  auto flat_input = input.reshape({-1, input.size(-1)});
  auto flat_output = torch::empty({flat_input.size(0), out_features},
    torch::TensorOptions()
      .dtype(input.dtype())
      .device(input.device())
  );

  for (int i = 0; i < flat_input.size(0); ++i) {
    auto input_vec = flat_input.index({i});
    auto output_vec = flat_output.index({i});
    code2x8_matvec(
      codes.squeeze(2),
      input_vec,
      output_vec,
      codebooks,
      codebook_a_sizes
    );
  }
  flat_output *= scales.flatten().unsqueeze(0);
  if (bias.has_value()) {
    flat_output += bias->unsqueeze(0);
  }

  auto output_sizes = input_sizes.vec();
  output_sizes.pop_back();
  output_sizes.push_back(-1);
  auto output = flat_output.reshape(output_sizes);
  return output;
}

// Accumulate the partition sizes.
int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
{
  int4 cumulative_sizes;
  auto cumulative_size = &cumulative_sizes.x;
  int i = 0;
  int last = 0;
  assert(codebook_partition_sizes.size(0) <= 4);
  for (; i <  codebook_partition_sizes.size(0); ++i, ++cumulative_size)
  {
    *cumulative_size = codebook_partition_sizes[i].item<int>() + last;
    last = *cumulative_size;
  }
  // fill in the rest with unreachable.
  for (; i < 4; ++i, ++cumulative_size)
  {
    *cumulative_size = last*10;
  }
  return cumulative_sizes;
}

} // namespace aqlm
} // namespace vllm


torch::Tensor aqlm_gemm(
  const torch::Tensor& input,
  const torch::Tensor& codes,
  const torch::Tensor& codebooks,
  const torch::Tensor& scales,
  const torch::Tensor& codebook_partition_sizes,
  const std::optional<torch::Tensor>& bias
)
{
  int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);

  int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
  int const entries = codebooks.size(1);

  if (nbooks == 1 && entries == (1 << 16))
  { 
    return vllm::aqlm::code1x16_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
  }
  if (nbooks == 2 && entries == (1 << 8))
  {
    return vllm::aqlm::code2x8_matmat(input, codes, codebooks, scales, cumulative_sizes, bias);
  }

  TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
  return {};
}

torch::Tensor aqlm_dequant(
  const torch::Tensor& codes,
  const torch::Tensor& codebooks,
  const torch::Tensor& codebook_partition_sizes
)
{
  int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes);

  int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0);
  int const entries = codebooks.size(1);

  const at::cuda::OptionalCUDAGuard device_guard(device_of(codes));
  int rows = codes.size(1);
  int cols = codes.size(0);

  auto in_features = codes.size(1) * 8;
  auto out_features = codes.size(0);

  assert(out_features = codebook_partition_sizes.sum().item<int>());

  auto weights = torch::empty({out_features, in_features},
    torch::TensorOptions()
      .dtype(codebooks.dtype())
      .device(codebooks.device())
  );

  if (nbooks == 1 && entries == (1 << 16))
  {
    vllm::aqlm::code1x16_dequant_cuda(
      codes.data_ptr(),
      weights.data_ptr(),
      codebooks.data_ptr(),
      out_features,
      in_features,
      cumulative_sizes,
      vllm::aqlm::codebook_stride(codebooks));

    // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
    // weights *= scales.index({"...", 0, 0});

     return weights;
  }

  if (nbooks == 2 && entries == (1 << 8))
  {
     vllm::aqlm::code2x8_dequant_cuda(
        codes.data_ptr(), 
        weights.data_ptr(), 
        codebooks.data_ptr(), 
        out_features,
        in_features, 
        cumulative_sizes, 
        vllm::aqlm::codebook_stride(codebooks));

    // if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
    // weights *= scales.index({"...", 0, 0});

     return weights;
  }

  TORCH_CHECK(false, "AQLM with ", nbooks, " codebooks and ", entries, " entries is not currently supported.")
  return {};
}