quant_utils.cuh 18.3 KB
Newer Older
1
#pragma once
2
#include <hip/hip_fp8.h>
3
4
5
6
7

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

8
#include "../../../attention/attention_dtypes.h"
9

10
namespace vllm {
11
12
13
#ifdef USE_ROCM

namespace fp8 {
14
  #ifdef ENABLE_FP8
15

16
template <typename Tout, typename Tin>
17
18
__inline__ __device__ Tout vec_conversion(const Tin& x) {
  return x;
19
20
21
}

template <typename Tout, typename Tin>
22
23
24
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
                                                 const float scale) {
  return x;
25
26
}

27
    #if HIP_FP8_TYPE_OCP
28
29
using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3;
30
31
32
    #else
using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
33
34
    #endif

35
36
// fp8 -> half
template <>
37
38
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
39
  return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
40
41
42
43
}

// fp8x2 -> half2
template <>
44
45
46
47
48
49
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
  union {
    __half2_raw h2r;
    uint32_t ui32;
  } tmp;
50
  tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
51
  return tmp.ui32;
52
53
54
55
}

// fp8x4 -> half2x2
template <>
56
57
58
59
60
61
62
63
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
  tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
  return tmp.u32x2;
64
65
66
67
}

// fp8x8 -> half2x4
template <>
68
69
70
71
72
73
74
75
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
  tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
  return tmp.u64x2;
76
77
78
79
80
81
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
82
83
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
84
85
86
  fp8_type f8;
  f8.__x = a;
  return __float2bfloat16(static_cast<float>(f8));
87
88
89
90
91
92
}

using __nv_bfloat162 = __hip_bfloat162;

// fp8x2 -> __nv_bfloat162
template <>
93
94
95
96
97
98
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
  __nv_bfloat162 res;
  res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
  res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
  return res;
99
100
101
102
}

// fp8x4 -> bf16_4_t
template <>
103
104
105
106
107
108
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
  bf16_4_t res;
  res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
  res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
  return res;
109
110
111
112
}

// fp8x8 -> bf16_8_t
template <>
113
114
115
116
117
118
119
120
121
122
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
  bf16_4_t tmp1, tmp2;
  tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
  tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
123
124
125
126
}

// fp8 -> float
template <>
127
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
128
129
130
  fp8_type f8;
  f8.__x = a;
  return static_cast<float>(f8);
131
132
133
134
}

// fp8x2 -> float2
template <>
135
136
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
137
138
139
  fp8x2_type f8x2;
  f8x2.__x = a;
  return static_cast<float2>(f8x2);
140
141
142
143
}

// fp8x4 -> float4
template <>
144
145
146
147
148
149
__inline__ __device__ Float4_
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
  Float4_ res;
  res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
  res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
  return res;
150
151
}

152
153
154
155
156
157
158
159
160
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
  Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
  float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  return res;
}

161
162
// fp8x8 -> float8
template <>
163
164
165
166
167
168
169
170
171
172
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
  Float4_ tmp1, tmp2;
  tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
  tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
173
174
175
176
}

// half -> fp8
template <>
177
178
179
180
__inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
  __half_raw tmp;
  tmp.x = a;
181
182
183
  return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
                                  fp8_type::__default_interpret);
}
184

185
186
187
188
189
190
191
192
193
194
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
  union {
    uint32_t ui32;
    __half2_raw h2r;
  } tmp;
  tmp.ui32 = a;
  return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
                                     fp8_type::__default_interpret);
195
196
197
198
}

// bf16 -> fp8
template <>
199
200
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
201
202
203
  return __hip_cvt_float_to_fp8(__bfloat162float(a),
                                fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
204
205
206
207
}

// float -> fp8
template <>
208
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
209
210
  return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
211
212
213
214
}

// float2 -> half2
template <>
215
216
217
218
219
220
__inline__ __device__ uint32_t
vec_conversion<uint32_t, float2>(const float2& a) {
  union {
    half2 float16;
    uint32_t uint32;
  };
221

222
223
  float16 = __float22half2_rn(a);
  return uint32;
224
225
226
227
}

// Float4 -> half2x2
template <>
228
229
230
231
232
233
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
  uint2 b;
  float2 val;
  val.x = a.x.x;
  val.y = a.x.y;
  b.x = vec_conversion<uint32_t, float2>(val);
234

235
236
237
238
  val.x = a.y.x;
  val.y = a.y.y;
  b.y = vec_conversion<uint32_t, float2>(val);
  return b;
239
240
241
242
}

// Float4 -> float4
template <>
243
244
245
246
247
248
249
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
  float4 b;
  b.x = a.x.x;
  b.y = a.x.y;
  b.z = a.y.x;
  b.w = a.y.y;
  return b;
250
251
252
253
}

// Float8 -> half2x4
template <>
254
255
256
257
258
259
260
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
  uint4 b;
  b.x = vec_conversion<uint32_t, float2>(a.x);
  b.y = vec_conversion<uint32_t, float2>(a.y);
  b.z = vec_conversion<uint32_t, float2>(a.z);
  b.w = vec_conversion<uint32_t, float2>(a.w);
  return b;
261
262
263
264
}

// float2 -> bfloat162
template <>
265
266
267
268
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
  __nv_bfloat162 b = __float22bfloat162_rn(a);
  return b;
269
270
271
272
}

// Float4 -> bfloat162x2
template <>
273
274
275
276
277
278
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
  bf16_4_t b;
  b.x = __float22bfloat162_rn(a.x);
  b.y = __float22bfloat162_rn(a.y);
  return b;
279
280
281
282
}

// Float8 -> bfloat162x4
template <>
283
284
285
286
287
288
289
290
__inline__ __device__ bf16_8_t
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
  bf16_8_t b;
  b.x = __float22bfloat162_rn(a.x);
  b.y = __float22bfloat162_rn(a.y);
  b.z = __float22bfloat162_rn(a.z);
  b.w = __float22bfloat162_rn(a.w);
  return b;
291
292
}

293
294
/* Scaled and vectorized conversions, for data exchange between high and low
   precision domains
295

296
297
298
   Convention of the scale in API, e.g: FP8_data = Quantization(
   High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
   scale =>  HP
299
300
301
302
303
304
305

 */

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
306
__inline__ __device__ __nv_bfloat16
307
308
309
310
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
  fp8_type f8;
  f8.__x = a;
  return __float2bfloat16(static_cast<float>(f8) * scale);
311
312
313
314
}

// fp8x2 -> __nv_bfloat162
template <>
315
316
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
317
                                                float scale) {
318
319
320
321
322
  __nv_bfloat162 res;
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  res.y =
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  return res;
323
324
325
326
}

// fp8x4 -> bf16_4_t
template <>
327
328
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
329
330
331
332
333
  bf16_4_t res;
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
                                                          scale);
  return res;
334
335
336
337
}

// fp8x8 -> bf16_8_t
template <>
338
__inline__ __device__ bf16_8_t
339
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
340
341
342
343
344
345
346
347
348
  bf16_4_t tmp1, tmp2;
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
349
350
351
352
}

// fp8 -> float
template <>
353
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
354
355
356
357
    const uint8_t& a, float scale) {
  fp8_type f8;
  f8.__x = a;
  return static_cast<float>(f8) * scale;
358
359
360
361
}

// fp8x2 -> float2
template <>
362
__inline__ __device__ float2
363
364
365
366
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
  fp8x2_type f8x2;
  f8x2.__x = a;
  return static_cast<float2>(f8x2) * scale;
367
368
369
370
}

// fp8x4 -> float4
template <>
371
372
373
374
375
376
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
  Float4_ res;
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  return res;
377
378
}

379
380
381
382
383
384
385
386
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
  Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  return {res.x.x, res.x.y, res.y.x, res.y.y};
}

387
388
// fp8x8 -> float8
template <>
389
__inline__ __device__ Float8_
390
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
391
392
393
394
395
396
397
398
399
  Float4_ tmp1, tmp2;
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
400
401
}

402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
  __half_raw res;
  res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
  return res.x;
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
  __half2_raw h2r =
      __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
  union {
    __half2_raw h2r;
    uint32_t ui32;
  } tmp;
  tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
  tmp.h2r.x.data *= scale;
  tmp.h2r.y.data *= scale;
  return tmp.ui32;
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
  tmp.u32[1] =
      scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
  return tmp.u32x2;
}
440

441
442
443
444
445
446
447
448
449
450
451
452
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
                                                                float scale) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  return tmp.u64x2;
}
453
454
455

// half -> fp8
template <>
456
__inline__ __device__ uint8_t
457
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
458
459
  __half_raw tmp;
  tmp.x = a;
460
461
462
463
  tmp.data /= scale;
  return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
                                  fp8_type::__default_interpret);
}
464

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
  union {
    uint32_t ui32;
    __half2_raw h2r;
  } tmp;
  tmp.ui32 = a;
  tmp.h2r.x.data /= scale;
  tmp.h2r.y.data /= scale;
  return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
                                     fp8_type::__default_interpret);
}

// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
  return tmp.ui32;
}

// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
                                                                float scale) {
  union {
    uint2 ui2[2];
    uint4 ui4;
  } tmp;
  tmp.ui4 = a;
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
  res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
  return res;
506
507
508
509
}

// bf16 -> fp8
template <>
510
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    const __nv_bfloat16& a, float scale) {
  return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
                                fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
}

// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
    const __nv_bfloat162& a, float scale) {
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
  return tmp.ui16;
}

// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
  return tmp.ui32;
}

// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
  res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
  return res;
551
552
553
554
}

// float -> fp8
template <>
555
__inline__ __device__ uint8_t
556
557
558
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
  return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
559
560
}

561
// floatx2 -> fp8x2
562
template <>
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
  return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
                                   fp8_type::__default_interpret);
}

// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
  return tmp.ui32;
580
}
581
  #endif  // ENABLE_FP8
582

583
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
584
585
__inline__ __device__ Tout convert(const Tin& x) {
  #ifdef ENABLE_FP8
586
587
588
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return vec_conversion<Tout, Tin>(x);
  }
589
  #endif
590
  assert(false);
591
  return {};  // Squash missing return statement warning
592
}
593
594

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
595
596
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
  #ifdef ENABLE_FP8
597
598
599
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return scaled_vec_conversion<Tout, Tin>(x, scale);
  }
600
  #endif
601
  assert(false);
602
  return {};  // Squash missing return statement warning
603
604
}

605
606
607
608
609
610
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
  #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
    if (KV_DTYPE == "auto") {                                                  \
611
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
612
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
613
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
614
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
615
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
616
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
617
618
619
620
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
    } else {                                                                   \
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
636

637
638
639
}  // namespace fp8
#endif  // USE_ROCM
}  // namespace vllm