quant_utils.cuh 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
#pragma once

#include "../../../attention/attention_dtypes.h"
#include <assert.h>
#include <float.h>
#include <stdint.h>
#include <type_traits>

namespace vllm {
#ifndef USE_ROCM

namespace fp8 {
13
  #ifdef ENABLE_FP8
14
15

template <typename Tout, typename Tin>
16
17
__inline__ __device__ Tout vec_conversion(
    const Tin& x, const __nv_fp8_interpretation_t fp8_type = __NV_E4M3) {
18
19
20
  return x;
}

21
22
23
24
25
26
27
28
29
30
31
32
33
34
// float -> c10::Float8_e4m3fn
template <>
__inline__ __device__ c10::Float8_e4m3fn
vec_conversion<c10::Float8_e4m3fn, float>(
    const float& a, const __nv_fp8_interpretation_t fp8_type) {
    #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
  return static_cast<c10::Float8_e4m3fn>(a);
    #else
  return c10::Float8_e4m3fn(__nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type),
                            c10::Float8_e4m3fn::from_bits());
    #endif
}

    #if 0  // Disable the following code to reduce the binary size.
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
// fp8 -> half
template <>
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(
    const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
  __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  return res.x;
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(
    const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  union {
    uint16_t u16[2];
    uint32_t u32;
  } tmp;
  __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
  tmp.u16[0] = res.x;
  tmp.u16[1] = res.y;
  return tmp.u32;
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(
    const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a, fp8_type);
  tmp.u32[1] =
      vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), fp8_type);
  return tmp.u32x2;
}

// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(
    const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x, fp8_type);
  tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y, fp8_type);
  return tmp.u64x2;
}

// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(
    const uint8_t &a, const __nv_fp8_interpretation_t fp8_type) {
  // Note there is no direct convert function from fp8 to bf16.
  // fp8 -> half
  __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  // half -> float -> bf16
  float tmp = half_to_float(res.x);
  return __float2bfloat16(tmp);
}

// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(
    const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  __nv_bfloat162 res;
  res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, fp8_type);
  res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), fp8_type);
  return res;
}

// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(
    const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  bf16_4_t res;
  res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, fp8_type);
  res.y =
      vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), fp8_type);
  return res;
}

// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(
    const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
  bf16_4_t tmp1, tmp2;
  tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x, fp8_type);
  tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y, fp8_type);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
}

// fp8 -> float
template <>
__inline__ __device__ float
vec_conversion<float, uint8_t>(const uint8_t &a,
                               const __nv_fp8_interpretation_t fp8_type) {
  // fp8 -> half
  uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a, fp8_type);
  // half -> float
  return half_to_float(tmp);
}

// fp8x2 -> float2
template <>
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(
    const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  // fp8x2 -> half2
  uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a, fp8_type);
  // half2 -> float2
  return half2_to_float2(tmp);
}

// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(
    const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  Float4_ res;
  res.x = vec_conversion<float2, uint16_t>((uint16_t)a, fp8_type);
  res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), fp8_type);
  return res;
}

// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(
    const uint2 &a, const __nv_fp8_interpretation_t fp8_type) {
  Float4_ tmp1, tmp2;
  tmp1 = vec_conversion<Float4_, uint32_t>(a.x, fp8_type);
  tmp2 = vec_conversion<Float4_, uint32_t>(a.y, fp8_type);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
}

// half -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(
    const uint16_t &a, const __nv_fp8_interpretation_t fp8_type) {
  __half_raw tmp;
  tmp.x = a;
  __nv_fp8_storage_t res =
      __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, fp8_type);
  return (uint8_t)res;
}

// bf16 -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(
    const __nv_bfloat16 &a, const __nv_fp8_interpretation_t fp8_type) {
193
      #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
194
  assert(false);
195
      #else
196
197
198
  __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(
      __nv_bfloat16_raw(a), __NV_SATFINITE, fp8_type);
  return (uint8_t)res;
199
      #endif
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
}

// float -> fp8
template <>
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(
    const float &a, const __nv_fp8_interpretation_t fp8_type) {
  __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, fp8_type);
  return (uint8_t)res;
}

// fp8x4 -> float4
template <>
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(
    const uint32_t &a, const __nv_fp8_interpretation_t fp8_type) {
  Float4_ tmp = vec_conversion<Float4_, uint32_t>(a, fp8_type);
  float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  return res;
}

template <>
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(
    const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
  union {
    half2 float16;
    uint32_t uint32;
  };

  float16 = __float22half2_rn(a);
  return uint32;
}

template <>
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(
    const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
  uint2 b;
  float2 val;
  val.x = a.x.x;
  val.y = a.x.y;
  b.x = vec_conversion<uint32_t, float2>(val, fp8_type);

  val.x = a.y.x;
  val.y = a.y.y;
  b.y = vec_conversion<uint32_t, float2>(val, fp8_type);

  return b;
}

template <>
__inline__ __device__ float4 vec_conversion<float4, Float4_>(
    const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
  float4 b;
  b.x = a.x.x;
  b.y = a.x.y;
  b.z = a.y.x;
  b.w = a.y.y;
  return b;
}

template <>
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(
    const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
  uint4 b;
  b.x = vec_conversion<uint32_t, float2>(a.x, fp8_type);
  b.y = vec_conversion<uint32_t, float2>(a.y, fp8_type);
  b.z = vec_conversion<uint32_t, float2>(a.z, fp8_type);
  b.w = vec_conversion<uint32_t, float2>(a.w, fp8_type);
  return b;
}

template <>
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(
    const float2 &a, const __nv_fp8_interpretation_t fp8_type) {
  __nv_bfloat162 b;
  from_float(b, a);
  return b;
}

template <>
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(
    const Float4_ &a, const __nv_fp8_interpretation_t fp8_type) {
  bf16_4_t b;
  from_float(b, a);
  return b;
}

template <>
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(
    const Float8_ &a, const __nv_fp8_interpretation_t fp8_type) {
  bf16_8_t b;
  from_float(b, a);
  return b;
}
292
    #endif
293
294
295
296
297
298
299
300
301

/* Scaled and vectorized conversions, for data exchange between high and low
   precision domains Convention of the scale in API, e.g: FP8_data =
   Quantization( High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8
     Dequant(FP8) * scale =>  HP
 */

template <typename Tout, typename Tin>
__inline__ __device__ Tout scaled_vec_conversion(
302
    const Tin& x, const float scale, const __nv_fp8_interpretation_t fp8_type) {
303
304
305
306
307
308
  return x;
}

// fp8 -> half
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, uint8_t>(
309
    const uint8_t& a, const float scale,
310
311
312
313
314
315
316
317
    const __nv_fp8_interpretation_t fp8_type) {
  __half_raw tmp = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  return float_to_half(half_to_float(tmp.x) * scale);
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
318
    const uint16_t& a, const float scale,
319
320
321
322
323
324
325
326
327
328
329
330
331
332
    const __nv_fp8_interpretation_t fp8_type) {
  union {
    uint16_t u16[2];
    uint32_t u32;
  } tmp;
  __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, fp8_type);
  tmp.u16[0] = float_to_half(half_to_float(res.x) * scale);
  tmp.u16[1] = float_to_half(half_to_float(res.y) * scale);
  return tmp.u32;
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint32_t>(
333
    const uint32_t& a, const float scale,
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
    const __nv_fp8_interpretation_t fp8_type) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] =
      scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, fp8_type);
  tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U),
                                                         scale, fp8_type);
  return tmp.u32x2;
}

// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4
349
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale,
350
351
352
353
354
355
356
357
358
359
360
361
362
363
                                    const __nv_fp8_interpretation_t fp8_type) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, fp8_type);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, fp8_type);
  return tmp.u64x2;
}

// fp8 -> __nv_bfloat16
template <>
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(
364
    const uint8_t& a, const float scale,
365
366
367
368
369
370
371
372
373
374
375
376
377
    const __nv_fp8_interpretation_t fp8_type) {
  // Note there is no direct convert function from fp8 to bf16.
  // fp8 -> half
  __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  // half -> float -> bf16
  float tmp = half_to_float(res.x);
  return __float2bfloat16(tmp * scale);
}

// fp8x2 -> __nv_bfloat162
template <>
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(
378
    const uint16_t& a, const float scale,
379
380
381
382
383
384
385
386
387
388
389
390
    const __nv_fp8_interpretation_t fp8_type) {
  __nv_bfloat162 res;
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale,
                                                        fp8_type);
  res.y = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U),
                                                        scale, fp8_type);
  return res;
}

// fp8x4 -> bf16_4_t
template <>
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
391
    const uint32_t& a, const float scale,
392
393
394
395
396
397
398
399
400
401
402
403
    const __nv_fp8_interpretation_t fp8_type) {
  bf16_4_t res;
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale,
                                                          fp8_type);
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
                                                          scale, fp8_type);
  return res;
}

// fp8x8 -> bf16_8_t
template <>
__inline__ __device__ bf16_8_t scaled_vec_conversion<bf16_8_t, uint2>(
404
    const uint2& a, const float scale,
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
    const __nv_fp8_interpretation_t fp8_type) {
  bf16_4_t tmp1, tmp2;
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, fp8_type);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, fp8_type);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
}

// fp8 -> float
template <>
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
420
    const uint8_t& a, const float scale,
421
422
423
424
425
426
427
428
429
430
431
432
    const __nv_fp8_interpretation_t fp8_type) {
  // fp8 -> half
  __half_raw res = __nv_cvt_fp8_to_halfraw(a, fp8_type);
  uint16_t tmp = res.x;

  // half -> float
  return half_to_float(tmp) * scale;
}

// fp8x2 -> float2
template <>
__inline__ __device__ float2 scaled_vec_conversion<float2, uint16_t>(
433
    const uint16_t& a, const float scale,
434
435
436
437
438
439
440
441
442
443
    const __nv_fp8_interpretation_t fp8_type) {
  // fp8x2 -> half2
  uint32_t tmp = scaled_vec_conversion<uint32_t, uint16_t>(a, scale, fp8_type);
  // half2 -> float2
  return half2_to_float2(tmp);
}

// fp8x4 -> float4
template <>
__inline__ __device__ Float4_ scaled_vec_conversion<Float4_, uint32_t>(
444
    const uint32_t& a, const float scale,
445
446
447
448
449
450
451
452
453
454
455
    const __nv_fp8_interpretation_t fp8_type) {
  Float4_ res;
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, fp8_type);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale,
                                                  fp8_type);
  return res;
}

// fp8x8 -> float8
template <>
__inline__ __device__ Float8_ scaled_vec_conversion<Float8_, uint2>(
456
    const uint2& a, const float scale,
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
    const __nv_fp8_interpretation_t fp8_type) {
  Float4_ tmp1, tmp2;
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, fp8_type);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, fp8_type);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
}

// half -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, uint16_t>(
472
    const uint16_t& a, const float scale,
473
474
475
476
477
478
479
480
481
    const __nv_fp8_interpretation_t fp8_type) {
  __nv_fp8_storage_t res =
      __nv_cvt_float_to_fp8(half_to_float(a) / scale, __NV_SATFINITE, fp8_type);
  return (uint8_t)res;
}

// bf16 -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
482
    const __nv_bfloat16& a, const float scale,
483
    const __nv_fp8_interpretation_t fp8_type) {
484
    #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
485
  assert(false);
486
    #else
487
488
489
  __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(__bfloat162float(a) / scale,
                                                 __NV_SATFINITE, fp8_type);
  return (uint8_t)res;
490
    #endif
491
  __builtin_unreachable();  // Suppress missing return statement warning
492
493
494
495
496
}

// float -> fp8
template <>
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, float>(
497
    const float& a, const float scale,
498
499
500
501
502
503
504
505
506
    const __nv_fp8_interpretation_t fp8_type) {
  __nv_fp8_storage_t res =
      __nv_cvt_float_to_fp8(a / scale, __NV_SATFINITE, fp8_type);
  return (uint8_t)res;
}

// fp8x4 -> float4
template <>
__inline__ __device__ float4 scaled_vec_conversion<float4, uint32_t>(
507
    const uint32_t& a, const float scale,
508
509
510
511
512
    const __nv_fp8_interpretation_t fp8_type) {
  Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale, fp8_type);
  float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  return res;
}
513
  #endif  // ENABLE_FP8
514
515

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
516
517
__inline__ __device__ Tout convert(const Tin& x) {
  #if 0  // Disable the following code to reduce the binary size.
518
519
520
521
522
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return vec_conversion<Tout, Tin>(x, __NV_E4M3);
  } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
    return vec_conversion<Tout, Tin>(x, __NV_E5M2);
  }
523
  #endif
524
  assert(false);
525
  __builtin_unreachable();  // Suppress missing return statement warning
526
527
528
}

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
529
530
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
  #ifdef ENABLE_FP8
531
532
533
534
535
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E4M3);
  } else if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
    return scaled_vec_conversion<Tout, Tin>(x, scale, __NV_E5M2);
  }
536
  #endif
537
  assert(false);
538
  __builtin_unreachable();  // Suppress missing return statement warning
539
540
}

541
542
543
544
545
546
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
  #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
    if (KV_DTYPE == "auto") {                                                  \
547
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
548
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
549
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
550
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
551
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
552
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
553
554
555
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    } else {                                                                   \
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
      } else if (KV_DTYPE == "fp8_e5m2") {                                     \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
579
      } else {                                                                 \
580
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
581
      }                                                                        \
582
    }
583

584
585
586
}  // namespace fp8
#endif  // not USE_ROCM
}  // namespace vllm