quant_utils.cuh 19.7 KB
Newer Older
1
#pragma once
2
#ifndef USE_ROCM
3
#include <hip/hip_fp8.h>
4
#endif
5
6
7
8
9

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

10
#include "../../../attention/attention_dtypes.h"
11

12
namespace vllm {
13
14
15
#ifdef USE_ROCM

namespace fp8 {
16
  #ifdef ENABLE_FP8
17

18
19
20
21
22
23
// Use hardware cvt instruction for fp8 on rocm
template <typename fp8_type>
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
  return {};
}

24
25
26
27
28
29
// __hip_fp8_e4m3 only exists starting in ROCm 6.3. The macro
// HIP_FP8_TYPE_OCP comes from the hip_fp8.h header and also makes
// its first appearance in ROCm 6.3. Since VLLM_DISPATCH_FP8_TYPES
// on ROCm instantiates both OCP and FNUZ kernels, we need to replace
// the new HW cvt with something reasonable that doesn't rely on the
// ROCm 6.3 feature. This allows compiling on ROCm 6.2 or newer.
30
31
template <>
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
32
    #if HIP_FP8_TYPE_OCP
33
34
35
36
  return c10::Float8_e4m3fn(
      __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
                             __hip_fp8_e4m3::__default_interpret),
      c10::Float8_e4m3fn::from_bits());
37
38
39
40
41
    #else
  // Cast implemented by pytorch. Uses bit manipulation instead of HW cvt.
  // HW cvt above is faster when it is available (ROCm 6.3 or newer).
  return static_cast<c10::Float8_e4m3fn>(r);
    #endif
42
43
44
45
46
47
48
49
50
51
}

template <>
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
  return c10::Float8_e4m3fnuz(
      __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
                             __hip_fp8_e4m3_fnuz::__default_interpret),
      c10::Float8_e4m3fnuz::from_bits());
}

52
template <typename Tout, typename Tin>
53
54
__inline__ __device__ Tout vec_conversion(const Tin& x) {
  return x;
55
56
57
}

template <typename Tout, typename Tin>
58
59
60
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
                                                 const float scale) {
  return x;
61
62
}

63
    #if HIP_FP8_TYPE_OCP
64
65
using fp8_type = __hip_fp8_e4m3;
using fp8x2_type = __hip_fp8x2_e4m3;
66
67
68
    #else
using fp8_type = __hip_fp8_e4m3_fnuz;
using fp8x2_type = __hip_fp8x2_e4m3_fnuz;
69
70
    #endif

71
72
// fp8 -> half
template <>
73
74
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
75
  return __hip_cvt_fp8_to_halfraw(a, fp8_type::__default_interpret).x;
76
77
78
79
}

// fp8x2 -> half2
template <>
80
81
82
83
84
85
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
  union {
    __half2_raw h2r;
    uint32_t ui32;
  } tmp;
86
  tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
87
  return tmp.ui32;
88
89
90
91
}

// fp8x4 -> half2x2
template <>
92
93
94
95
96
97
98
99
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
  tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
  return tmp.u32x2;
100
101
102
103
}

// fp8x8 -> half2x4
template <>
104
105
106
107
108
109
110
111
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
  tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
  return tmp.u64x2;
112
113
114
115
116
117
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
118
119
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
120
121
122
  fp8_type f8;
  f8.__x = a;
  return __float2bfloat16(static_cast<float>(f8));
123
124
125
126
127
128
}

using __nv_bfloat162 = __hip_bfloat162;

// fp8x2 -> __nv_bfloat162
template <>
129
130
131
132
133
134
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
  __nv_bfloat162 res;
  res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
  res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
  return res;
135
136
137
138
}

// fp8x4 -> bf16_4_t
template <>
139
140
141
142
143
144
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
  bf16_4_t res;
  res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
  res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
  return res;
145
146
147
148
}

// fp8x8 -> bf16_8_t
template <>
149
150
151
152
153
154
155
156
157
158
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
  bf16_4_t tmp1, tmp2;
  tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
  tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
159
160
161
162
}

// fp8 -> float
template <>
163
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
164
165
166
  fp8_type f8;
  f8.__x = a;
  return static_cast<float>(f8);
167
168
169
170
}

// fp8x2 -> float2
template <>
171
172
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
173
174
175
  fp8x2_type f8x2;
  f8x2.__x = a;
  return static_cast<float2>(f8x2);
176
177
178
179
}

// fp8x4 -> float4
template <>
180
181
182
183
184
185
__inline__ __device__ Float4_
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
  Float4_ res;
  res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
  res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
  return res;
186
187
}

188
189
190
191
192
193
194
195
196
// fp8x4 -> float4
template <>
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
  Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
  float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  return res;
}

197
198
// fp8x8 -> float8
template <>
199
200
201
202
203
204
205
206
207
208
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
  Float4_ tmp1, tmp2;
  tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
  tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
209
210
211
212
}

// half -> fp8
template <>
213
214
215
216
__inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
  __half_raw tmp;
  tmp.x = a;
217
218
219
  return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
                                  fp8_type::__default_interpret);
}
220

221
222
223
224
225
226
227
228
229
230
template <>
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint32_t>(const uint32_t& a) {
  union {
    uint32_t ui32;
    __half2_raw h2r;
  } tmp;
  tmp.ui32 = a;
  return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
                                     fp8_type::__default_interpret);
231
232
233
234
}

// bf16 -> fp8
template <>
235
236
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
237
238
239
  return __hip_cvt_float_to_fp8(__bfloat162float(a),
                                fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
240
241
242
243
}

// float -> fp8
template <>
244
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
245
246
  return __hip_cvt_float_to_fp8(a, fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
247
248
249
250
}

// float2 -> half2
template <>
251
252
253
254
255
256
__inline__ __device__ uint32_t
vec_conversion<uint32_t, float2>(const float2& a) {
  union {
    half2 float16;
    uint32_t uint32;
  };
257

258
259
  float16 = __float22half2_rn(a);
  return uint32;
260
261
262
263
}

// Float4 -> half2x2
template <>
264
265
266
267
268
269
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
  uint2 b;
  float2 val;
  val.x = a.x.x;
  val.y = a.x.y;
  b.x = vec_conversion<uint32_t, float2>(val);
270

271
272
273
274
  val.x = a.y.x;
  val.y = a.y.y;
  b.y = vec_conversion<uint32_t, float2>(val);
  return b;
275
276
277
278
}

// Float4 -> float4
template <>
279
280
281
282
283
284
285
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
  float4 b;
  b.x = a.x.x;
  b.y = a.x.y;
  b.z = a.y.x;
  b.w = a.y.y;
  return b;
286
287
288
289
}

// Float8 -> half2x4
template <>
290
291
292
293
294
295
296
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
  uint4 b;
  b.x = vec_conversion<uint32_t, float2>(a.x);
  b.y = vec_conversion<uint32_t, float2>(a.y);
  b.z = vec_conversion<uint32_t, float2>(a.z);
  b.w = vec_conversion<uint32_t, float2>(a.w);
  return b;
297
298
299
300
}

// float2 -> bfloat162
template <>
301
302
303
304
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
  __nv_bfloat162 b = __float22bfloat162_rn(a);
  return b;
305
306
307
308
}

// Float4 -> bfloat162x2
template <>
309
310
311
312
313
314
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
  bf16_4_t b;
  b.x = __float22bfloat162_rn(a.x);
  b.y = __float22bfloat162_rn(a.y);
  return b;
315
316
317
318
}

// Float8 -> bfloat162x4
template <>
319
320
321
322
323
324
325
326
__inline__ __device__ bf16_8_t
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
  bf16_8_t b;
  b.x = __float22bfloat162_rn(a.x);
  b.y = __float22bfloat162_rn(a.y);
  b.z = __float22bfloat162_rn(a.z);
  b.w = __float22bfloat162_rn(a.w);
  return b;
327
328
}

329
330
/* Scaled and vectorized conversions, for data exchange between high and low
   precision domains
331

332
333
334
   Convention of the scale in API, e.g: FP8_data = Quantization(
   High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
   scale =>  HP
335
336
337
338
339
340
341

 */

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
342
__inline__ __device__ __nv_bfloat16
343
344
345
346
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
  fp8_type f8;
  f8.__x = a;
  return __float2bfloat16(static_cast<float>(f8) * scale);
347
348
349
350
}

// fp8x2 -> __nv_bfloat162
template <>
351
352
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
353
                                                float scale) {
354
355
356
357
358
  __nv_bfloat162 res;
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  res.y =
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  return res;
359
360
361
362
}

// fp8x4 -> bf16_4_t
template <>
363
364
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
365
366
367
368
369
  bf16_4_t res;
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
                                                          scale);
  return res;
370
371
372
373
}

// fp8x8 -> bf16_8_t
template <>
374
__inline__ __device__ bf16_8_t
375
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
376
377
378
379
380
381
382
383
384
  bf16_4_t tmp1, tmp2;
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
385
386
387
388
}

// fp8 -> float
template <>
389
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
390
391
392
393
    const uint8_t& a, float scale) {
  fp8_type f8;
  f8.__x = a;
  return static_cast<float>(f8) * scale;
394
395
396
397
}

// fp8x2 -> float2
template <>
398
__inline__ __device__ float2
399
400
401
402
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
  fp8x2_type f8x2;
  f8x2.__x = a;
  return static_cast<float2>(f8x2) * scale;
403
404
405
406
}

// fp8x4 -> float4
template <>
407
408
409
410
411
412
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
  Float4_ res;
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  return res;
413
414
}

415
416
417
418
419
420
421
422
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
  Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  return {res.x.x, res.x.y, res.y.x, res.y.y};
}

423
424
// fp8x8 -> float8
template <>
425
__inline__ __device__ Float8_
426
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
427
428
429
430
431
432
433
434
435
  Float4_ tmp1, tmp2;
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
436
437
}

438
439
440
441
442
443
444
445
446
447
448
449
450
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
  __half_raw res;
  res.data = scaled_vec_conversion<float, uint8_t>(a, scale);
  return res.x;
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
451
  [[maybe_unused]] __half2_raw h2r =
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
      __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
  union {
    __half2_raw h2r;
    uint32_t ui32;
  } tmp;
  tmp.h2r = __hip_cvt_fp8x2_to_halfraw2(a, fp8_type::__default_interpret);
  tmp.h2r.x.data *= scale;
  tmp.h2r.y.data *= scale;
  return tmp.ui32;
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
  tmp.u32[1] =
      scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
  return tmp.u32x2;
}
476

477
478
479
480
481
482
483
484
485
486
487
488
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
                                                                float scale) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  return tmp.u64x2;
}
489
490
491

// half -> fp8
template <>
492
__inline__ __device__ uint8_t
493
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
494
495
  __half_raw tmp;
  tmp.x = a;
496
497
498
499
  tmp.data /= scale;
  return __hip_cvt_halfraw_to_fp8(tmp, fp8_type::__default_saturation,
                                  fp8_type::__default_interpret);
}
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
  union {
    uint32_t ui32;
    __half2_raw h2r;
  } tmp;
  tmp.ui32 = a;
  tmp.h2r.x.data /= scale;
  tmp.h2r.y.data /= scale;
  return __hip_cvt_halfraw2_to_fp8x2(tmp.h2r, fp8_type::__default_saturation,
                                     fp8_type::__default_interpret);
}

// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
  return tmp.ui32;
}

// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
                                                                float scale) {
  union {
    uint2 ui2[2];
    uint4 ui4;
  } tmp;
  tmp.ui4 = a;
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
  res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
  return res;
542
543
544
545
}

// bf16 -> fp8
template <>
546
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
    const __nv_bfloat16& a, float scale) {
  return __hip_cvt_float_to_fp8(__bfloat162float(a) / scale,
                                fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
}

// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
    const __nv_bfloat162& a, float scale) {
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
  return tmp.ui16;
}

// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
  return tmp.ui32;
}

// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
  res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
  return res;
587
588
589
590
}

// float -> fp8
template <>
591
__inline__ __device__ uint8_t
592
593
594
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
  return __hip_cvt_float_to_fp8(a / scale, fp8_type::__default_saturation,
                                fp8_type::__default_interpret);
595
596
}

597
// floatx2 -> fp8x2
598
template <>
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
  return __hip_cvt_float2_to_fp8x2(a / scale, fp8_type::__default_saturation,
                                   fp8_type::__default_interpret);
}

// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
  return tmp.ui32;
616
}
617
  #endif  // ENABLE_FP8
618

619
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
620
621
__inline__ __device__ Tout convert(const Tin& x) {
  #ifdef ENABLE_FP8
622
623
624
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return vec_conversion<Tout, Tin>(x);
  }
625
  #endif
626
  assert(false);
627
  return {};  // Squash missing return statement warning
628
}
629
630

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
631
632
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
  #ifdef ENABLE_FP8
633
634
635
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return scaled_vec_conversion<Tout, Tin>(x, scale);
  }
636
  #endif
637
  assert(false);
638
  return {};  // Squash missing return statement warning
639
640
}

641
642
643
644
645
646
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
  #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
    if (KV_DTYPE == "auto") {                                                  \
647
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
648
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
649
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
650
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
651
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
652
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
653
654
655
656
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
    } else {                                                                   \
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
672

673
674
675
}  // namespace fp8
#endif  // USE_ROCM
}  // namespace vllm