nvfp4_quant_kernels.cu 9.69 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <torch/all.h>

#include <cuda_runtime_api.h>
#include <cuda_runtime.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cuda_fp8.h>
26
#include "dispatch_utils.h"
27
28

#include "cuda_utils.h"
29
#include "launch_bounds_utils.h"
30
31
32
33

// Define before including nvfp4_utils.cuh so the header
// can use this macro during compilation.
#define NVFP4_ENABLE_ELTS16 1
34
#include "nvfp4_utils.cuh"
35

36
namespace vllm {
37
38
39

// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
40
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
41
42
43
44
    cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, int32_t num_padded_cols,
                    Type const* __restrict__ in,
                    float const* __restrict__ SFScale,
                    uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) {
45
  using PackedVec = vllm::PackedVec<Type, CVT_FP4_PACK16>;
46

47
48
49
50
51
  static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
      (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
                "Vec size is not matched.");

52
53
54
  // Precompute SF layout parameter (constant for entire kernel).
  int32_t const numKTiles = (numCols + 63) / 64;

55
  int sf_m = round_up<int>(numRows, 128);
56
57
  int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
  int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;
58

59
60
61
  // Get the global scaling factor, which will be applied to the SF.
  // Note SFScale is the same as next GEMM's alpha, which is
  // (448.f / (Alpha_A / 6.f)).
62
  float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];
63

64
65
66
  // Iterate over all rows and cols including padded ones -
  //  ensures we visit every single scale factor address to initialize it.
  for (int rowIdx = blockIdx.x; rowIdx < sf_m; rowIdx += gridDim.x) {
67
    if (colIdx < num_padded_cols) {
68
      PackedVec in_vec;
69
      int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
70
71

      // If we are outside valid rows OR outside valid columns -> Use Zeros
72
73
      bool valid = (rowIdx < numRows) && (elem_idx < numCols);
      if constexpr (CVT_FP4_PACK16) {
74
75
76
        ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
                         &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
                         valid);
77
      } else {
78
79
80
        ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
                         &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
                         valid);
81
      }
82
83
84
85

      auto sf_out =
          cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
                                             CVT_FP4_NUM_THREADS_PER_SF>(
86
              rowIdx, colIdx, numKTiles, SFout);
87

88
      auto out_val =
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
          cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
              in_vec, global_scale, sf_out);

      // We do NOT write output for padding because the 'out' tensor is not
      // padded.
      if (valid) {
        if constexpr (CVT_FP4_PACK16) {
          int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
          uint64_t packed64 =
              (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
          reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
        } else {
          out[inOffset] = out_val;
        }
      }
    }
  }
}

// Use UE4M3 by default.
template <class Type, bool UE8M0_SF = false>
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
    cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols,
112
113
                             int32_t sf_n_unpadded, int32_t num_packed_cols,
                             Type const* __restrict__ in,
114
115
116
                             float const* __restrict__ SFScale,
                             uint32_t* __restrict__ out,
                             uint32_t* __restrict__ SFout) {
117
  using PackedVec = PackedVec<Type, CVT_FP4_PACK16>;
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

  static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
      (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
  static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
                "Vec size is not matched.");

  int32_t const colIdx = blockDim.x * blockIdx.y + threadIdx.x;
  int elem_idx = colIdx * CVT_FP4_ELTS_PER_THREAD;

  // Get the global scaling factor, which will be applied to the SF.
  // Note SFScale is the same as next GEMM's alpha, which is
  // (448.f / (Alpha_A / 6.f)).
  float const global_scale = (SFScale == nullptr) ? 1.0f : SFScale[0];

  // Iterate over all rows and cols including padded ones -
  //  ensures we visit every single scale factor address to initialize it.
  for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
135
    if (colIdx < num_packed_cols) {
136
137
138
139
140
141
      PackedVec in_vec;
      int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;

      // If we are outside valid rows OR outside valid columns -> Use Zeros
      bool valid = (rowIdx < numRows) && (elem_idx < numCols);
      if constexpr (CVT_FP4_PACK16) {
142
143
144
        ld256_cg_or_zero(reinterpret_cast<u32x8_t&>(in_vec),
                         &reinterpret_cast<const uint32_t*>(in)[inOffset * 8],
                         valid);
145
      } else {
146
147
148
        ld128_cg_or_zero(reinterpret_cast<uint4&>(in_vec),
                         &reinterpret_cast<const uint32_t*>(in)[inOffset * 4],
                         valid);
149
150
151
152
153
154
155
156
      }

      auto sf_out =
          sf_out_rowmajor_u8<uint32_t>(rowIdx, colIdx, sf_n_unpadded, SFout);

      auto out_val =
          cvt_warp_fp16_to_fp4<Type, CVT_FP4_NUM_THREADS_PER_SF, UE8M0_SF>(
              in_vec, global_scale, sf_out);
157
158
159

      // We do NOT write output for padding because the 'out' tensor is not
      // padded.
160
161
162
163
164
165
166
167
168
      if (valid) {
        if constexpr (CVT_FP4_PACK16) {
          int64_t outOffset = rowIdx * (numCols / 8) + colIdx * 2;
          uint64_t packed64 =
              (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo);
          reinterpret_cast<uint64_t*>(out)[outOffset >> 1] = packed64;
        } else {
          out[inOffset] = out_val;
        }
169
      }
170
171
172
173
    }
  }
}

174
175
}  // namespace vllm

176
void scaled_fp4_quant_sm1xxa(torch::Tensor const& output,
177
178
                             torch::Tensor const& input,
                             torch::Tensor const& output_sf,
179
180
                             torch::Tensor const& input_sf,
                             bool is_sf_swizzled_layout) {
181
182
183
184
  int32_t m = input.size(0);
  int32_t n = input.size(1);

  TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
185
186
187
  TORCH_CHECK(input.scalar_type() == at::ScalarType::Half ||
                  input.scalar_type() == at::ScalarType::BFloat16,
              "Unsupported input data type for quantize_to_fp4.");
188
189
190
191
192
193
194

  int multiProcessorCount =
      get_device_attribute(cudaDevAttrMultiProcessorCount, -1);

  auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
  auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
  auto output_ptr = static_cast<int64_t*>(output.data_ptr());
195
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
196
  auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
197

198
199
  int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE);

200
201
202
203
  // Grid, Block size. Each thread converts 8 values.
  dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
  int const numBlocksPerSM =
      vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225

  if (is_sf_swizzled_layout) {
    int sf_n_int = int(vllm::round_up(sf_n_unpadded, 4) / 4);
    int32_t num_padded_cols =
        sf_n_int * 4 * CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;

    int grid_y = vllm::div_round_up(num_padded_cols, static_cast<int>(block.x));
    int grid_x =
        std::min(vllm::computeEffectiveRows(m),
                 std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
    dim3 grid(grid_x, grid_y);

    VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
      using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
      auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
      // NOTE: We don't support e8m0 scales at this moment.
      vllm::cvt_fp16_to_fp4<cuda_type, false><<<grid, block, 0, stream>>>(
          m, n, num_padded_cols, input_ptr, input_sf_ptr,
          reinterpret_cast<uint32_t*>(output_ptr),
          reinterpret_cast<uint32_t*>(sf_out));
    });
  } else {
226
227
    int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD;
    int grid_y = vllm::div_round_up(num_packed_cols, static_cast<int>(block.x));
228
229
230
231
232
233
234
235
236
    int grid_x = std::min(
        m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y));
    dim3 grid(grid_x, grid_y);

    VLLM_DISPATCH_HALF_TYPES(input.scalar_type(), "nvfp4_quant_kernel", [&] {
      using cuda_type = vllm::CUDATypeConverter<scalar_t>::Type;
      auto input_ptr = static_cast<cuda_type const*>(input.data_ptr());
      // NOTE: We don't support e8m0 scales at this moment.
      vllm::cvt_fp16_to_fp4_sf_major<cuda_type, false>
237
238
          <<<grid, block, 0, stream>>>(m, n, sf_n_unpadded, num_packed_cols,
                                       input_ptr, input_sf_ptr,
239
240
241
242
243
                                       reinterpret_cast<uint32_t*>(output_ptr),
                                       reinterpret_cast<uint32_t*>(sf_out));
    });
  }
}