cache_kernels.cu 46.7 KB
Newer Older
Xiaowei.zhang's avatar
Xiaowei.zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
/*
 
 * Copyright (c) 2024, The vLLM team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "hip_compat.h"
#include "hip_reduce.h"
#include "dispatch_utils.h"
#include "py_itfs_common.h"

#include "quant_utils.cuh"

#include <algorithm>
#include <cassert>
#include <map>
#include <vector>

#include <hip/hip_bf16.h>

template <typename T, typename F>
__device__ constexpr T block_reduce(T val, F reduce_f)
{
  __shared__ T smem[256];
  T wave_local = wave_reduce(val, reduce_f);
  T v_local = cross_wave_reduce(wave_local, reduce_f, smem);
  return v_local;
}

void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
                 const torch::Tensor &block_mapping)
{
  torch::Device src_device = src.device();
  torch::Device dst_device = dst.device();
  cudaMemcpyKind memcpy_type;
  if (src_device.is_cuda() && dst_device.is_cuda())
  {
    TORCH_CHECK(src_device.index() == dst_device.index(),
                "src and dst must be on the same GPU");
    memcpy_type = cudaMemcpyDeviceToDevice;
  }
  else if (src_device.is_cuda() && dst_device.is_cpu())
  {
    memcpy_type = cudaMemcpyDeviceToHost;
  }
  else if (src_device.is_cpu() && dst_device.is_cuda())
  {
    memcpy_type = cudaMemcpyHostToDevice;
  }
  else
  {
    TORCH_CHECK(false, "Invalid device combination");
  }

  // NOTE(youkaichao): keep in mind that `block_mapping` should be
  // a cpu tensor, otherwise every `item` call will require a gpu-cpu
  // synchronization.
  TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");

  char *src_ptr = static_cast<char *>(src.data_ptr());
  char *dst_ptr = static_cast<char *>(dst.data_ptr());

  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
  const at::cuda::OptionalCUDAGuard device_guard(
      src_device.is_cuda() ? src_device : dst_device);
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  // NOTE(woosuk): This can be slow if the number of blocks is large.
  const int64_t num_blocks = block_mapping.size(0);
  for (size_t i = 0; i < num_blocks; i++)
  {
    int64_t src_block_number = block_mapping[i][0].item<int64_t>();
    int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
    int64_t src_offset = src_block_number * block_size_in_bytes;
    int64_t dst_offset = dst_block_number * block_size_in_bytes;
    cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset,
                    block_size_in_bytes, memcpy_type, stream);
  }
}

namespace vllm
{

  // Grid: (num_layers, num_pairs)
  template <typename scalar_t>
  __global__ void copy_blocks_kernel(int64_t *key_cache_ptrs,
                                     int64_t *value_cache_ptrs,
                                     const int64_t *__restrict__ block_mapping,
                                     const int numel_per_block)
  {
    const int layer_idx = blockIdx.x;
    const int pair_idx = blockIdx.y;

    scalar_t *key_cache = reinterpret_cast<scalar_t *>(key_cache_ptrs[layer_idx]);
    scalar_t *value_cache =
        reinterpret_cast<scalar_t *>(value_cache_ptrs[layer_idx]);
    int64_t src_block_number = block_mapping[2 * pair_idx];
    int64_t dst_block_number = block_mapping[2 * pair_idx + 1];

    const int64_t src_block_offset = src_block_number * numel_per_block;
    const int64_t dst_block_offset = dst_block_number * numel_per_block;
    for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x)
    {
      int64_t src_offset = src_block_offset + i;
      int64_t dst_offset = dst_block_offset + i;
      key_cache[dst_offset] = key_cache[src_offset];
    }
    for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x)
    {
      int64_t src_offset = src_block_offset + i;
      int64_t dst_offset = dst_block_offset + i;
      value_cache[dst_offset] = value_cache[src_offset];
    }
  }

} // namespace vllm

// Note: the key_caches and value_caches vectors are constant but
// not the Tensors they contain. The vectors need to be const refs
// in order to satisfy pytorch's C++ operator registration code.
void copy_blocks(std::vector<torch::Tensor> const &key_caches,
                 std::vector<torch::Tensor> const &value_caches,
                 const torch::Tensor &block_mapping)
{
  int num_layers = key_caches.size();
  TORCH_CHECK(num_layers == value_caches.size());
  if (num_layers == 0)
  {
    return;
  }
  torch::Device cache_device = key_caches[0].device();
  TORCH_CHECK(cache_device.is_cuda());

  // Create data structures for the kernel.
  // Create an array of pointers to the key and value caches.
  int64_t key_cache_ptrs[num_layers];
  int64_t value_cache_ptrs[num_layers];
  for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx)
  {
    key_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(key_caches[layer_idx].data_ptr());
    value_cache_ptrs[layer_idx] =
        reinterpret_cast<int64_t>(value_caches[layer_idx].data_ptr());
  }

  // block_mapping is a 2D tensor with shape (num_pairs, 2).
  int num_pairs = block_mapping.size(0);

  // Move the data structures to the GPU.
  // NOTE: This synchronizes the CPU and GPU.
  torch::Tensor key_cache_ptrs_tensor =
      torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);
  torch::Tensor value_cache_ptrs_tensor =
      torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64)
          .to(cache_device);

  // Launch the kernel.
  const int numel_per_block = key_caches[0][0].numel();
  dim3 grid(num_layers, num_pairs);
  dim3 block(std::min(1024, numel_per_block));
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
      key_caches[0].scalar_type(), "copy_blocks_kernel", ([&]
                                                          { vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
                                                                key_cache_ptrs_tensor.data_ptr<int64_t>(),
                                                                value_cache_ptrs_tensor.data_ptr<int64_t>(),
                                                                block_mapping.data_ptr<int64_t>(), numel_per_block); }));
}

namespace vllm
{

  template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt, bool asmLayout = false, typename slot_mapping_t = int64_t>
  __global__ void reshape_and_cache_kernel(
      const scalar_t *__restrict__ key,                // [num_tokens, num_heads, head_size]
      const scalar_t *__restrict__ value,              // [num_tokens, num_heads, head_size]
      cache_t *__restrict__ key_cache,                 // [num_blocks, num_heads, head_size/x,
                                                       // block_size, x]
      cache_t *__restrict__ value_cache,               // [num_blocks, num_heads, head_size,
                                                       // block_size]
      const slot_mapping_t *__restrict__ slot_mapping, // [num_tokens]
      const int key_stride, const int value_stride, const int num_heads,
      const int head_size, const int block_size, const int x, const float k_scale,
      const float v_scale)
  {
    const int64_t token_idx = blockIdx.x;
    const slot_mapping_t slot_idx = slot_mapping[token_idx];
    if (slot_idx < 0)
    {
      // Padding token that should be ignored.
      return;
    }

    const int64_t block_idx = static_cast<int64_t>(slot_idx) / block_size;
    const int64_t block_offset = static_cast<int64_t>(slot_idx) % block_size;

    const int n = num_heads * head_size;
    for (int i = threadIdx.x; i < n; i += blockDim.x)
    {
      const int64_t src_key_idx = token_idx * key_stride + i;
      const int64_t src_value_idx = token_idx * value_stride + i;

      const int head_idx = i / head_size;
      const int head_offset = i % head_size;
      const int x_idx = head_offset / x;
      const int x_offset = head_offset % x;

      const int64_t tgt_key_idx =
          block_idx * num_heads * (head_size / x) * block_size * x +
          head_idx * (head_size / x) * block_size * x +
          x_idx * block_size * x +
          block_offset * x +
          x_offset;
      int64_t tgt_value_idx;
      if constexpr (asmLayout)
      { //[num_blocks, num_heads, block_size/X, head_size, X]
        const int x_idx_v = block_offset / x;
        const int x_offset_v = block_offset % x;
        tgt_value_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size +
            x_idx_v * head_size * x +
            head_offset * x +
            x_offset_v;
      }
      else
      { //[num_blocks, num_heads, head_size, block_size]
        tgt_value_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size +
            head_offset * block_size +
            block_offset;
      }
      scalar_t tgt_key = key[src_key_idx];
      scalar_t tgt_value = value[src_value_idx];
      if constexpr (kv_dt == Fp8KVCacheDataType::kAuto)
      {
        key_cache[tgt_key_idx] = tgt_key;
        value_cache[tgt_value_idx] = tgt_value;
      }
      else
      {
        key_cache[tgt_key_idx] =
            fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
        value_cache[tgt_value_idx] =
            fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
      }
    }
  }

  template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
  __global__ void reshape_and_cache_flash_kernel(
      const scalar_t *__restrict__ key,         // [num_tokens, num_heads, head_size]
      const scalar_t *__restrict__ value,       // [num_tokens, num_heads, head_size]
      cache_t *__restrict__ key_cache,          // [num_blocks, block_size, num_heads,
                                                // head_size]
      cache_t *__restrict__ value_cache,        // [num_blocks, block_size, num_heads,
                                                // head_size]
      const int64_t *__restrict__ slot_mapping, // [num_tokens]
      const int block_stride, const int key_stride, const int value_stride,
      const int num_heads, const int head_size, const int block_size,
      const float* k_scale, const float* v_scale)
  {
    const int64_t token_idx = blockIdx.x;
    const int64_t slot_idx = slot_mapping[token_idx];
    // NOTE: slot_idx can be -1 if the token is padded
    if (slot_idx < 0)
    {
      return;
    }
    const int64_t block_idx = slot_idx / block_size;
    const int64_t block_offset = slot_idx % block_size;
    const int n = num_heads * head_size;
    for (int i = threadIdx.x; i < n; i += blockDim.x)
    {
      const int64_t src_key_idx = token_idx * key_stride + i;
      const int64_t src_value_idx = token_idx * value_stride + i;
      const int head_idx = i / head_size;
      const int head_offset = i % head_size;
      const int64_t tgt_key_value_idx = block_idx * block_stride +
                                        block_offset * num_heads * head_size +
                                        head_idx * head_size + head_offset;
      scalar_t tgt_key = key[src_key_idx];
      scalar_t tgt_value = value[src_value_idx];
      if constexpr (kv_dt == Fp8KVCacheDataType::kAuto)
      {
        key_cache[tgt_key_value_idx] = tgt_key;
        value_cache[tgt_key_value_idx] = tgt_value;
      }
      else
      {
        key_cache[tgt_key_value_idx] =
            fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
        value_cache[tgt_key_value_idx] =
            fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
      }
    }
  }

  namespace impl
  {
    template <typename DType, typename SType>
    __device__ DType type_convert(SType);

    template <>
    __device__ float type_convert<float, __half>(__half x)
    {
      return __half2float(x);
    }

    template <>
    __device__ float type_convert<float, __hip_bfloat16>(__hip_bfloat16 x)
    {
      return __bfloat162float(x);
    }

    template <>
    __device__ hip_fp8 type_convert<hip_fp8, float>(float x)
    {
      hip_fp8 f8{x};
      return f8;
    }

    template <>
    __device__ float type_convert<float, hip_fp8>(hip_fp8 x)
    {
      return float(x);
    }

    template <>
    __device__ int8_t type_convert<int8_t, float>(float x)
    {
      return static_cast<int8_t>(x);
    }

    template <>
    __device__ float type_convert<float, int8_t>(int8_t x)
    {
      return static_cast<float>(x);
    }

    template <>
    __device__ float type_convert<float, float>(float x)
    {
      return x;
    }

    template <typename T, typename F>
    __device__ constexpr T wave_reduce(T local, F reduce_f)
    {
      constexpr int reduce_stage = 6; // 1<<6=64
      T v_local = local;
#pragma unroll
      for (int i_stage = 0; i_stage < reduce_stage; i_stage++)
      {
        int src_lane = __lane_id() ^ (1 << i_stage);
        int32_t v_remote_tmp =
            __builtin_amdgcn_ds_bpermute(src_lane << 2, __builtin_bit_cast(int32_t, v_local));
        T v_remote = __builtin_bit_cast(T, v_remote_tmp);
        v_local = reduce_f(v_local, v_remote);
      }
      return v_local;
    }

    __device__ float abs(float x)
    {
      union
      {
        float f32;
        uint32_t u32;
      } y;
      y.f32 = x;
      y.u32 = y.u32 & 0x7fffffff;
      return y.f32;
    };
  }

  // TODO: this is for kv pertoken quant
  template <typename scalar_t, typename cache_t, typename dequant_scale_t, bool asmLayout = false, int wg_size = 256>
  __global__ void reshape_and_cache_with_per_token_quant_kernel(
      const scalar_t *__restrict__ key,               // [num_tokens, num_heads, head_size]
      const scalar_t *__restrict__ value,             // [num_tokens, num_heads, head_size]
      cache_t *__restrict__ key_cache,                // [num_blocks, num_heads, head_size/x, block_size, x]
      cache_t *__restrict__ value_cache,              // [num_blocks, num_heads, head_size, block_size]
      dequant_scale_t *__restrict__ k_dequant_scales, // [num_heads, max_kv_tokens]
      dequant_scale_t *__restrict__ v_dequant_scales, // [num_heads, max_kv_tokens]
      const int64_t *__restrict__ slot_mapping,       // [num_tokens]
      const int key_stride, const int value_stride, const int num_heads,
      const int head_size, const int block_size, const int x,
      const int num_tokens, const int max_kv_tokens,
      float dtypeMax)
  {
    const int32_t tokens_per_wg = wg_size / warpSize;

    // every wave compute one token, one head, all the headim
    int wave_id = threadIdx.x / warpSize;
    int lane_id = threadIdx.x % warpSize;

    const int64_t token_idx = static_cast<int64_t>(blockIdx.x * tokens_per_wg + wave_id);
    const int32_t head_idx = blockIdx.y;
    const int64_t slot_idx = slot_mapping[token_idx];

    if (token_idx >= num_tokens || slot_idx < 0)
    {
      // Padding token that should be ignored.
      return;
    }

    const int64_t block_idx = slot_idx / block_size;
    const int64_t block_offset = slot_idx % block_size;

    auto f_absmax_f32 = [](float v_0_, float v_1_)
    {
      return __builtin_fmaxf(impl::abs(v_0_), impl::abs(v_1_));
    };
    auto f_max_f32 = [](float v_0_, float v_1_)
    {
      return __builtin_fmaxf(v_0_, v_1_);
    };

    constexpr int local_dim_elems = 8;

    float k_local_dim[local_dim_elems]{0}; // up to 64*8 = 512 hdim
    float v_local_dim[local_dim_elems]{0}; // up to 64*8 = 512 hdim
#pragma unroll
    for (int i_d = 0; i_d < local_dim_elems; i_d++)
    {
      int current_d = lane_id + i_d * warpSize;
      const int64_t src_k_idx = token_idx * key_stride + head_idx * head_size + current_d;
      const int64_t src_v_idx = token_idx * value_stride + head_idx * head_size + current_d;
      if (current_d < head_size)
      {
        k_local_dim[i_d] = impl::type_convert<float>(key[src_k_idx]);
        v_local_dim[i_d] = impl::type_convert<float>(value[src_v_idx]);
      }
    }

    // smoot-quant
    float k_local_max = [&]()
    {
      float max_ = k_local_dim[0];
#pragma unroll
      for (int i_d = 1; i_d < local_dim_elems; i_d++)
      {
        max_ = f_absmax_f32(max_, k_local_dim[i_d]);
      }
      return max_;
    }();

    float k_max = impl::wave_reduce(k_local_max, f_max_f32);

    float v_local_max = [&]()
    {
      float max_ = v_local_dim[0];
#pragma unroll
      for (int i_d = 1; i_d < local_dim_elems; i_d++)
      {
        max_ = f_absmax_f32(max_, v_local_dim[i_d]);
      }
      return max_;
    }();
    float v_max = impl::wave_reduce(v_local_max, f_max_f32);

    float k_token_scale = k_max / dtypeMax;
    float v_token_scale = v_max / dtypeMax;

#pragma unroll
    for (int i_d = 0; i_d < local_dim_elems; i_d++)
    {
      k_local_dim[i_d] = k_local_dim[i_d] / k_token_scale;
      v_local_dim[i_d] = v_local_dim[i_d] / v_token_scale;
    }

    // store the scale
    int scale_idx;
    if constexpr (asmLayout)
    {
      // [num_blocks, num_heads, block_size]
      scale_idx = block_size * num_heads * block_idx +
                  block_size * head_idx +
                  block_offset;
      k_dequant_scales[scale_idx] = k_token_scale;
      v_dequant_scales[scale_idx] = v_token_scale;
    }
    else
    {
      scale_idx = head_idx * max_kv_tokens + slot_idx;
      k_dequant_scales[scale_idx] = k_token_scale;
      v_dequant_scales[scale_idx] = v_token_scale;
    }

    // now let's store out
#pragma unroll
    for (int i = 0; i < local_dim_elems; i++)
    {
      // const int head_idx = i / head_size;
      // const int head_offset = i % head_size;
      int i_d = lane_id + i * warpSize;
      if (i_d >= head_size)
      {
        break;
      }
      const int x_idx = i_d / x;
      const int x_offset = i_d % x;

      const int64_t tgt_key_idx =
          block_idx * num_heads * (head_size / x) * block_size * x +
          head_idx * (head_size / x) * block_size * x +
          x_idx * block_size * x +
          block_offset * x +
          x_offset;
      int64_t tgt_value_idx;
      if constexpr (asmLayout)
      { //[num_blocks, num_heads, block_size/X, head_size, X]
        const int x_idx_v = block_offset / x;
        const int x_offset_v = block_offset % x;
        tgt_value_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size +
            x_idx_v * head_size * x +
            i_d * x +
            x_offset_v;
      }
      else
      { //[num_blocks, num_heads, head_size, block_size]
        tgt_value_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size +
            i_d * block_size +
            block_offset;
      }
      key_cache[tgt_key_idx] = impl::type_convert<cache_t>(k_local_dim[i]);
      value_cache[tgt_value_idx] = impl::type_convert<cache_t>(v_local_dim[i]);
    }
  }

  // TODO: this is for kv pertoken quant
  template <typename scalar_t, typename cache_t, typename dequant_scale_t, bool asmLayout = false, int wg_size = 256>
  __global__ void reshape_and_cache_with_block_quant_kernel(
      const scalar_t *__restrict__ key,               // [batch_size, seq_len, num_heads, head_size]
      const scalar_t *__restrict__ value,             // [batch_size, seq_len, num_heads, head_size]
      cache_t *__restrict__ key_cache,                // [num_blocks, num_heads, head_size/x, block_size, x]
      cache_t *__restrict__ value_cache,              // [num_blocks, num_heads, head_size, block_size]
      dequant_scale_t *__restrict__ k_dequant_scales, // [num_heads, num_blocks]
      dequant_scale_t *__restrict__ v_dequant_scales, // [num_heads, num_blocks]
      const int64_t *__restrict__ slot_mapping,       // [num_tokens]
      const int key_stride, const int value_stride, const int num_heads, const int num_blocks,
      const int head_size, const int block_size, const int x,
      const int num_tokens, const int seq_len, float dtypeMax)
  {
    int64_t first_token_idx = blockIdx.x * seq_len + blockIdx.y * block_size;
    int64_t slot_idx;
    int64_t block_idx;
    int64_t block_offset;
    if (blockIdx.y * block_size >= seq_len)
    {
      int64_t preTg_block_idx = slot_mapping[first_token_idx - block_size] / block_size;
      first_token_idx = blockIdx.x * seq_len + seq_len - 1;
      slot_idx = slot_mapping[first_token_idx];
      block_idx = slot_idx / block_size;
      if (preTg_block_idx == block_idx)
      {
        return;
      }
      block_offset = slot_idx % block_size;
    }
    else
    {
      slot_idx = slot_mapping[first_token_idx];
      block_idx = slot_idx / block_size;
      block_offset = slot_idx % block_size;
    }

    if (slot_idx < 0)
    {
      // Padding token that should be ignored.
      return;
    }
    const int32_t head_idx = blockIdx.z;

    // fix first_token_idx to real block first_token_idx
    if (blockIdx.y > 0 && block_offset > 0)
    {
      __shared__ int64_t idx_smem[2];
      if (threadIdx.x < block_size)
      {
        int64_t token_idx = first_token_idx - (threadIdx.x + 1);
        int64_t block_idx1 = slot_mapping[token_idx] / block_size;
        int64_t slot_idx2 = slot_mapping[token_idx + 1];
        int64_t block_idx2 = slot_idx2 / block_size;
        if (block_idx1 != block_idx2 && block_idx2 == block_idx)
        {
          idx_smem[0] = token_idx + 1;
          idx_smem[1] = slot_idx2;
        }
      }
      __syncthreads();
      first_token_idx = idx_smem[0];
      slot_idx = idx_smem[1];
    }

    block_offset = slot_idx % block_size;

    int tokens_in_block = 0;
    if (first_token_idx + threadIdx.x < num_tokens)
    {
      tokens_in_block = slot_mapping[first_token_idx + threadIdx.x] / block_size;
      tokens_in_block = tokens_in_block == block_idx ? 1 : 0;
    }
    int numtokens_in_block = block_reduce(tokens_in_block, [](float a, float b)
                                          { return a + b; });

    auto f_absmax_f32 = [](float v_0_, float v_1_)
    {
      return __builtin_fmaxf(impl::abs(v_0_), impl::abs(v_1_));
    };
    auto f_max_f32 = [](float v_0_, float v_1_)
    {
      return __builtin_fmaxf(v_0_, v_1_);
    };

    float k_max_val = 1e-6;
    float v_max_val = 1e-6;
#pragma unroll
    for (int id = 0; id < numtokens_in_block * head_size; id += blockDim.x)
    {
      if ((id + threadIdx.x) < numtokens_in_block * head_size)
      {
        int64_t token_idx = (id + threadIdx.x) / head_size + first_token_idx;
        int current_d = (id + threadIdx.x) % head_size;

        const int64_t src_k_idx = token_idx * key_stride + head_idx * head_size + current_d;
        const int64_t src_v_idx = token_idx * value_stride + head_idx * head_size + current_d;

        k_max_val = f_absmax_f32(k_max_val, impl::type_convert<float>(key[src_k_idx]));
        v_max_val = f_absmax_f32(v_max_val, impl::type_convert<float>(value[src_v_idx]));
      }
    }

    k_max_val = block_reduce(k_max_val, f_max_f32);
    v_max_val = block_reduce(v_max_val, f_max_f32);

    float k_block_scale = k_max_val / dtypeMax;
    float v_block_scale = v_max_val / dtypeMax;

    int64_t scale_idx;
    if constexpr (asmLayout)
    {
      scale_idx = block_idx * num_heads + head_idx;
    }
    else
    {
      scale_idx = head_idx * num_blocks + block_idx;
    }

    if (block_offset > 0)
    {
      float k_block_scale_global = k_dequant_scales[scale_idx];
      float v_block_scale_global = v_dequant_scales[scale_idx];

      if (k_block_scale_global < k_block_scale)
      {
        int64_t tgt_value_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size;
#pragma unroll
        for (int id = 0; id < block_offset * head_size; id += blockDim.x)
        {
          if (id + threadIdx.x < block_offset * head_size)
          {
            int block_offset_local = (id + threadIdx.x) / head_size;
            int x_idx = (id + threadIdx.x) % head_size / x;
            int x_offset = (id + threadIdx.x) % x;
            int64_t cache_idx = tgt_value_idx +
                                x_idx * block_size * x +
                                block_offset_local * x +
                                x_offset;
            float tmp = impl::type_convert<float>(key_cache[cache_idx]);
            tmp = tmp * k_block_scale_global / k_block_scale;
            key_cache[cache_idx] = impl::type_convert<cache_t>(tmp);
          }
        }
        k_dequant_scales[scale_idx] = k_block_scale;
      }
      else
      {
        k_block_scale = k_block_scale_global;
      }

      if (v_block_scale_global < v_block_scale)
      {
        int64_t tgt_value_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size;
#pragma unroll
        for (int id = 0; id < block_offset * head_size; id += blockDim.x)
        {
          if (id + threadIdx.x < block_offset * head_size)
          {
            int64_t cache_idx;
            if constexpr (asmLayout)
            {
              int block_offset_local = (id + threadIdx.x) / head_size;
              int head_offset = (id + threadIdx.x) % head_size;
              int block_offset_local_divX = block_offset_local / x;
              int x_idx = block_offset_local % x;
              cache_idx = tgt_value_idx +
                          block_offset_local_divX * head_size * x +
                          head_offset * x +
                          x_idx;
            }
            else
            {
              int block_offset_local = (id + threadIdx.x) / head_size;
              int head_offset = (id + threadIdx.x) % head_size;
              cache_idx = tgt_value_idx +
                          head_offset * block_size +
                          block_offset_local;
            }
            float tmp = impl::type_convert<float>(value_cache[cache_idx]);
            tmp = tmp * v_block_scale_global / v_block_scale;
            value_cache[cache_idx] = impl::type_convert<cache_t>(tmp);
          }
        }
        v_dequant_scales[scale_idx] = v_block_scale;
      }
      else
      {
        v_block_scale = v_block_scale_global;
      }
    }
    else
    {
      k_dequant_scales[scale_idx] = k_block_scale;
      v_dequant_scales[scale_idx] = v_block_scale;
    }

    // now let's store out
    for (int id = 0; id < numtokens_in_block * head_size; id += blockDim.x)
    {
      if ((id + threadIdx.x) < numtokens_in_block * head_size)
      {
        int token_idx = (id + threadIdx.x) / head_size + first_token_idx;
        int current_d = (id + threadIdx.x) % head_size;
        int block_offset_local = token_idx - first_token_idx + block_offset;

        const int64_t src_k_idx = token_idx * key_stride + head_idx * head_size + current_d;
        const int64_t src_v_idx = token_idx * value_stride + head_idx * head_size + current_d;
        float tmp_k = impl::type_convert<float>(key[src_k_idx]) / k_block_scale;
        float tmp_v = impl::type_convert<float>(value[src_v_idx]) / v_block_scale;

        const int x_idx = current_d / x;
        const int x_offset = current_d % x;
        //[num_blocks, num_heads, head_size/X, block_size, X]
        const int64_t tgt_key_idx =
            block_idx * num_heads * head_size * block_size +
            head_idx * head_size * block_size +
            x_idx * block_size * x +
            block_offset_local * x +
            x_offset;

        int64_t tgt_value_idx;
        if constexpr (asmLayout)
        { //[num_blocks, num_heads, block_size/X, head_size, X]
          const int x_idx = block_offset_local / x;
          const int x_offset = block_offset_local % x;
          tgt_value_idx =
              block_idx * num_heads * head_size * block_size +
              head_idx * head_size * block_size +
              x_idx * head_size * x +
              current_d * x +
              x_offset;
        }
        else
        { //[num_blocks, num_heads, head_size, block_size]
          tgt_value_idx =
              block_idx * num_heads * head_size * block_size +
              head_idx * head_size * block_size +
              current_d * block_size +
              block_offset_local;
        }
        key_cache[tgt_key_idx] = impl::type_convert<cache_t>(tmp_k);
        value_cache[tgt_value_idx] = impl::type_convert<cache_t>(tmp_v);
      }
    }
  }
} // namespace vllm

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE)               \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE>             \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T *>(key.data_ptr()),                   \
          reinterpret_cast<KV_T *>(value.data_ptr()),                 \
          reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),          \
          reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),        \
          slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
          num_heads, head_size, block_size, x, k_scale, v_scale);

#define CALL_RESHAPE_AND_CACHE_ASM(KV_T, CACHE_T, KV_DTYPE)           \
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, KV_DTYPE, true>       \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T *>(key.data_ptr()),                   \
          reinterpret_cast<KV_T *>(value.data_ptr()),                 \
          reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),          \
          reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),        \
          slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
          num_heads, head_size, block_size, x, k_scale, v_scale);

void reshape_and_cache(
    torch::Tensor &key,   // [num_tokens, num_heads, head_size]
    torch::Tensor &value, // [num_tokens, num_heads, head_size]
    torch::Tensor &
        key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor &
        value_cache,             // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor &slot_mapping, // [num_tokens]
    const std::string &kv_cache_dtype, const double k_scale,
    const double v_scale,
    const bool asm_layout)
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (asm_layout)
  {
    DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                               CALL_RESHAPE_AND_CACHE_ASM)
  }
  else
  {
    DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                               CALL_RESHAPE_AND_CACHE)
  }
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE)         \
  vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE>       \
      <<<grid, block, 0, stream>>>(                                   \
          reinterpret_cast<KV_T *>(key.data_ptr()),                   \
          reinterpret_cast<KV_T *>(value.data_ptr()),                 \
          reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),          \
          reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),        \
          slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
          value_stride, num_heads, head_size, block_size, k_scale.data_ptr<float>(), v_scale.data_ptr<float>());

void reshape_and_cache_flash(
    torch::Tensor &key,       // [num_tokens, num_heads, head_size]
    torch::Tensor &value,     // [num_tokens, num_heads, head_size]
    torch::Tensor &key_cache, // [num_blocks, block_size, num_heads, head_size]
    torch::Tensor &
        value_cache,             // [num_blocks, block_size, num_heads, head_size]
    torch::Tensor &slot_mapping, // [num_tokens]
    const std::string &kv_cache_dtype,
    torch::Tensor& k_scale,
    torch::Tensor& v_scale)
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(1);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);
  int block_stride = key_cache.stride(0);
  TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));

  dim3 grid(num_tokens);
  dim3 block(std::min(num_heads * head_size, 512));
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
                             CALL_RESHAPE_AND_CACHE_FLASH);
}

// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(KV_T, CACHE_T, dequant_scale_t)            \
  if (asm_layout)                                                                             \
  {                                                                                           \
    vllm::reshape_and_cache_with_per_token_quant_kernel<KV_T, CACHE_T, dequant_scale_t, true> \
        <<<grid, block, 0, stream>>>(                                                         \
            reinterpret_cast<KV_T *>(key.data_ptr()),                                         \
            reinterpret_cast<KV_T *>(value.data_ptr()),                                       \
            reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),                                \
            reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),                              \
            reinterpret_cast<dequant_scale_t *>(k_dequant_scales.data_ptr()),                 \
            reinterpret_cast<dequant_scale_t *>(v_dequant_scales.data_ptr()),                 \
            slot_mapping.data_ptr<int64_t>(), key_stride, value_stride,                       \
            num_heads, head_size, block_size, x, num_tokens, max_kv_tokens, dtypeMax);        \
  }                                                                                           \
  else                                                                                        \
  {                                                                                           \
    vllm::reshape_and_cache_with_per_token_quant_kernel<KV_T, CACHE_T, dequant_scale_t>       \
        <<<grid, block, 0, stream>>>(                                                         \
            reinterpret_cast<KV_T *>(key.data_ptr()),                                         \
            reinterpret_cast<KV_T *>(value.data_ptr()),                                       \
            reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),                                \
            reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),                              \
            reinterpret_cast<dequant_scale_t *>(k_dequant_scales.data_ptr()),                 \
            reinterpret_cast<dequant_scale_t *>(v_dequant_scales.data_ptr()),                 \
            slot_mapping.data_ptr<int64_t>(), key_stride, value_stride,                       \
            num_heads, head_size, block_size, x, num_tokens, max_kv_tokens, dtypeMax);        \
  }

#define CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(KV_T, CACHE_T, dequant_scale_t)           \
  if (asm_layout)                                                                         \
  {                                                                                       \
    vllm::reshape_and_cache_with_block_quant_kernel<KV_T, CACHE_T, dequant_scale_t, true> \
        <<<grid, block, 0, stream>>>(                                                     \
            reinterpret_cast<KV_T *>(key.data_ptr()),                                     \
            reinterpret_cast<KV_T *>(value.data_ptr()),                                   \
            reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),                            \
            reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),                          \
            reinterpret_cast<dequant_scale_t *>(k_dequant_scales.data_ptr()),             \
            reinterpret_cast<dequant_scale_t *>(v_dequant_scales.data_ptr()),             \
            slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, num_heads,        \
            num_blocks, head_size, block_size, x, num_tokens, seq_len, dtypeMax);         \
  }                                                                                       \
  else                                                                                    \
  {                                                                                       \
    vllm::reshape_and_cache_with_block_quant_kernel<KV_T, CACHE_T, dequant_scale_t>       \
        <<<grid, block, 0, stream>>>(                                                     \
            reinterpret_cast<KV_T *>(key.data_ptr()),                                     \
            reinterpret_cast<KV_T *>(value.data_ptr()),                                   \
            reinterpret_cast<CACHE_T *>(key_cache.data_ptr()),                            \
            reinterpret_cast<CACHE_T *>(value_cache.data_ptr()),                          \
            reinterpret_cast<dequant_scale_t *>(k_dequant_scales.data_ptr()),             \
            reinterpret_cast<dequant_scale_t *>(v_dequant_scales.data_ptr()),             \
            slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, num_heads,        \
            num_blocks, head_size, block_size, x, num_tokens, seq_len, dtypeMax);         \
  }

void reshape_and_cache_with_pertoken_quant(
    torch::Tensor &key,   // [num_tokens, num_heads, head_size]
    torch::Tensor &value, // [num_tokens, num_heads, head_size]
    torch::Tensor &
        key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor &
        value_cache,                 // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor &k_dequant_scales, // [num_heads, max_kv_tokens]
    torch::Tensor &v_dequant_scales, // [num_heads, max_kv_tokens]
    torch::Tensor &slot_mapping,     // [num_tokens]
    const bool asm_layout)
{
  int num_tokens = key.size(0);
  int num_heads = key.size(1);
  int head_size = key.size(2);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);
  int max_kv_tokens = k_dequant_scales.size(1);

  int key_stride = key.stride(0);
  int value_stride = value.stride(0);

  dim3 grid((num_tokens + 3) / 4, num_heads);
  dim3 block(256);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  using dequant_scale_t = float; // should align with k_dequant_scales/v_dequant_scales dtype

  float dtypeMax;
  if (key_cache.dtype() == torch_fp8)
  {
    dtypeMax = FP8_MAX;
    if (key.dtype() == at::ScalarType::Float)
    {
      CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(float, hip_fp8, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::Half)
    {
      CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(__half, hip_fp8, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::BFloat16)
    {
      CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(__hip_bfloat16, hip_fp8, dequant_scale_t);
    }
    else
    {
      TORCH_CHECK(false,
                  "Unsupported input type of kv: ", key.dtype());
    }
  }
  else if (key_cache.dtype() == at::ScalarType::Char)
  {
    dtypeMax = 127;
    if (key.dtype() == at::ScalarType::Float)
    {
      CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(float, int8_t, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::Half)
    {
      CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(__half, int8_t, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::BFloat16)
    {
      CALL_RESHAPE_AND_CACHE_WITH_PERTOKEN_QUANT(__hip_bfloat16, int8_t, dequant_scale_t);
    }
    else
    {
      TORCH_CHECK(false,
                  "Unsupported input type of kv: ", key.dtype(), " kv cache: ", key_cache.dtype());
    }
  }
  else
  {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", key_cache.dtype());
  }
}

void reshape_and_cache_with_block_quant(
    torch::Tensor &key,   // [batch_size, seq_len, num_heads, head_size]
    torch::Tensor &value, // [batch_size, seq_len, num_heads, head_size]
    torch::Tensor &
        key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
    torch::Tensor &
        value_cache,                 // [num_blocks, num_heads, head_size, block_size]
    torch::Tensor &k_dequant_scales, // [num_heads, num_blocks]
    torch::Tensor &v_dequant_scales, // [num_heads, num_blocks]
    torch::Tensor &slot_mapping,     // [num_tokens]
    const bool asm_layout)
{
  int batch_size = key.size(0);
  int seq_len = key.size(1);
  int num_heads = key.size(2);
  int head_size = key.size(3);
  int num_blocks = key_cache.size(0);
  int block_size = key_cache.size(3);
  int x = key_cache.size(4);
  int num_tokens = batch_size * seq_len;

  int key_stride = key.stride(0) / seq_len;
  int value_stride = value.stride(0) / seq_len;
  int blockDimx = (block_size + 255) / 256 * 256;

  dim3 grid(batch_size, (seq_len + block_size - 1) / block_size + 1, num_heads);
  dim3 block(blockDimx);
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  using dequant_scale_t = float; // should align with k_dequant_scales/v_dequant_scales dtype

  float dtypeMax;
  if (key_cache.dtype() == torch_fp8)
  {
    dtypeMax = FP8_MAX;
    if (key.dtype() == at::ScalarType::Float)
    {
      CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(float, hip_fp8, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::Half)
    {
      CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(__half, hip_fp8, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::BFloat16)
    {
      CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(__hip_bfloat16, hip_fp8, dequant_scale_t);
    }
    else
    {
      TORCH_CHECK(false,
                  "Unsupported input type of kv: ", key.dtype());
    }
  }
  else if (key_cache.dtype() == at::ScalarType::Char)
  {
    dtypeMax = 127;
    if (key.dtype() == at::ScalarType::Float)
    {
      CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(float, int8_t, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::Half)
    {
      CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(__half, int8_t, dequant_scale_t);
    }
    else if (key.dtype() == at::ScalarType::BFloat16)
    {
      CALL_RESHAPE_AND_CACHE_WITH_BLOCK_QUANT(__hip_bfloat16, int8_t, dequant_scale_t);
    }
    else
    {
      TORCH_CHECK(false,
                  "Unsupported input type of kv: ", key.dtype(), " kv cache: ", key_cache.dtype());
    }
  }
  else
  {
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", key_cache.dtype());
  }
}

namespace vllm
{

  template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
  __global__ void convert_fp8_kernel(const Tin *__restrict__ src_cache,
                                     Tout *__restrict__ dst_cache,
                                     const float scale,
                                     const int64_t block_stride)
  {
    const int64_t block_idx = blockIdx.x;
    for (int i = threadIdx.x; i < block_stride; i += blockDim.x)
    {
      int64_t idx = block_idx * block_stride + i;
      dst_cache[idx] =
          fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
    }
  }

} // namespace vllm

#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE)                                \
  vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
      reinterpret_cast<Tin *>(src_cache.data_ptr()),                         \
      reinterpret_cast<Tout *>(dst_cache.data_ptr()), scale, block_stride);

// Only for testing.
void convert_fp8(torch::Tensor &dst_cache, torch::Tensor &src_cache,
                 const double scale, const std::string &kv_cache_dtype)
{
  torch::Device src_device = src_cache.device();
  torch::Device dst_device = dst_cache.device();
  TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
  TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
  TORCH_CHECK(src_device.index() == dst_device.index(),
              "src and dst must be on the same GPU");
  at::cuda::OptionalCUDAGuard device_guard(src_device);

  int64_t num_blocks = src_cache.size(0);
  int64_t block_stride = src_cache.stride(0);

  dim3 grid(num_blocks);
  dim3 block(std::min(block_stride, int64_t(512)));
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

  if (kv_cache_dtype == "auto")
  {
    if (src_cache.dtype() == at::ScalarType::Float)
    {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto);
    }
    else if (src_cache.dtype() == at::ScalarType::Half)
    {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);
    }
    else if (src_cache.dtype() == at::ScalarType::BFloat16)
    {
      CALL_CONVERT_FP8(uint8_t, __hip_bfloat16, vllm::Fp8KVCacheDataType::kAuto);
    }
    else if (dst_cache.dtype() == at::ScalarType::Float)
    {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    }
    else if (dst_cache.dtype() == at::ScalarType::Half)
    {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    }
    else if (dst_cache.dtype() == at::ScalarType::BFloat16)
    {
      CALL_CONVERT_FP8(__hip_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto);
    }
  }
  else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3")
  {
    if (src_cache.dtype() == at::ScalarType::Float)
    {
      CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
    else if (src_cache.dtype() == at::ScalarType::Half)
    {
      CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
    else if (src_cache.dtype() == at::ScalarType::BFloat16)
    {
      CALL_CONVERT_FP8(uint8_t, __hip_bfloat16,
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
    else if (dst_cache.dtype() == at::ScalarType::Float)
    {
      CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
    else if (dst_cache.dtype() == at::ScalarType::Half)
    {
      CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
    else if (dst_cache.dtype() == at::ScalarType::BFloat16)
    {
      CALL_CONVERT_FP8(__hip_bfloat16, uint8_t,
                       vllm::Fp8KVCacheDataType::kFp8E4M3);
    }
  }
  else
  {
    TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
  }
}