quant_utils.cuh 19.2 KB
Newer Older
1
#pragma once
2
#ifndef USE_ROCM
3
#include <hip/hip_fp8.h>
4
#endif
5
6
7
8
9

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

10
#include "../../../attention/attention_dtypes.h"
11

12
namespace vllm {
13
14
15
#ifdef USE_ROCM

namespace fp8 {
zhuwenwen's avatar
zhuwenwen committed
16
  // #ifdef ENABLE_FP8
17

zhuwenwen's avatar
zhuwenwen committed
18
19
20
21
22
23
24
25
26
// KV-CACHE int8
static inline __device__ float fp8_to_float(uint8_t input) {
  const uint32_t w = (uint32_t)input << 24;
  const uint32_t sign = w & UINT32_C(0x80000000);
  const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
  uint32_t renorm_shift = __clz(nonsign);
  renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
  uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
  return c10::detail::fp32_from_bits(result);
27
28
}

zhuwenwen's avatar
zhuwenwen committed
29
// float -> fp8
30
static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
  constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = 0x7f;
  } else {
    if (f_bits < (UINT32_C(121) << 23)) {
      f_bits =
        c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint8_t mant_odd = (f_bits >> 20) & 1;
      f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 20);
    }
  }
51

zhuwenwen's avatar
zhuwenwen committed
52
53
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
54
55
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
  } else {
    if (f_bits < (UINT32_C(113) << 23)) {
      f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
               + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint32_t mant_odd = (f_bits >> 21) & 1;
      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 21);
    }
  }
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
}
81
82

template <typename Tout, typename Tin>
83
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
84
                                                 const float scale, Fp8KVCacheDataType kv_type) {
85
  return x;
86
87
88
89
90
91
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
92
__inline__ __device__ __nv_bfloat16
93
94
95
96
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
    assert(false);
  }
zhuwenwen's avatar
zhuwenwen committed
97
98

  return __float2bfloat16(fp8_to_float(a) * scale);
99
100
101
102
}

// fp8x2 -> __nv_bfloat162
template <>
103
104
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
105
                                                float scale, Fp8KVCacheDataType kv_type) {
106
  __nv_bfloat162 res;
107
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
108
  res.y =
109
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
110
  return res;
111
112
113
114
}

// fp8x4 -> bf16_4_t
template <>
115
__inline__ __device__ bf16_4_t
116
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
117
  bf16_4_t res;
118
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
119
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
120
                                                          scale, kv_type);
121
  return res;
122
123
124
125
}

// fp8x8 -> bf16_8_t
template <>
126
__inline__ __device__ bf16_8_t
127
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
128
  bf16_4_t tmp1, tmp2;
129
130
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
131
132
133
134
135
136
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
137
138
139
140
}

// fp8 -> float
template <>
141
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
142
143
144
145
    const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
    if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
      assert(false);
    }
zhuwenwen's avatar
zhuwenwen committed
146
    return fp8_to_float(a) * scale;
147
148
149
150
}

// fp8x2 -> float2
template <>
151
__inline__ __device__ float2
152
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
153
    float2 f2r;
154
155
    f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
    f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
156
    return f2r;
157
158
159
160
}

// fp8x4 -> float4
template <>
161
__inline__ __device__ Float4_
162
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
163
  Float4_ res;
164
165
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
166
  return res;
167
168
}

169
170
171
// fp8x4 -> float4
template <>
__inline__ __device__ float4
172
173
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
  Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
174
175
176
  return {res.x.x, res.x.y, res.y.x, res.y.y};
}

177
178
// fp8x8 -> float8
template <>
179
__inline__ __device__ Float8_
180
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
181
  Float4_ tmp1, tmp2;
182
183
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
184
185
186
187
188
189
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
190
191
}

192
193
194
// fp8 -> half
template <>
__inline__ __device__ uint16_t
195
196
197
198
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
    assert(false);
  }
zhuwenwen's avatar
zhuwenwen committed
199
200
  float res = fp8_to_float(a) * scale;
  return float_to_half(res);
201
202
203
204
205
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
206
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
207
  union {
zhuwenwen's avatar
zhuwenwen committed
208
209
210
    uint16_t u16[2];
    uint32_t u32;
  } res;
211
212
  res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
  res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
213
  return res.u32;
214
215
216
217
218
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
219
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
220
221
222
223
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
224
225
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
  tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
226
227
  return tmp.u32x2;
}
228

229
230
231
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
232
                                                                float scale, Fp8KVCacheDataType kv_type) {
233
234
235
236
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
237
238
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
239
240
  return tmp.u64x2;
}
241
242
243

// half -> fp8
template <>
244
__inline__ __device__ uint8_t
245
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
246
  float res_f = half_to_float(a) / scale;
247
248
249
250
251
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
    return float_to_fp8_e4m3(res_f);
  } else {
    return float_to_fp8_e5m2(res_f);
  }
252
}
253

254
255
256
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
257
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
258
  union {
zhuwenwen's avatar
zhuwenwen committed
259
260
    uint8_t ui8[2];
    uint16_t ui16;
261
  } tmp;
zhuwenwen's avatar
zhuwenwen committed
262
263
264
265
266
  union {
    uint32_t ui32;
    half2 h2r;
  } tmp_a;
  tmp_a.ui32 = a;
267
268
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
269
  return tmp.ui16;
270
271
272
273
274
}

// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
275
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
276
277
278
279
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
280
281
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
282
283
284
285
286
287
  return tmp.ui32;
}

// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
288
                                                                float scale, Fp8KVCacheDataType kv_type) {
289
290
291
292
293
294
  union {
    uint2 ui2[2];
    uint4 ui4;
  } tmp;
  tmp.ui4 = a;
  uint2 res;
295
296
  res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
  res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
297
  return res;
298
299
300
301
}

// bf16 -> fp8
template <>
302
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
303
    const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
304
      float res_f = (static_cast<float>(a)) / scale;
305
306
307
308
309
      if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
        return float_to_fp8_e4m3(res_f);
      } else {
        return float_to_fp8_e5m2(res_f);
      }
310
311
312
313
314
}

// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
315
    const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
316
317
318
319
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
320
321
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
322
323
324
325
326
327
  return tmp.ui16;
}

// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
328
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
329
330
331
332
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
333
334
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
335
336
337
338
339
340
  return tmp.ui32;
}

// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
341
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
342
  uint2 res;
343
344
  res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
  res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
345
  return res;
346
347
348
349
}

// float -> fp8
template <>
350
__inline__ __device__ uint8_t
351
352
353
354
355
356
scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
    return float_to_fp8_e4m3(a / scale);
  } else {
    return float_to_fp8_e5m2(a / scale);
  }
357
358
}

359
// floatx2 -> fp8x2
360
template <>
361
__inline__ __device__ uint16_t
362
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
363
364
365
366
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
367
368
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
369
  return tmp.ui16;
370
371
372
373
374
}

// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
375
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
376
377
378
379
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
380
381
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
382
  return tmp.ui32;
383
384
}

zhuwenwen's avatar
zhuwenwen committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
inline __device__ uint8_t float_to_fp8e5m2(float f) {
  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
  } else {
    if (f_bits < (UINT32_C(113) << 23)) {
      f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
               + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint32_t mant_odd = (f_bits >> 21) & 1;
      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 21);
    }
  }
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
}

//  fp8
template <typename Tin>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
  return 0;
}

// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
  return float_to_fp8e5m2(a / scale);
}

// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
  float res_f = half_to_float(a) / scale;
  return float_to_fp8e5m2(res_f);
}

// bf16 -> fp8
template <>
__inline__ __device__ uint8_t 
scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
  float res_f = (static_cast<float>(a)) / scale;
  return float_to_fp8e5m2(res_f);
}

inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
  union uf16{
    uint16_t as_bits;
    _Float16 as_value;
  } ;
  uf16 u16;
  u16.as_bits = (uint16_t)input << 8;
  return (float)u16.as_value;
}

template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
  return 0;
}

// fp8 ->  float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
  return fp8e5m2_to_fp32(a)*scale;
}

// fp8  -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
  return float_to_half(fp8e5m2_to_fp32(a)*scale);
}

// fp8  -> bf16
template <>
__inline__ __device__ __nv_bfloat16 
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
  return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}


479
480

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
481
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
482
483
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
    return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
zhuwenwen's avatar
zhuwenwen committed
484
485
486
487
488
489
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
    return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
  }
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
    return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
  }
490
  return {};  // Squash missing return statement warning
491
492
}

493
494
495
496
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
zhuwenwen's avatar
zhuwenwen committed
497
 #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
498
    if (KV_DTYPE == "auto") {                                                  \
499
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
500
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
501
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
502
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
503
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
504
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
505
506
507
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
xiabo's avatar
xiabo committed
508
509
510
511
512
513
514
515
516
517
    } else if (KV_DTYPE == "int8") {                                           \
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
        FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8);                   \
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
        FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8);                \
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
        FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8);           \
      } else {                                                                 \
        TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE);  \
      }                                                                        \
518
    } else {                                                                   \
519
520
521
522
523
524
525
526
527
528
529
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
zhuwenwen's avatar
zhuwenwen committed
530
531
532
533
534
535
536
537
538
539
540
      } else if (KV_DTYPE == "fp8_e5m2") {                                     \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
541
542
543
544
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
545

zhuwenwen's avatar
zhuwenwen committed
546

547
548
}  // namespace fp8
#endif  // USE_ROCM
zhuwenwen's avatar
zhuwenwen committed
549
}  // namespace vllm