nvfp4_utils.cuh 10.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*
 * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#pragma once

#include <cuda_runtime.h>
#include <cuda_fp8.h>

22
23
24
25
#include "../../cuda_vec_utils.cuh"

#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \
    CUDA_VERSION >= 12090
26
27
28
29
30
  #define ELTS_PER_THREAD 16
constexpr int CVT_FP4_ELTS_PER_THREAD = 16;
constexpr bool CVT_FP4_PACK16 = true;
#else
  #define ELTS_PER_THREAD 8
31
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
32
33
34
constexpr bool CVT_FP4_PACK16 = false;
#endif

35
36
37
38
constexpr int CVT_FP4_SF_VEC_SIZE = 16;

namespace vllm {

39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
template <typename Int>
__host__ __device__ inline Int round_up(Int x, Int y) {
  static_assert(std::is_integral_v<Int>,
                "round_up argument must be integral type");
  return ((x + y - 1) / y) * y;
}

template <typename Int>
__host__ __device__ __forceinline__ Int div_round_up(Int x, Int y) {
  return (x + y - 1) / y;
}

// Compute effective rows for grid configuration with swizzled SF layouts.
inline int computeEffectiveRows(int m) {
  constexpr int ROW_TILE = 128;
  return round_up(m, ROW_TILE);
}

57
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
58
inline __device__ uint32_t fp32_vec8_to_e2m1(float (&array)[8]) {
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  uint32_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\n"
      "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
      "}"
      : "=r"(val)
      : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
        "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
  return val;
}

// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
79
__device__ __forceinline__ uint32_t fp32_vec8_to_e2m1(float2 (&array)[4]) {
80
81
82
83
84
85
86
87
88
89
90
91
  uint32_t val;
  asm volatile(
      "{\n"
      ".reg .b8 byte0;\n"
      ".reg .b8 byte1;\n"
      ".reg .b8 byte2;\n"
      ".reg .b8 byte3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte0, %2, %1;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte1, %4, %3;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte2, %6, %5;\n"
      "cvt.rn.satfinite.e2m1x2.f32   byte3, %8, %7;\n"
      "mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
92
      "}\n"
93
94
95
96
97
98
      : "=r"(val)
      : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
        "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
  return val;
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
struct u32x2 {
  uint32_t lo, hi;
};

using fp4_packed_t = std::conditional_t<CVT_FP4_PACK16, u32x2, uint32_t>;

__device__ __forceinline__ u32x2 fp32_vec16_to_e2m1(float2 (&array)[8]) {
  u32x2 out;
  asm volatile(
      "{\n"
      ".reg .b8 b0;\n"
      ".reg .b8 b1;\n"
      ".reg .b8 b2;\n"
      ".reg .b8 b3;\n"
      ".reg .b8 b4;\n"
      ".reg .b8 b5;\n"
      ".reg .b8 b6;\n"
      ".reg .b8 b7;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b0,  %3,  %2;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b1,  %5,  %4;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b2,  %7,  %6;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b3,  %9,  %8;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b4, %11, %10;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b5, %13, %12;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b6, %15, %14;\n"
      "cvt.rn.satfinite.e2m1x2.f32   b7, %17, %16;\n"
      "mov.b32 %0, {b0, b1, b2, b3};\n"
      "mov.b32 %1, {b4, b5, b6, b7};\n"
      "}\n"
      : "=r"(out.lo), "=r"(out.hi)
      : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
        "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y),
        "f"(array[4].x), "f"(array[4].y), "f"(array[5].x), "f"(array[5].y),
        "f"(array[6].x), "f"(array[6].y), "f"(array[7].x), "f"(array[7].y));
  return out;
}

__device__ __forceinline__ uint32_t pack_fp4(float2 (&v)[4]) {
  return fp32_vec8_to_e2m1(v);
}

__device__ __forceinline__ u32x2 pack_fp4(float2 (&v)[8]) {
  return fp32_vec16_to_e2m1(v);
}

144
// Fast reciprocal.
145
__device__ __forceinline__ float reciprocal_approximate_ftz(float a) {
146
  float b;
147
  asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(b) : "f"(a));
148
149
150
  return b;
}

151
152
153
// Compute SF output offset for swizzled tensor core layout.
// SF layout: [numMTiles, numKTiles, 32, 4, 4]
// Caller must precompute: numKTiles = (numCols + 63) / 64
154
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
155
156
__device__ __forceinline__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(
    int rowIdx, int colIdx, int32_t numKTiles, SFType* SFout) {
157
158
159
160
161
162
  static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
                CVT_FP4_NUM_THREADS_PER_SF == 2);

  // One pair of threads write one SF to global memory.
  // TODO: stage through smem for packed STG.32
  // is it better than STG.8 from 4 threads ?
163
164
  if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF != 0) {
    return nullptr;
165
  }
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186

  // SF vector index (16 elements share one SF in the K dimension).
  int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
  int32_t mIdx = rowIdx;

  // Decompose indices using bitwise ops (all divisors are powers of 2).
  // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
  int32_t mTileIdx = mIdx >> 7;         // mIdx / 128
  int32_t outerMIdx = mIdx & 31;        // mIdx % 32
  int32_t innerMIdx = (mIdx >> 5) & 3;  // (mIdx / 32) % 4
  int32_t kTileIdx = kIdx >> 2;         // kIdx / 4
  int32_t innerKIdx = kIdx & 3;         // kIdx % 4

  // Compute global SF offset: mTileIdx * (numKTiles * 512) + kTileIdx * 512 +
  //                           outerMIdx * 16 + innerMIdx * 4 + innerKIdx
  // Use bitwise OR for non-overlapping lower bits.
  int64_t SFOffset = (static_cast<int64_t>(mTileIdx) * numKTiles + kTileIdx)
                         << 9 |
                     (outerMIdx << 4) | (innerMIdx << 2) | innerKIdx;

  return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
187
188
}

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
template <class SFType>
__device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack,
                                                       int packs_per_row_sf,
                                                       SFType* SFout) {
  constexpr int PACK = CVT_FP4_ELTS_PER_THREAD;
  constexpr int THREADS_PER_SF =
      CVT_FP4_SF_VEC_SIZE / PACK;  // 1 if PACK=16, 2 else PACK=8

  if (threadIdx.x % THREADS_PER_SF != 0) return nullptr;

  int sf_col =
      pack / THREADS_PER_SF;  // PACK=16 => sf_col=pack; PACK=8 => sf_col=pack/2
  int64_t off = (int64_t)row * packs_per_row_sf + sf_col;

  return (uint8_t*)SFout + off;
}

206
// Quantizes the provided PackedVec into the uint32_t output
207
template <class Type, int CVT_FP4_NUM_THREADS_PER_SF, bool UE8M0_SF = false>
208
209
__device__ __forceinline__ fp4_packed_t cvt_warp_fp16_to_fp4(
    PackedVec<Type, CVT_FP4_PACK16>& vec, float SFScaleVal, uint8_t* SFout) {
210
211
212
  // Get absolute maximum values among the local 8 values.
  auto localMax = __habs2(vec.elts[0]);

213
  // Local maximum value.
214
215
216
217
218
219
#pragma unroll
  for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
    localMax = __hmax2(localMax, __habs2(vec.elts[i]));
  }

  // Get the absolute maximum among all 16 values (two threads).
220
221
222
223

  if constexpr (CVT_FP4_NUM_THREADS_PER_SF == 2) {
    localMax = __hmax2(__shfl_xor_sync(0xffffffffu, localMax, 1), localMax);
  }
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
  // Get the final absolute maximum values.
  float vecMax = float(__hmax(localMax.x, localMax.y));

  // Get the SF (max value of the vector / max value of e2m1).
  // maximum value of e2m1 = 6.0.
  // TODO: use half as compute data type.
  float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
  // 8 bits representation of the SF.
  uint8_t fp8SFVal;
  // Write the SF to global memory (STG.8).
  if constexpr (UE8M0_SF) {
    // Extract the 8 exponent bits from float32.
    // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
    uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
    fp8SFVal = tmp & 0xff;
    // Convert back to fp32.
    reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
  } else {
    // Here SFValue is always positive, so E4M3 is the same as UE4M3.
    __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
    reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
    // Convert back to fp32.
    SFValue = float(tmp);
  }
248
249
250
251

  // Write the SF to global memory (STG.8).
  if (SFout) *SFout = fp8SFVal;

252
253
254
255
  // Get the output scale.
  // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
  //                       reciprocal(SFScaleVal))
  float outputScale =
256
257
258
      SFValue != 0.0f ? reciprocal_approximate_ftz(
                            SFValue * reciprocal_approximate_ftz(SFScaleVal))
                      : 0.0f;
259
260
261
262
263
264

  // Convert the input to float.
  float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];

#pragma unroll
  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
265
    fp2Vals[i] = cast_to_float2(vec.elts[i]);
266
267
268
269
270
    fp2Vals[i].x *= outputScale;
    fp2Vals[i].y *= outputScale;
  }

  // Convert to e2m1 values.
271
  return pack_fp4(fp2Vals);
272
273
}

274
275
276
277
278
279
280
281
282
283
// silu in float32
__device__ __forceinline__ float silu(float x) {
  return __fdividef(x, (1.f + __expf(-x)));
}

__device__ __forceinline__ float2 silu2(float2 x) {
  return make_float2(silu(x.x), silu(x.y));
}

template <class Type>
284
285
286
287
__inline__ __device__ PackedVec<Type, CVT_FP4_PACK16> compute_silu_mul(
    const PackedVec<Type, CVT_FP4_PACK16>& x_vec,
    const PackedVec<Type, CVT_FP4_PACK16>& y_vec) {
  PackedVec<Type, CVT_FP4_PACK16> result;
288
289
290
291

#pragma unroll
  for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) {
    // silu_mul in float32
292
293
294
295
296
    using packed_t = typename PackedTypeConverter<Type>::Type;
    float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i]));
    float2 y_f2 = cast_to_float2(y_vec.elts[i]);
    result.elts[i] = cast_to_packed<packed_t>(
        make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y));
297
298
299
300
  }
  return result;
}

301
}  // namespace vllm