quant_utils.cuh 16.2 KB
Newer Older
1
2
3
4
5
6
7
#pragma once
#include "hip_float8.h"

#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>

8
#include "../../../attention/dtype_fp8.cuh"
9
10
11
#include "../../../attention/dtype_float32.cuh"
#include "../../../attention/dtype_bfloat16.cuh"

12
namespace vllm {
13
14
15
#ifdef USE_ROCM

namespace fp8 {
16
  #ifdef ENABLE_FP8
17

18
template <typename Tout, typename Tin>
19
20
__inline__ __device__ Tout vec_conversion(const Tin& x) {
  return x;
21
22
23
}

template <typename Tout, typename Tin>
24
25
26
__inline__ __device__ Tout scaled_vec_conversion(const Tin& x,
                                                 const float scale) {
  return x;
27
28
29
30
}

// fp8 -> half
template <>
31
32
33
34
35
36
__inline__ __device__ uint16_t
vec_conversion<uint16_t, uint8_t>(const uint8_t& a) {
  hip_fp8 f8{a, hip_fp8::from_bits()};
  __half_raw res;
  res.data = static_cast<float>(f8);
  return res.x;
37
38
39
40
}

// fp8x2 -> half2
template <>
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
__inline__ __device__ uint32_t
vec_conversion<uint32_t, uint16_t>(const uint16_t& a) {
    #if defined(__HIP__MI300__) && \
        defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  union {
    __half2_raw h2r;
    uint32_t ui32;
  } tmp;
  tmp.h2r.x.data = f2[0];
  tmp.h2r.y.data = f2[1];
  return tmp.ui32;
    #else
  union {
    uint16_t u16[2];
    uint32_t u32;
  } tmp;

  tmp.u16[0] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a));
  tmp.u16[1] = vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a >> 8U));
  return tmp.u32;
    #endif
63
64
65
66
}

// fp8x4 -> half2x2
template <>
67
68
69
70
71
72
73
74
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
  tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
  return tmp.u32x2;
75
76
77
78
}

// fp8x8 -> half2x4
template <>
79
80
81
82
83
84
85
86
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
  tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
  return tmp.u64x2;
87
88
89
90
91
92
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
93
94
95
96
97
__inline__ __device__ __nv_bfloat16
vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) {
  hip_fp8 f8{a, hip_fp8::from_bits()};
  float f{f8};
  return __float2bfloat16(f);
98
99
100
101
102
103
}

using __nv_bfloat162 = __hip_bfloat162;

// fp8x2 -> __nv_bfloat162
template <>
104
105
106
107
108
109
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) {
  __nv_bfloat162 res;
  res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
  res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
  return res;
110
111
112
113
}

// fp8x4 -> bf16_4_t
template <>
114
115
116
117
118
119
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) {
  bf16_4_t res;
  res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
  res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
  return res;
120
121
122
123
}

// fp8x8 -> bf16_8_t
template <>
124
125
126
127
128
129
130
131
132
133
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) {
  bf16_4_t tmp1, tmp2;
  tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
  tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
134
135
136
137
}

// fp8 -> float
template <>
138
139
140
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) {
  hip_fp8 fp8{a, hip_fp8::from_bits()};
  return static_cast<float>(fp8);
141
142
143
144
}

// fp8x2 -> float2
template <>
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
__inline__ __device__ float2
vec_conversion<float2, uint16_t>(const uint16_t& a) {
    #if defined(__HIP__MI300__) && \
        defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  float2 res;
  const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  res.x = f2[0];
  res.y = f2[1];
  return res;
    #else
  float2 res;
  res.x = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a));
  res.y = vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U));
  return res;
    #endif
160
161
162
163
}

// fp8x4 -> float4
template <>
164
165
166
167
168
169
__inline__ __device__ Float4_
vec_conversion<Float4_, uint32_t>(const uint32_t& a) {
  Float4_ res;
  res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
  res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
  return res;
170
171
172
173
}

// fp8x8 -> float8
template <>
174
175
176
177
178
179
180
181
182
183
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) {
  Float4_ tmp1, tmp2;
  tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
  tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
184
185
186
187
}

// half -> fp8
template <>
188
189
190
191
__inline__ __device__ uint8_t
vec_conversion<uint8_t, uint16_t>(const uint16_t& a) {
  __half_raw tmp;
  tmp.x = a;
192

193
194
  hip_fp8 f8{static_cast<float>(tmp.data)};
  return f8.data;
195
196
197
198
}

// bf16 -> fp8
template <>
199
200
201
202
__inline__ __device__ uint8_t
vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) {
  hip_fp8 res{__bfloat162float(a)};
  return res.data;
203
204
205
206
}

// float -> fp8
template <>
207
208
209
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) {
  hip_fp8 f8(a);
  return f8.data;
210
211
212
213
}

// fp8x4 -> float4
template <>
214
215
216
217
218
__inline__ __device__ float4
vec_conversion<float4, uint32_t>(const uint32_t& a) {
  Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
  float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  return res;
219
220
221
222
}

// float2 -> half2
template <>
223
224
225
226
227
228
__inline__ __device__ uint32_t
vec_conversion<uint32_t, float2>(const float2& a) {
  union {
    half2 float16;
    uint32_t uint32;
  };
229

230
231
  float16 = __float22half2_rn(a);
  return uint32;
232
233
234
235
}

// Float4 -> half2x2
template <>
236
237
238
239
240
241
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) {
  uint2 b;
  float2 val;
  val.x = a.x.x;
  val.y = a.x.y;
  b.x = vec_conversion<uint32_t, float2>(val);
242

243
244
245
246
  val.x = a.y.x;
  val.y = a.y.y;
  b.y = vec_conversion<uint32_t, float2>(val);
  return b;
247
248
249
250
}

// Float4 -> float4
template <>
251
252
253
254
255
256
257
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) {
  float4 b;
  b.x = a.x.x;
  b.y = a.x.y;
  b.z = a.y.x;
  b.w = a.y.y;
  return b;
258
259
260
261
}

// Float8 -> half2x4
template <>
262
263
264
265
266
267
268
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) {
  uint4 b;
  b.x = vec_conversion<uint32_t, float2>(a.x);
  b.y = vec_conversion<uint32_t, float2>(a.y);
  b.z = vec_conversion<uint32_t, float2>(a.z);
  b.w = vec_conversion<uint32_t, float2>(a.w);
  return b;
269
270
271
272
}

// float2 -> bfloat162
template <>
273
274
275
276
__inline__ __device__ __nv_bfloat162
vec_conversion<__nv_bfloat162, float2>(const float2& a) {
  __nv_bfloat162 b = __float22bfloat162_rn(a);
  return b;
277
278
279
280
}

// Float4 -> bfloat162x2
template <>
281
282
283
284
285
286
__inline__ __device__ bf16_4_t
vec_conversion<bf16_4_t, Float4_>(const Float4_& a) {
  bf16_4_t b;
  b.x = __float22bfloat162_rn(a.x);
  b.y = __float22bfloat162_rn(a.y);
  return b;
287
288
289
290
}

// Float8 -> bfloat162x4
template <>
291
292
293
294
295
296
297
298
__inline__ __device__ bf16_8_t
vec_conversion<bf16_8_t, Float8_>(const Float8_& a) {
  bf16_8_t b;
  b.x = __float22bfloat162_rn(a.x);
  b.y = __float22bfloat162_rn(a.y);
  b.z = __float22bfloat162_rn(a.z);
  b.w = __float22bfloat162_rn(a.w);
  return b;
299
300
}

301
302
/* Scaled and vectorized conversions, for data exchange between high and low
   precision domains
303

304
305
306
   Convention of the scale in API, e.g: FP8_data = Quantization(
   High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) *
   scale =>  HP
307
308
309
310
311

 */

// fp8 -> half
template <>
312
313
314
315
316
317
__inline__ __device__ uint16_t
scaled_vec_conversion<uint16_t, uint8_t>(const uint8_t& a, const float scale) {
  hip_fp8 f8{a, hip_fp8::from_bits()};
  __half_raw res;
  res.data = static_cast<float>(f8) * scale;
  return res.x;
318
319
320
321
}

// fp8x2 -> half2
template <>
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
__inline__ __device__ uint32_t scaled_vec_conversion<uint32_t, uint16_t>(
    const uint16_t& a, const float scale) {
    #if defined(__HIP__MI300__) && \
        defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  union {
    __half2_raw h2r;
    uint32_t ui32;
  } tmp;
  tmp.h2r.x.data = f2[0] * scale;
  tmp.h2r.y.data = f2[1] * scale;
  return tmp.ui32;
    #else
  union {
    uint16_t u16[2];
    uint32_t u32;
  } tmp;

  tmp.u16[0] =
      scaled_vec_conversion<uint16_t, uint8_t>(static_cast<uint8_t>(a), scale);
  tmp.u16[1] = scaled_vec_conversion<uint16_t, uint8_t>(
      static_cast<uint8_t>(a >> 8U), scale);
  return tmp.u32;
    #endif
346
347
348
349
}

// fp8x4 -> half2x2
template <>
350
351
352
353
354
355
356
357
358
359
__inline__ __device__ uint2
scaled_vec_conversion<uint2, uint32_t>(const uint32_t& a, const float scale) {
  union {
    uint2 u32x2;
    uint32_t u32[2];
  } tmp;
  tmp.u32[0] = scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)a, scale);
  tmp.u32[1] =
      scaled_vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U), scale);
  return tmp.u32x2;
360
361
362
363
}

// fp8x8 -> half2x4
template <>
364
365
366
367
368
369
370
371
372
__inline__ __device__ uint4
scaled_vec_conversion<uint4, uint2>(const uint2& a, const float scale) {
  union {
    uint4 u64x2;
    uint2 u64[2];
  } tmp;
  tmp.u64[0] = scaled_vec_conversion<uint2, uint32_t>(a.x, scale);
  tmp.u64[1] = scaled_vec_conversion<uint2, uint32_t>(a.y, scale);
  return tmp.u64x2;
373
374
375
376
377
378
}

using __nv_bfloat16 = __hip_bfloat16;

// fp8 -> __nv_bfloat16
template <>
379
380
381
382
383
384
__inline__ __device__ __nv_bfloat16
scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a,
                                              const float scale) {
  hip_fp8 f8{a, hip_fp8::from_bits()};
  float f{f8};
  return __float2bfloat16(f * scale);
385
386
387
388
389
390
}

using __nv_bfloat162 = __hip_bfloat162;

// fp8x2 -> __nv_bfloat162
template <>
391
392
393
394
395
396
397
398
__inline__ __device__ __nv_bfloat162
scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a,
                                                const float scale) {
  __nv_bfloat162 res;
  res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale);
  res.y =
      scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale);
  return res;
399
400
401
402
}

// fp8x4 -> bf16_4_t
template <>
403
404
405
406
407
408
409
__inline__ __device__ bf16_4_t scaled_vec_conversion<bf16_4_t, uint32_t>(
    const uint32_t& a, const float scale) {
  bf16_4_t res;
  res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U),
                                                          scale);
  return res;
410
411
412
413
}

// fp8x8 -> bf16_8_t
template <>
414
415
416
417
418
419
420
421
422
423
424
__inline__ __device__ bf16_8_t
scaled_vec_conversion<bf16_8_t, uint2>(const uint2& a, const float scale) {
  bf16_4_t tmp1, tmp2;
  tmp1 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<bf16_4_t, uint32_t>(a.y, scale);
  bf16_8_t res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
425
426
427
428
}

// fp8 -> float
template <>
429
430
431
432
__inline__ __device__ float scaled_vec_conversion<float, uint8_t>(
    const uint8_t& a, const float scale) {
  hip_fp8 fp8{a, hip_fp8::from_bits()};
  return static_cast<float>(fp8) * scale;
433
434
435
436
}

// fp8x2 -> float2
template <>
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
__inline__ __device__ float2
scaled_vec_conversion<float2, uint16_t>(const uint16_t& a, const float scale) {
    #if defined(__HIP__MI300__) && \
        defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__)
  float2 res;
  const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0);
  res.x = f2[0] * scale;
  res.y = f2[1] * scale;
  return res;
    #else
  float2 res;
  res.x = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a), scale);
  res.y = scaled_vec_conversion<float, uint8_t>(static_cast<uint8_t>(a >> 8U),
                                                scale);
  return res;
    #endif
453
454
455
456
}

// fp8x4 -> float4
template <>
457
458
459
460
461
462
__inline__ __device__ Float4_
scaled_vec_conversion<Float4_, uint32_t>(const uint32_t& a, const float scale) {
  Float4_ res;
  res.x = scaled_vec_conversion<float2, uint16_t>((uint16_t)a, scale);
  res.y = scaled_vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U), scale);
  return res;
463
464
465
466
}

// fp8x8 -> float8
template <>
467
468
469
470
471
472
473
474
475
476
477
__inline__ __device__ Float8_
scaled_vec_conversion<Float8_, uint2>(const uint2& a, const float scale) {
  Float4_ tmp1, tmp2;
  tmp1 = scaled_vec_conversion<Float4_, uint32_t>(a.x, scale);
  tmp2 = scaled_vec_conversion<Float4_, uint32_t>(a.y, scale);
  Float8_ res;
  res.x = tmp1.x;
  res.y = tmp1.y;
  res.z = tmp2.x;
  res.w = tmp2.y;
  return res;
478
479
480
481
482
483
484
485
}

/* Quantize(HP / scale) => FP8 */

// TODO(Hai): vectorized to add

// half -> fp8
template <>
486
487
488
489
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, uint16_t>(const uint16_t& a, const float scale) {
  __half_raw tmp;
  tmp.x = a;
490

491
492
  hip_fp8 f8{static_cast<float>(tmp.data) / scale};
  return f8.data;
493
494
495
496
}

// bf16 -> fp8
template <>
497
498
499
500
__inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
    const __nv_bfloat16& a, const float scale) {
  hip_fp8 res{__bfloat162float(a) / scale};
  return res.data;
501
502
503
504
}

// float -> fp8
template <>
505
506
507
508
__inline__ __device__ uint8_t
scaled_vec_conversion<uint8_t, float>(const float& a, const float scale) {
  hip_fp8 f8(a / scale);
  return f8.data;
509
510
511
512
}

// fp8x4 -> float4
template <>
513
514
515
516
517
__inline__ __device__ float4
scaled_vec_conversion<float4, uint32_t>(const uint32_t& a, const float scale) {
  Float4_ tmp = scaled_vec_conversion<Float4_, uint32_t>(a, scale);
  float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
  return res;
518
}
519
  #endif  // ENABLE_FP8
520

521
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
522
523
__inline__ __device__ Tout convert(const Tin& x) {
  #ifdef ENABLE_FP8
524
525
526
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return vec_conversion<Tout, Tin>(x);
  }
527
  #endif
528
  assert(false);
529
  return {};  // Squash missing return statement warning
530
}
531
532

template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
533
534
__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
  #ifdef ENABLE_FP8
535
536
537
  if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) {
    return scaled_vec_conversion<Tout, Tin>(x, scale);
  }
538
  #endif
539
  assert(false);
540
  return {};  // Squash missing return statement warning
541
542
}

543
544
545
546
547
548
  // The following macro is used to dispatch the conversion function based on
  // the data type of the key and value cache. The FN is a macro that calls a
  // function with template<typename scalar_t, typename cache_t,
  // Fp8KVCacheDataType kv_dt>.
  #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN)                  \
    if (KV_DTYPE == "auto") {                                                  \
549
      if (SRC_DTYPE == at::ScalarType::Float) {                                \
550
        FN(float, float, vllm::Fp8KVCacheDataType::kAuto);                     \
551
      } else if (SRC_DTYPE == at::ScalarType::Half) {                          \
552
        FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto);               \
553
      } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                      \
554
        FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto);     \
555
556
557
558
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \
      }                                                                        \
    } else {                                                                   \
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
      if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") {                       \
        if (SRC_DTYPE == at::ScalarType::Float) {                              \
          FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);              \
        } else if (SRC_DTYPE == at::ScalarType::Half) {                        \
          FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);           \
        } else if (SRC_DTYPE == at::ScalarType::BFloat16) {                    \
          FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3);      \
        } else {                                                               \
          TORCH_CHECK(false,                                                   \
                      "Unsupported input type of kv cache: ", SRC_DTYPE);      \
        }                                                                      \
      } else {                                                                 \
        TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE);   \
      }                                                                        \
    }
574

575
576
577
}  // namespace fp8
#endif  // USE_ROCM
}  // namespace vllm