quant_utils.cuh 15 KB
Newer Older
1
#pragma once
2
#ifndef USE_ROCM
3
#include <hip/hip_fp8.h>
4
#endif
5
6
7
8
9

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

10
#include "../../../attention/attention_dtypes.h"
11

12
namespace vllm {
13
14
15
#ifdef USE_ROCM

namespace fp8 {
zhuwenwen's avatar
zhuwenwen committed
16
  // #ifdef ENABLE_FP8
17

zhuwenwen's avatar
zhuwenwen committed
18
19
20
21
22
23
24
25
26
// KV-CACHE int8
static inline __device__ float fp8_to_float(uint8_t input) {
  const uint32_t w = (uint32_t)input << 24;
  const uint32_t sign = w & UINT32_C(0x80000000);
  const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
  uint32_t renorm_shift = __clz(nonsign);
  renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
  uint32_t result = sign | ((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23));
  return c10::detail::fp32_from_bits(result);
27
28
}

zhuwenwen's avatar
zhuwenwen committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
// float -> fp8
static inline __device__ uint8_t float_to_fp8(float f) {
  constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
  constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = 0x7f;
  } else {
    if (f_bits < (UINT32_C(121) << 23)) {
      f_bits =
        c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits) + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint8_t mant_odd = (f_bits >> 20) & 1;
      f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 20);
    }
  }
51

zhuwenwen's avatar
zhuwenwen committed
52
53
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
54
55
}

56
57

template <typename Tout, typename Tin>
58
59
60
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
                                                 const float scale) {
  return x;
61
62
63
64
65
66
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
67
__inline__ __device__ __nv_bfloat16
68
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
69
70

  return __float2bfloat16(fp8_to_float(a) * scale);
71
72
73
74
}

// fp8x2 -> __nv_bfloat162
template <>
75
76
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
77
                                                float scale) {
78
79
80
81
82
  __nv_bfloat162 res;
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  res.y =
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  return res;
83
84
85
86
}

// fp8x4 -> bf16_4_t
template <>
87
88
__inline__ __device__ bf16_4_t
scaled_vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a, float scale) {
89
90
91
92
93
  bf16_4_t res;
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
                                                          scale);
  return res;
94
95
96
97
}

// fp8x8 -> bf16_8_t
template <>
98
__inline__ __device__ bf16_8_t
99
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, float scale) {
100
101
102
103
104
105
106
107
108
  bf16_4_t tmp1, tmp2;
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
109
110
111
112
}

// fp8 -> float
template <>
113
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
114
    const uint8_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
115
    return fp8_to_float(a) * scale;
116
117
118
119
}

// fp8x2 -> float2
template <>
120
__inline__ __device__ float2
121
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
122
123
124
125
    float2 f2r;
    f2r.x = scaled_vec_conversion<float, uint8_t>((uint8_t)a, scale);
    f2r.y = scaled_vec_conversion<float, uint8_t>((uint8_t)(a >> 8U), scale);
    return f2r;
126
127
128
129
}

// fp8x4 -> float4
template <>
130
131
132
133
134
135
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
  Float4_ res;
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  return res;
136
137
}

138
139
140
141
142
143
144
145
// fp8x4 -> float4
template <>
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, float scale) {
  Float4_ res = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  return {res.x.x, res.x.y, res.y.x, res.y.y};
}

146
147
// fp8x8 -> float8
template <>
148
__inline__ __device__ Float8_
149
scaled_vec_conversion<Float8_, uint2>(const uint2& a, float scale) {
150
151
152
153
154
155
156
157
158
  Float4_ tmp1, tmp2;
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
159
160
}

161
162
163
164
// fp8 -> half
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
165
166
  float res = fp8_to_float(a) * scale;
  return float_to_half(res);
167
168
169
170
171
172
173
}

// fp8x2 -> half2
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint16_t>(const uint16_t& a, float scale) {
  union {
zhuwenwen's avatar
zhuwenwen committed
174
175
176
177
178
179
    uint16_t u16[2];
    uint32_t u32;
  } res;
  res.u16[0] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)a, scale);
  res.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>((uint8_t)(a >> 8U), scale);
  return res.u32;
180
181
182
183
184
185
186
187
188
189
190
}

// fp8x4 -> half2x2
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, float scale) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
zhuwenwen's avatar
zhuwenwen committed
191
  tmp.u32[1] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
192
193
  return tmp.u32x2;
}
194

195
196
197
198
199
200
201
202
203
204
205
206
// fp8x8 -> half2x4
template <>
__inline__ __device__ uint4 scaled_vec_conversion<uint4, uint2>(const uint2& a,
                                                                float scale) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  return tmp.u64x2;
}
207
208
209

// half -> fp8
template <>
210
__inline__ __device__ uint8_t
211
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
212
213
  float res_f = half_to_float(a) / scale;
  return float_to_fp8(res_f);
214
}
215

216
217
218
219
220
// halfx2 -> fp8x2
template <>
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint32_t>(const uint32_t& a, float scale) {
  union {
zhuwenwen's avatar
zhuwenwen committed
221
222
    uint8_t ui8[2];
    uint16_t ui16;
223
  } tmp;
zhuwenwen's avatar
zhuwenwen committed
224
225
226
227
228
229
230
231
  union {
    uint32_t ui32;
    half2 h2r;
  } tmp_a;
  tmp_a.ui32 = a;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[0], scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, uint16_t>(tmp_a.h2r.data[1], scale);
  return tmp.ui16;
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
}

// half2x2 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, uint2>(const uint2& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, uint32_t>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, uint32_t>(a.y, scale);
  return tmp.ui32;
}

// half2x4 -> fp8x8
template <>
__inline__ __device__ uint2 scaled_vec_conversion<uint2, uint4>(const uint4& a,
                                                                float scale) {
  union {
    uint2 ui2[2];
    uint4 ui4;
  } tmp;
  tmp.ui4 = a;
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[0], scale);
  res.y = scaled_vec_conversion<uint32_t, uint2>(tmp.ui2[1], scale);
  return res;
260
261
262
263
}

// bf16 -> fp8
template <>
264
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
265
    const __nv_bfloat16& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
266
267
      float res_f = (static_cast<float>(a)) / scale;
      return float_to_fp8(res_f);
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
}

// bf16x2 -> fp8x2
template <>
__inline__ __device__ uint16_t scaled_vec_conversion<uint16_t, __nv_bfloat162>(
    const __nv_bfloat162& a, float scale) {
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.x, scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, __nv_bfloat16>(a.y, scale);
  return tmp.ui16;
}

// bf16x4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, bf16_4_t>(const bf16_4_t& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.x, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, __nv_bfloat162>(a.y, scale);
  return tmp.ui32;
}

// bf16x8 -> fp8x8
template <>
__inline__ __device__ uint2
scaled_vec_conversion<uint2, bf16_8_t>(const bf16_8_t& a, float scale) {
  uint2 res;
  res.x = scaled_vec_conversion<uint32_t, bf16_4_t>({a.x, a.y}, scale);
  res.y = scaled_vec_conversion<uint32_t, bf16_4_t>({a.z, a.w}, scale);
  return res;
304
305
306
307
}

// float -> fp8
template <>
308
__inline__ __device__ uint8_t
309
scaled_vec_conversion<uint8_t, float>(const float& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
310
  return float_to_fp8(a / scale);
311
312
}

313
// floatx2 -> fp8x2
314
template <>
315
316
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, float2>(const float2& a, float scale) {
zhuwenwen's avatar
zhuwenwen committed
317
318
319
320
321
322
323
  union {
    uint8_t ui8[2];
    uint16_t ui16;
  } tmp;
  tmp.ui8[0] = scaled_vec_conversion<uint8_t, float>(a.x, scale);
  tmp.ui8[1] = scaled_vec_conversion<uint8_t, float>(a.y, scale);
  return tmp.ui16;
324
325
326
327
328
329
330
331
332
333
334
335
336
}

// floatx4 -> fp8x4
template <>
__inline__ __device__ uint32_t
scaled_vec_conversion<uint32_t, float4>(const float4& a, float scale) {
  union {
    uint16_t ui16[2];
    uint32_t ui32;
  } tmp;
  tmp.ui16[0] = scaled_vec_conversion<uint16_t, float2>({a.x, a.y}, scale);
  tmp.ui16[1] = scaled_vec_conversion<uint16_t, float2>({a.z, a.w}, scale);
  return tmp.ui32;
337
338
}

yangql's avatar
yangql committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
inline __device__ uint8_t float_to_fp8e5m2(float f) {
  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
  uint32_t f_bits = c10::detail::fp32_to_bits(f);
  uint8_t result = 0u;
  const uint32_t sign = f_bits & UINT32_C(0x80000000);
  f_bits ^= sign;
  if (f_bits >= fp8_max) {
    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
  } else {
    if (f_bits < (UINT32_C(113) << 23)) {
      f_bits = c10::detail::fp32_to_bits(c10::detail::fp32_from_bits(f_bits)
               + c10::detail::fp32_from_bits(denorm_mask));
      result = static_cast<uint8_t>(f_bits - denorm_mask);
    } else {
      uint32_t mant_odd = (f_bits >> 21) & 1;
      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
      f_bits += mant_odd;
      result = static_cast<uint8_t>(f_bits >> 21);
    }
  }
  result |= static_cast<uint8_t>(sign >> 24);
  return result;
}

//  fp8
template <typename Tin>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2(const Tin& a, float scale) {
  return 0;
}

// float -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<float>(const float& a, float scale) {
  return float_to_fp8e5m2(a / scale);
}

// half -> fp8
template <>
__inline__ __device__ uint8_t
scaled_vec_conversion_e5m2<uint16_t>(const uint16_t& a, float scale) {
  float res_f = half_to_float(a) / scale;
  return float_to_fp8e5m2(res_f);
}

// bf16 -> fp8
template <>
__inline__ __device__ uint8_t 
scaled_vec_conversion_e5m2<__nv_bfloat16>(const __nv_bfloat16& a, float scale) {
  float res_f = (static_cast<float>(a)) / scale;
  return float_to_fp8e5m2(res_f);
}

395
396

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
397
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
yangql's avatar
yangql committed
398
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
399
    return scaled_vec_conversion<Tout, Tin>(x, scale);
yangql's avatar
yangql committed
400
401
402
403
404
405
406
407
  }
  else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 && sizeof(Tout)==1){
    return scaled_vec_conversion_e5m2<Tin>(x, scale);
  }
  // else if constexpr(kv_dt == Fp8KVCacheDataType::kFp8E5M2 &&
  //      (std::is_same<Tin, uint16_t>::value||std::is_same<Tin, __nv_bfloat16>::value)){
  //   return scaled_vec_conversion_e5m2<Tin>(x, scale);
  // }  
408
  return {};  // Squash missing return statement warning
409
410
}

411
412
413
414
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
yangql's avatar
yangql committed
415
 #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
416
    if (KV_DTYPE == "auto") {                                                  \
417
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
418
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
419
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
420
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
421
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
422
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
423
424
425
426
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
    } else {                                                                   \
427
428
429
430
431
432
433
434
435
436
437
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
yangql's avatar
yangql committed
438
439
440
441
442
443
444
445
446
447
448
      } else if (KV_DTYPE == "fp8_e5m2") {                                     \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E5M2);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
449
450
451
452
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
453

yangql's avatar
yangql committed
454

455
456
}  // namespace fp8
#endif  // USE_ROCM
yangql's avatar
yangql committed
457
}  // namespace vllm