nvfp4_utils.cuh 13.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <cuda_runtime.h>
#include <cuda_fp8.h>

22
23
24
25
26
27
28
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
     defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
  #define ELTS_PER_THREAD 16
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
constexpr bool CVT_FP4_PACK16 = true;
#else
  #define ELTS_PER_THREAD 8
29
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
30
31
32
constexpr bool CVT_FP4_PACK16 = false;
#endif

33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
constexpr int CVT_FP4_SF_VEC_SIZE = 16;

namespace vllm {

// Convert PyTorch cpp type to CUDA type
template <typename T>
struct CUDATypeConverter {
  using Type = T;
};

template <>
struct CUDATypeConverter<at::Half> {
  using Type = half;
};

template <>
struct CUDATypeConverter<at::BFloat16> {
  using Type = __nv_bfloat16;
};

// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
  using Type = half2;
};  // keep for generality

template <>
struct TypeConverter<half2> {
  using Type = half;
};

template <>
struct TypeConverter<half> {
  using Type = half2;
};

template <>
struct TypeConverter<__nv_bfloat162> {
  using Type = __nv_bfloat16;
};

template <>
struct TypeConverter<__nv_bfloat16> {
  using Type = __nv_bfloat162;
};

79
80
81
82
83
84
85
86
#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \
     defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100)
// Define a 32 bytes packed data type.
template <class Type>
struct alignas(32) PackedVec {
  typename TypeConverter<Type>::Type elts[8];
};
#else
87
88
// Define a 16 bytes packed data type.
template <class Type>
89
struct alignas(16) PackedVec {
90
91
  typename TypeConverter<Type>::Type elts[4];
};
92
#endif
93
94
95
96
97
98

template <>
struct PackedVec<__nv_fp8_e4m3> {
  __nv_fp8x2_e4m3 elts[8];
};

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
  static_assert(std::is_integral_v<Int>,
                "round_up argument must be integral type");
  return ((x + y - 1) / y) * y;
}

template <typename Int>
__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) {
  return (x + y - 1) / y;
}

// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
  constexpr int ROW_TILE = 128;
  return round_up(m, ROW_TILE);
}

117
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
118
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
  uint32_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\n"
      "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
      "}"
      : "=r"(val)
      : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
        "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
  return val;
}

// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
139
__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) {
140
141
142
143
144
145
146
147
148
149
150
151
  uint32_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\n"
      "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
152
      "}\n"
153
154
155
156
157
158
      : "=r"(val)
      : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
        "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
  return val;
}

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
struct u32x2 {
  uint32_t lo, hi;
};

using fp4_packed_t = std::conditional_t<CVT_FP4_PACK16, u32x2, uint32_t>;

__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) {
  u32x2 out;
  asm volatile(
      "{\n"
      ".reg .b8 b0;\n"
      ".reg .b8 b1;\n"
      ".reg .b8 b2;\n"
      ".reg .b8 b3;\n"
      ".reg .b8 b4;\n"
      ".reg .b8 b5;\n"
      ".reg .b8 b6;\n"
      ".reg .b8 b7;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b0,  %3,  %2;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b1,  %5,  %4;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b2,  %7,  %6;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b3,  %9,  %8;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b4, %11, %10;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b5, %13, %12;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b6, %15, %14;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b7, %17, %16;\n"
      "mov.b32 %0, {b0, b1, b2, b3};\n"
      "mov.b32 %1, {b4, b5, b6, b7};\n"
      "}\n"
      : "=r"(out.lo), "=r"(out.hi)
      : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
        "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y),
        "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y),
        "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y));
  return out;
}

__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) {
  return fp32_vec8_to_e2m1(v);
}

__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) {
  return fp32_vec16_to_e2m1(v);
}

204
// Fast reciprocal.
205
__device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
206
  float b;
207
  asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a));
208
209
210
  return b;
}

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
template <class Type>
__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec<Type>& out,
                                                     const void* ptr,
                                                     bool pred) {
  uint32_t r0, r1, r2, r3;

  asm volatile(
      "{\n"
      "  .reg .pred pr;\n"
      "  setp.ne.u32 pr, %4, 0;\n"
      "  mov.u32 %0, 0;\n"
      "  mov.u32 %1, 0;\n"
      "  mov.u32 %2, 0;\n"
      "  mov.u32 %3, 0;\n"
      "  @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
      "}\n"
      : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
      : "r"((int)pred), "l"(ptr));

  *reinterpret_cast<uint4*>(&out) = uint4{r0, r1, r2, r3};
}

template <class Type>
__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec<Type>& out,
                                                     const void* ptr,
                                                     bool pred) {
  uint32_t r0, r1, r2, r3, r4, r5, r6, r7;

  asm volatile(
      "{\n"
      "  .reg .pred pr;\n"
      "  setp.ne.u32 pr, %8, 0;\n"
      "  mov.u32 %0, 0;\n"
      "  mov.u32 %1, 0;\n"
      "  mov.u32 %2, 0;\n"
      "  mov.u32 %3, 0;\n"
      "  mov.u32 %4, 0;\n"
      "  mov.u32 %5, 0;\n"
      "  mov.u32 %6, 0;\n"
      "  mov.u32 %7, 0;\n"
      "  @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
      "}\n"
      : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6),
        "=r"(r7)
      : "r"((int)pred), "l"(ptr));

  reinterpret_cast<uint4*>(&out)[0] = uint4{r0, r1, r2, r3};
  reinterpret_cast<uint4*>(&out)[1] = uint4{r4, r5, r6, r7};
}

261
262
263
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
264
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
265
266
__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
    int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) {
267
268
269
270
271
272
  static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
                CVT_FP4_NUM_THREADS_PER_SF == 2);

  // One pair of threads write one SF to global memory.
  // TODO: stage through smem for packed STG.32
  // is it better than STG.8 from 4 threads ?
273
274
  if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) {
    return nullptr;
275
  }
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

  // SF vector index (16 elements share one SF in the K dimension).
  int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
  int32_t mIdx = rowIdx;

  // Decompose indices using bitwise ops (all divisors are powers of 2).
  // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
  int32_t mTileIdx = mIdx >> 7;         // mIdx / 128
  int32_t outerMIdx = mIdx & 31;        // mIdx % 32
  int32_t innerMIdx = (mIdx >> 5) & 3;  // (mIdx / 32) % 4
  int32_t kTileIdx = kIdx >> 2;         // kIdx / 4
  int32_t innerKIdx = kIdx & 3;         // kIdx % 4

  // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
  //                           outerMIdx * 16 + innerMIdx * 4 + innerKIdx
  // Use bitwise OR for non-overlapping lower bits.
  int64_t SFOffset = (static_cast<int64_t>(mTileIdx) * numKTiles + kTileIdx)
                         << 9 |
                     (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx;

  return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
297
298
}

299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
template <class SFType>
__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
                                                       int packs_per_row_sf,
                                                       SFType* SFout) {
  constexpr int PACK = CVT_FP4_ELTS_PER_THREAD;
  constexpr int THREADS_PER_SF =
      CVT_FP4_SF_VEC_SIZE / PACK;  // 1 if PACK=16, 2 else PACK=8

  if (threadIdx.x % THREADS_PER_SF != 0) return nullptr;

  int sf_col =
      pack / THREADS_PER_SF;  // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2
  int64_t off = (int64_t)row * packs_per_row_sf + sf_col;

  return (uint8_t*)SFout + off;
}

316
// Quantizes the provided PackedVec into the uint32_t output
317
318
319
template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
__device__ __forceinline__ fp4_packed_t
cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal, uint8_t* SFout) {
320
321
322
  // Get absolute maximum values among the local 8 values.
  auto localMax = __habs2(vec.elts[0]);

323
  // Local maximum value.
324
325
326
327
328
329
#pragma unroll
  for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
    localMax = __hmax2(localMax, __habs2(vec.elts[i]));
  }

  // Get the absolute maximum among all 16 values (two threads).
330
331
332
333

  if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) {
    localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax);
  }
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
  // Get the final absolute maximum values.
  float vecMax = float(__hmax(localMax.x, localMax.y));

  // Get the SF (max value of the vector / max value of e2m1).
  // maximum value of e2m1 = 6.0.
  // TODO: use half as compute data type.
  float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
  // 8 bits representation of the SF.
  uint8_t fp8SFVal;
  // Write the SF to global memory (STG.8).
  if constexpr (UE8M0_SF) {
    // Extract the 8 exponent bits from float32.
    // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
    uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
    fp8SFVal = tmp & 0xff;
    // Convert back to fp32.
    reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
  } else {
    // Here SFValue is always positive, so E4M3 is the same as UE4M3.
    __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
    reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
    // Convert back to fp32.
    SFValue = float(tmp);
  }
358
359
360
361

  // Write the SF to global memory (STG.8).
  if (SFout) *SFout = fp8SFVal;

362
363
364
365
  // Get the output scale.
  // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
  //                       reciprocal(SFScaleVal))
  float outputScale =
366
367
368
      SFValue != 0.0f ? reciprocal_approximate_ftz(
                            SFValue * reciprocal_approximate_ftz(SFScaleVal))
                      : 0.0f;
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384

  // Convert the input to float.
  float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];

#pragma unroll
  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
    if constexpr (std::is_same_v<Type, half>) {
      fp2Vals[i] = __half22float2(vec.elts[i]);
    } else {
      fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
    }
    fp2Vals[i].x *= outputScale;
    fp2Vals[i].y *= outputScale;
  }

  // Convert to e2m1 values.
385
  return pack_fp4(fp2Vals);
386
387
}

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
// silu in float32
__device__ __forceinline__ float silu(float x) {
  return __fdividef(x, (1.f + __expf(-x)));
}

__device__ __forceinline__ float2 silu2(float2 x) {
  return make_float2(silu(x.x), silu(x.y));
}

template <class Type>
__inline__ __device__ PackedVec<Type> compute_silu_mul(
    const PackedVec<Type>& x_vec, const PackedVec<Type>& y_vec) {
  PackedVec<Type> result;

#pragma unroll
  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
    // silu_mul in float32
    if constexpr (std::is_same_v<Type, half>) {
      float2 silu_vec = silu2(__half22float2(x_vec.elts[i]));
      result.elts[i] = __float22half2_rn(
          __fmul2_rn(silu_vec, __half22float2(y_vec.elts[i])));
    } else {
      float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i]));
      result.elts[i] = __float22bfloat162_rn(
          __fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i])));
    }
  }
  return result;
}

418
}  // namespace vllm