quant_utils.cuh 19.2 KB
Newer Older
1
#pragma once
2
#ifndef USE_ROCM
3
#include <hip/hip_fp8.h>
4
#endif
5
6
7
8
9

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

10
#include "../../../attention/attention_dtypes.h"
11

12
namespace vllm {
13
14
15
#ifdef USE_ROCM

namespace fp8 {
zhuwenwen's avatar
zhuwenwen committed
16
  // #ifdef ENABLE_FP8
17

zhuwenwen's avatar
zhuwenwen committed
18
19
20
21
22
23
24
25
26
// KV-CACHE int8
static inline __device__ float fp8_to_float(uint8_t input) {
  const uint32_t w = (uint32_t)input << 24;
  const uint32_t sign = w & UINT32_C(0x80000000);
  const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
  uint32_t renorm_shift = __clz(nonsign);
  renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
  uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
  return c10::detail::fp32_from_bits(result);
27
28
}

zhuwenwen's avatar
zhuwenwen committed
29
// float -> fp8
30
static inline __device__ uint8_t float_to_fp8_e4m3(float f) {
zhuwenwen's avatar
zhuwenwen committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
  constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = 0x7f;
  } else {
    if (f_bits < (UINT32_C(121) << 23)) {
      f_bits =
        c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint8_t mant_odd = (f_bits >> 20) & 1;
      f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 20);
    }
  }
51

zhuwenwen's avatar
zhuwenwen committed
52
53
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
54
55
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
static inline __device__ uint8_t float_to_fp8_e5m2(float f) {
  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
  } else {
    if (f_bits < (UINT32_C(113) << 23)) {
      f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
               + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint32_t mant_odd = (f_bits >> 21) & 1;
      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 21);
    }
  }
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
}
81
82

template <typename Tout, typename Tin>
83
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
84
                                                 const float scale, Fp8KVCacheDataType kv_type) {
85
  return x;
86
87
88
89
90
91
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
92
__inline__ __device__ __nv_bfloat16
93
94
95
96
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
    assert(false);
  }
zhuwenwen's avatar
zhuwenwen committed
97
  return __float2bfloat16(fp8_to_float(a) * scale);
98
99
100
101
}

// fp8x2 -> __nv_bfloat162
template <>
102
103
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
104
                                                float scale, Fp8KVCacheDataType kv_type) {
105
  __nv_bfloat162 res;
106
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale, kv_type);
107
  res.y =
108
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
109
  return res;
110
111
112
113
}

// fp8x4 -> bf16_4_t
template <>
114
__inline__ __device__ bf16_4_t
115
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
116
  bf16_4_t res;
117
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale, kv_type);
118
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
119
                                                          scale, kv_type);
120
  return res;
121
122
123
124
}

// fp8x8 -> bf16_8_t
template <>
125
__inline__ __device__ bf16_8_t
126
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
127
  bf16_4_t tmp1, tmp2;
128
129
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale, kv_type);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale, kv_type);
130
131
132
133
134
135
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
136
137
138
139
}

// fp8 -> float
template <>
140
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
141
142
143
144
    const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
    if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
      assert(false);
    }
zhuwenwen's avatar
zhuwenwen committed
145
    return fp8_to_float(a) * scale;
146
147
148
149
}

// fp8x2 -> float2
template <>
150
__inline__ __device__ float2
151
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
152
    float2 f2r;
153
154
    f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale, kv_type);
    f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
155
    return f2r;
156
157
158
159
}

// fp8x4 -> float4
template <>
160
__inline__ __device__ Float4_
161
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale, Fp8KVCacheDataType kv_type) {
162
  Float4_ res;
zhuwenwen's avatar
zhuwenwen committed
163
164
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale, kv_type);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
165
  return res;
166
167
}

168
169
170
// fp8x4 -> float4
template <>
__inline__ __device__ float4
zhuwenwen's avatar
zhuwenwen committed
171
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
172
  Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale, kv_type);
173
174
175
  return {res.x.x, res.x.y, res.y.x, res.y.y};
}

176
177
// fp8x8 -> float8
template <>
178
__inline__ __device__ Float8_
179
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
180
  Float4_ tmp1, tmp2;
181
182
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale, kv_type);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale, kv_type);
183
184
185
186
187
188
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
189
190
}

191
192
193
// fp8 -> half
template <>
__inline__ __device__ uint16_t
194
195
196
197
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale, Fp8KVCacheDataType kv_type) {
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E5M2) {
    assert(false);
  }
zhuwenwen's avatar
zhuwenwen committed
198
199
  float res = fp8_to_float(a) * scale;
  return float_to_half(res);
200
201
202
203
204
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
205
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
206
  union {
zhuwenwen's avatar
zhuwenwen committed
207
208
209
    uint16_t u16[2];
    uint32_t u32;
  } res;
210
211
  res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale, kv_type);
  res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
212
  return res.u32;
213
214
215
216
217
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
218
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
219
220
221
222
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
223
224
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale, kv_type);
  tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale, kv_type);
225
226
  return tmp.u32x2;
}
227

228
229
230
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
231
                                                                float scale, Fp8KVCacheDataType kv_type) {
232
233
234
235
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
236
237
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale, kv_type);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale, kv_type);
238
239
  return tmp.u64x2;
}
240
241
242

// half -> fp8
template <>
243
__inline__ __device__ uint8_t
244
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
245
  float res_f = half_to_float(a) / scale;
246
247
248
249
250
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
    return float_to_fp8_e4m3(res_f);
  } else {
    return float_to_fp8_e5m2(res_f);
  }
251
}
252

253
254
255
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
256
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale, Fp8KVCacheDataType kv_type) {
257
  union {
zhuwenwen's avatar
zhuwenwen committed
258
259
    uint8_t ui8[2];
    uint16_t ui16;
260
  } tmp;
zhuwenwen's avatar
zhuwenwen committed
261
262
263
264
265
  union {
    uint32_t ui32;
    half2 h2r;
  } tmp_a;
  tmp_a.ui32 = a;
266
267
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale, kv_type);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
268
  return tmp.ui16;
269
270
271
272
273
}

// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
274
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale, Fp8KVCacheDataType kv_type) {
275
276
277
278
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
279
280
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale, kv_type);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale, kv_type);
281
282
283
284
285
286
  return tmp.ui32;
}

// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
287
                                                                float scale, Fp8KVCacheDataType kv_type) {
288
289
290
291
292
293
  union {
    uint2 ui2[2];
    uint4 ui4;
  } tmp;
  tmp.ui4 = a;
  uint2 res;
294
295
  res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale, kv_type);
  res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale, kv_type);
296
  return res;
297
298
299
300
}

// bf16 -> fp8
template <>
301
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
302
    const __nv_bfloat16& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
303
      float res_f = (static_cast<float>(a)) / scale;
304
305
306
307
308
      if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
        return float_to_fp8_e4m3(res_f);
      } else {
        return float_to_fp8_e5m2(res_f);
      }
309
310
311
312
313
}

// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
314
    const __nv_bfloat162& a, float scale, Fp8KVCacheDataType kv_type) {
315
316
317
318
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
319
320
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale, kv_type);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale, kv_type);
321
322
323
324
325
326
  return tmp.ui16;
}

// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
327
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale, Fp8KVCacheDataType kv_type) {
328
329
330
331
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
332
333
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale, kv_type);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale, kv_type);
334
335
336
337
338
339
  return tmp.ui32;
}

// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
340
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale, Fp8KVCacheDataType kv_type) {
341
  uint2 res;
342
343
  res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale, kv_type);
  res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale, kv_type);
344
  return res;
345
346
347
348
}

// float -> fp8
template <>
349
__inline__ __device__ uint8_t
350
351
352
353
354
355
356
scaled_vec_conversion<uint8_t, float>(const float& a, float scale, Fp8KVCacheDataType kv_type) {
  if (kv_type == vllm::Fp8KVCacheDataType::kFp8E4M3) {
    return float_to_fp8_e4m3(a / scale);
  } else {
    return float_to_fp8_e5m2(a / scale);
  }

357
358
}

359
// floatx2 -> fp8x2
360
template <>
361
__inline__ __device__ uint16_t
362
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale, Fp8KVCacheDataType kv_type) {
zhuwenwen's avatar
zhuwenwen committed
363
364
365
366
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
367
368
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale, kv_type);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale, kv_type);
zhuwenwen's avatar
zhuwenwen committed
369
  return tmp.ui16;
370
371
372
373
374
}

// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
375
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale, Fp8KVCacheDataType kv_type) {
376
377
378
379
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
380
381
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale, kv_type);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale, kv_type);
382
  return tmp.ui32;
383
384
}

zhuwenwen's avatar
zhuwenwen committed
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
inline __device__ uint8_t float_to_fp8e5m2(float f) {
  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
  } else {
    if (f_bits < (UINT32_C(113) << 23)) {
      f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
               + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint32_t mant_odd = (f_bits >> 21) & 1;
      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 21);
    }
  }
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
}

//  fp8
template <typename Tin>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
  return 0;
}

// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
  return float_to_fp8e5m2(a / scale);
}

// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
  float res_f = half_to_float(a) / scale;
  return float_to_fp8e5m2(res_f);
}

// bf16 -> fp8
template <>
__inline__ __device__ uint8_t 
scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
  float res_f = (static_cast<float>(a)) / scale;
  return float_to_fp8e5m2(res_f);
}

inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
  union uf16{
    uint16_t as_bits;
    _Float16 as_value;
  } ;
  uf16 u16;
  u16.as_bits = (uint16_t)input << 8;
  return (float)u16.as_value;
}

template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
  return 0;
}

// fp8 ->  float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
  return fp8e5m2_to_fp32(a)*scale;
}

// fp8  -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
  return float_to_half(fp8e5m2_to_fp32(a)*scale);
}

// fp8  -> bf16
template <>
__inline__ __device__ __nv_bfloat16 
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
  return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}


479
480

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
481
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
482
483
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3 || kv_dt == Fp8KVCacheDataType::kFp8E5M2) {
    return scaled_vec_conversion<Tout, Tin>(x, scale, kv_dt);
zhuwenwen's avatar
zhuwenwen committed
484
485
486
487
488
489
490
  }
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
    return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
  }
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
    return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
  }
491
  return {};  // Squash missing return statement warning
492
493
}

494
495
496
497
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
zhuwenwen's avatar
zhuwenwen committed
498
 #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
499
    if (KV_DTYPE == "auto") {                                                  \
500
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
501
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
502
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
503
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
504
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
505
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
506
507
508
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
xiabo's avatar
xiabo committed
509
510
511
512
513
514
515
516
517
518
    } else if (KV_DTYPE == "int8") {                                           \
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
        FN(float, uint8_t, vllm::Fp8KVCacheDataType::kInt8);                   \
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
        FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kInt8);                \
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
        FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kInt8);           \
      } else {                                                                 \
        TORCH_CHECK(false,"Unsupported input type of kv cache: ", SRC_DTYPE);  \
      }                                                                        \
519
    } else {                                                                   \
520
521
522
523
524
525
526
527
528
529
530
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
zhuwenwen's avatar
zhuwenwen committed
531
532
533
534
535
536
537
538
539
540
541
      } else if (KV_DTYPE == "fp8_e5m2") {                                     \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
542
543
544
545
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
546

zhuwenwen's avatar
zhuwenwen committed
547

548
549
}  // namespace fp8
#endif  // USE_ROCM
zhuwenwen's avatar
zhuwenwen committed
550
}  // namespace vllm