quant_utils.cuh 15.8 KB
Newer Older
1
#pragma once
2
#ifndef USE_ROCM
3
#include <hip/hip_fp8.h>
4
#endif
5
6
7
8
9

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

10
#include "../../../attention/attention_dtypes.h"
11

12
namespace vllm {
13
14
15
#ifdef USE_ROCM

namespace fp8 {
zhuwenwen's avatar
zhuwenwen committed
16
  // #ifdef ENABLE_FP8
17

zhuwenwen's avatar
zhuwenwen committed
18
19
20
21
22
23
24
25
26
// KV-CACHE int8
static inline __device__ float fp8_to_float(uint8_t input) {
  const uint32_t w = (uint32_t)input << 24;
  const uint32_t sign = w & UINT32_C(0x80000000);
  const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
  uint32_t renorm_shift = __clz(nonsign);
  renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
  uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
  return c10::detail::fp32_from_bits(result);
27
28
}

zhuwenwen's avatar
zhuwenwen committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) {
  constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
  constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = 0x7f;
  } else {
    if (f_bits < (UINT32_C(121) << 23)) {
      f_bits =
        c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint8_t mant_odd = (f_bits >> 20) & 1;
      f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 20);
    }
  }
51

zhuwenwen's avatar
zhuwenwen committed
52
53
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
54
55
}

56
57

template <typename Tout, typename Tin>
58
59
60
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
                                                 const float scale) {
  return x;
61
62
63
64
65
66
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
67
__inline__ __device__ __nv_bfloat16
68
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
69
70

  return __float2bfloat16(fp8_to_float(a) * scale);
71
72
73
74
}

// fp8x2 -> __nv_bfloat162
template <>
75
76
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
77
                                                float scale) {
78
79
80
81
82
  __nv_bfloat162 res;
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  res.y =
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  return res;
83
84
85
86
}

// fp8x4 -> bf16_4_t
template <>
87
88
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
89
90
91
92
93
  bf16_4_t res;
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
                                                          scale);
  return res;
94
95
96
97
}

// fp8x8 -> bf16_8_t
template <>
98
__inline__ __device__ bf16_8_t
99
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
100
101
102
103
104
105
106
107
108
  bf16_4_t tmp1, tmp2;
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
109
110
111
112
}

// fp8 -> float
template <>
113
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
114
    const uint8_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
115
    return fp8_to_float(a) * scale;
116
117
118
119
}

// fp8x2 -> float2
template <>
120
__inline__ __device__ float2
121
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
122
123
124
125
    float2 f2r;
    f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
    f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
    return f2r;
126
127
128
129
}

// fp8x4 -> float4
template <>
130
131
132
133
134
135
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
  Float4_ res;
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  return res;
136
137
}

138
139
140
141
142
143
144
145
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
  Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  return {res.x.x, res.x.y, res.y.x, res.y.y};
}

146
147
// fp8x8 -> float8
template <>
148
__inline__ __device__ Float8_
149
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
150
151
152
153
154
155
156
157
158
  Float4_ tmp1, tmp2;
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
159
160
}

161
162
163
164
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
165
166
  float res = fp8_to_float(a) * scale;
  return float_to_half(res);
167
168
169
170
171
172
173
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
  union {
zhuwenwen's avatar
zhuwenwen committed
174
175
176
177
178
179
    uint16_t u16[2];
    uint32_t u32;
  } res;
  res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
  res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
  return res.u32;
180
181
182
183
184
185
186
187
188
189
190
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
zhuwenwen's avatar
zhuwenwen committed
191
  tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
192
193
  return tmp.u32x2;
}
194

195
196
197
198
199
200
201
202
203
204
205
206
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
                                                                float scale) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  return tmp.u64x2;
}
207
208
209

// half -> fp8
template <>
210
__inline__ __device__ uint8_t
211
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
212
213
  float res_f = half_to_float(a) / scale;
  return float_to_fp8(res_f);
214
}
215

216
217
218
219
220
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
  union {
zhuwenwen's avatar
zhuwenwen committed
221
222
    uint8_t ui8[2];
    uint16_t ui16;
223
  } tmp;
zhuwenwen's avatar
zhuwenwen committed
224
225
226
227
228
229
230
231
  union {
    uint32_t ui32;
    half2 h2r;
  } tmp_a;
  tmp_a.ui32 = a;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
  return tmp.ui16;
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
}

// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
  return tmp.ui32;
}

// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
                                                                float scale) {
  union {
    uint2 ui2[2];
    uint4 ui4;
  } tmp;
  tmp.ui4 = a;
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
  res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
  return res;
260
261
262
263
}

// bf16 -> fp8
template <>
264
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
265
    const __nv_bfloat16& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
266
267
      float res_f = (static_cast<float>(a)) / scale;
      return float_to_fp8(res_f);
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
}

// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
    const __nv_bfloat162& a, float scale) {
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
  return tmp.ui16;
}

// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
  return tmp.ui32;
}

// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
  res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
  return res;
304
305
306
307
}

// float -> fp8
template <>
308
__inline__ __device__ uint8_t
309
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
310
  return float_to_fp8(a / scale);
311
312
}

313
// floatx2 -> fp8x2
314
template <>
315
316
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
317
318
319
320
321
322
323
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
  return tmp.ui16;
324
325
326
327
328
329
330
331
332
333
334
335
336
}

// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
  return tmp.ui32;
337
338
}

yangql's avatar
yangql committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
inline __device__ uint8_t float_to_fp8e5m2(float f) {
  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
  } else {
    if (f_bits < (UINT32_C(113) << 23)) {
      f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
               + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint32_t mant_odd = (f_bits >> 21) & 1;
      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 21);
    }
  }
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
}

//  fp8
template <typename Tin>
__inline__ __device__ uint8_t
368
scaled_vec_conversion_to_e5m2(const Tin& a, float scale) {
yangql's avatar
yangql committed
369
370
371
372
373
374
  return 0;
}

// float -> fp8
template <>
__inline__ __device__ uint8_t
375
scaled_vec_conversion_to_e5m2<float>(const float& a, float scale) {
yangql's avatar
yangql committed
376
377
378
379
380
381
  return float_to_fp8e5m2(a / scale);
}

// half -> fp8
template <>
__inline__ __device__ uint8_t
382
scaled_vec_conversion_to_e5m2<uint16_t>(const uint16_t& a, float scale) {
yangql's avatar
yangql committed
383
384
385
386
387
388
389
  float res_f = half_to_float(a) / scale;
  return float_to_fp8e5m2(res_f);
}

// bf16 -> fp8
template <>
__inline__ __device__ uint8_t 
390
scaled_vec_conversion_to_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
yangql's avatar
yangql committed
391
392
393
394
  float res_f = (static_cast<float>(a)) / scale;
  return float_to_fp8e5m2(res_f);
}

395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
inline __device__ float fp8e5m2_to_fp32(const uint8_t& input) {
  union uf16{
    uint16_t as_bits;
    _Float16 as_value;
  } ;
  uf16 u16;
  u16.as_bits = (uint16_t)input << 8;
  return (float)u16.as_value;
}

template <typename Tout>
__inline__ __device__ Tout
scaled_vec_conversion_from_e5m2(const uint8_t& a, float scale) {
  return 0;
}

// fp8 ->  float
template <>
__inline__ __device__ float
scaled_vec_conversion_from_e5m2<float>(const uint8_t& a, float scale) {
  return fp8e5m2_to_fp32(a)*scale;
}

// fp8  -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion_from_e5m2<uint16_t>(const uint8_t& a, float scale) {
  return float_to_half(fp8e5m2_to_fp32(a)*scale);
}

// fp8  -> bf16
template <>
__inline__ __device__ __nv_bfloat16 
scaled_vec_conversion_from_e5m2<__nv_bfloat16>(const uint8_t& a, float scale) {
  return __float2bfloat16(fp8e5m2_to_fp32(a)*scale);
}


433
434

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
435
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
yangql's avatar
yangql committed
436
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
437
    return scaled_vec_conversion<Tout, Tin>(x, scale);
yangql's avatar
yangql committed
438
439
  }
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
440
441
442
443
    return scaled_vec_conversion_to_e5m2<Tin>(x, scale);
  }
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tin)==1){
    return scaled_vec_conversion_from_e5m2<Tout>(x, scale);
yangql's avatar
yangql committed
444
  }
445
  return {};  // Squash missing return statement warning
446
447
}

448
449
450
451
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
yangql's avatar
yangql committed
452
 #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
453
    if (KV_DTYPE == "auto") {                                                  \
454
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
455
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
456
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
457
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
458
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
459
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
460
461
462
463
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
    } else {                                                                   \
464
465
466
467
468
469
470
471
472
473
474
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
yangql's avatar
yangql committed
475
476
477
478
479
480
481
482
483
484
485
      } else if (KV_DTYPE == "fp8_e5m2") {                                     \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
486
487
488
489
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
490

yangql's avatar
yangql committed
491

492
493
}  // namespace fp8
#endif  // USE_ROCM
yangql's avatar
yangql committed
494
}  // namespace vllm