quantize.cuh 16.5 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
 *
 * See LICENSE for license information.
 ************************************************************************/

/*! \file quantize.cuh
 *  \brief Quantize dispatcher.
 */

#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_
#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_

#include <transformer_engine/transformer_engine.h>

#include "../../common.h"
#include "../../transpose/cast_transpose.h"
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
22
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"

namespace transformer_engine {
namespace dispatch {

template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
                         const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
  using namespace detail;

  const Tensor *input_tensor = convertNVTETensorCheck(input);
  Tensor *output_tensor = convertNVTETensorCheck(output);

  // Quantization config
  QuantizationConfig quant_config_cpp;
  if (quant_config != nullptr) {
    quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
  }

  // Noop flag
  Tensor dummy_tensor;
  Tensor *noop_tensor = &dummy_tensor;
  if (quant_config_cpp.noop_tensor != nullptr) {
    noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
  }

  // Check for unsupported options
  if (quant_config_cpp.stochastic_rounding) {
    NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
               "Stochastic rounding is only supported for NVFP4 quantization.");
  }

  NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
             "Either rowwise or columnwise output data need to be allocated.");

  // Dispatch to quantization kernel depending on data format
  switch (output_tensor->scaling_mode) {
    case NVTE_DELAYED_TENSOR_SCALING: {
      const Tensor *dummy_input_tensor = nullptr;
      Tensor *dummy_dbias_tensor = nullptr;
      Tensor *dummy_workspace_tensor = nullptr;
      if (output_tensor->has_columnwise_data()) {
        NVTE_CHECK(output_tensor->has_data(),
                   "Quantizing in only the columnwise direction not supported yet!");
        if constexpr (!IS_ACT) {
          cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
        } else {
          cast_transpose_fused</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, float, ParamOP, OP>(
              *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor,
              dummy_workspace_tensor, stream);
        }
      } else if (output_tensor->has_data()) {
        fp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
            *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
            dummy_workspace_tensor, stream);
      }
      break;
    }
    case NVTE_MXFP8_1D_SCALING: {
      const Tensor *dummy_input_tensor = nullptr;
      Tensor *dummy_dbias_tensor = nullptr;
      Tensor *dummy_workspace_tensor = nullptr;
      mxfp8::quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
          *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor,
          dummy_workspace_tensor, stream);
      break;
    }
    case NVTE_NVFP4_1D_SCALING: {
      NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");

      // Check tensors
      CheckNoopTensor(*noop_tensor, "cast_noop");
      CheckInputTensor(*input_tensor, "input");
      CheckOutputTensor(*output_tensor, "output", false);

      // Choose kernel
      int32_t rows = input_tensor->flat_first_dim();
      int32_t cols = input_tensor->flat_last_dim();
      auto dtype = input_tensor->dtype();
      bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
                                  (cols % 32 == 0) && output_tensor->has_data();

      // Launch NVFP4 quantize kernel
      if (use_optimized_kernel) {
        if (quant_config_cpp.nvfp4_2d_quantization) {
          nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
              *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
        } else {
          nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
              *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
        }
      } else {
        auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
                                                                  : output_tensor->columnwise_amax;
        quantize_transpose_vector_blockwise_fp4(
            /*input=*/input_tensor->data, /*global_amax=*/global_amax,
            /*scale_inv=*/output_tensor->scale_inv,
            /*scale_inv_t=*/output_tensor->columnwise_scale_inv,
            /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
            /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
            /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
            /*swizzled_scale=*/false,
            /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
            /*rng_state=*/quant_config_cpp.rng_state,
            /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
            /*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
      }
      break;
    }
    case NVTE_BLOCK_SCALING_2D: {
      // TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
      NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D");
      bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
      float epsilon = quant_config_cpp.amax_epsilon;
      quantize_transpose_square_blockwise(
          input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
          output_tensor->data, output_tensor->columnwise_data, epsilon,
          /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
          /*noop_tensor=*/noop_tensor->data, stream);
      break;
    }
    case NVTE_BLOCK_SCALING_1D: {
      // TODO(kwyss): IS_ACT, ParamOP, OP parameters support.
      NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D");
      bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
      float epsilon = quant_config_cpp.amax_epsilon;
      FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
      FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
      if (output_tensor->has_data()) {
153
        rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
154
155
      }
      if (output_tensor->has_columnwise_data()) {
156
        columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
      }
      quantize_transpose_vector_blockwise(
          input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
          output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
          columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
      break;
    }
    default:
      NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
  }
}

template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output,
                         NVTETensor dbias, NVTETensor workspace,
                         const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
  using namespace detail;

  const Tensor *grad_tensor = convertNVTETensorCheck(grad);
  const Tensor *input_tensor = convertNVTETensor(input);

  Tensor *output_tensor = convertNVTETensorCheck(output);
  Tensor *dbias_tensor = convertNVTETensor(dbias);
  Tensor *workspace_tensor = convertNVTETensor(workspace);

  // Quantization config
  QuantizationConfig quant_config_cpp;
  if (quant_config != nullptr) {
    quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
  }

  // Noop flag
  Tensor dummy_tensor;
  Tensor *noop_tensor = &dummy_tensor;
  if (quant_config_cpp.noop_tensor != nullptr) {
    noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
  }

  // Check for unsupported options
  if (quant_config_cpp.stochastic_rounding) {
    NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
               "Stochastic rounding is only supported for NVFP4 quantization.");
  }

  NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(),
             "Either rowwise or columnwise output data need to be allocated.");

  // Dispatch to quantization kernel depending on data format
  switch (output_tensor->scaling_mode) {
    case NVTE_DELAYED_TENSOR_SCALING: {
      if (output_tensor->has_columnwise_data()) {
        NVTE_CHECK(output_tensor->has_data(),
                   "Quantizing in only the columnwise direction not supported yet!");
        if constexpr (!IS_DBIAS && !IS_DACT) {
          cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream);
        } else {
          cast_transpose_fused<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, float, ParamOP, OP>(
              *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream);
        }
      } else if (output_tensor->has_data()) {
        fp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
            *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
            stream);
      }
      break;
    }
    case NVTE_MXFP8_1D_SCALING: {
      mxfp8::quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
          *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
          stream);
      break;
    }
    case NVTE_NVFP4_1D_SCALING: {
      NVTE_CHECK((!IS_DBIAS && !IS_DACT),
                 "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING");

      // Check tensors
      CheckNoopTensor(*noop_tensor, "cast_noop");
      CheckInputTensor(*grad_tensor, "input");
      CheckOutputTensor(*output_tensor, "output", false);

      // Choose kernel
      int32_t rows = grad_tensor->flat_first_dim();
      int32_t cols = grad_tensor->flat_last_dim();
      auto dtype = grad_tensor->dtype();
      bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) &&
                                  (cols % 32 == 0) && output_tensor->has_data();

      // Launch NVFP4 quantize kernel
      if (use_optimized_kernel) {
        if (quant_config_cpp.nvfp4_2d_quantization) {
          nvfp4::quantize_transpose</*use_2d_quantization=*/true>(
              *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
        } else {
          nvfp4::quantize_transpose</*use_2d_quantization*/ false>(
              *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
        }
      } else {
        auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
                                                                  : output_tensor->columnwise_amax;
        quantize_transpose_vector_blockwise_fp4(
            /*input=*/grad_tensor->data, /*global_amax=*/global_amax,
            /*scale_inv=*/output_tensor->scale_inv,
            /*scale_inv_t=*/output_tensor->columnwise_scale_inv,
            /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
            /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
            /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
            /*swizzled_scale=*/false,
            /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
            /*rng_state=*/quant_config_cpp.rng_state,
            /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
            /*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
      }
      break;
    }
    case NVTE_BLOCK_SCALING_2D: {
      // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
      NVTE_CHECK((!IS_DBIAS && !IS_DACT),
                 "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D");
      bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
      float epsilon = quant_config_cpp.amax_epsilon;
      quantize_transpose_square_blockwise(
          grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
          output_tensor->data, output_tensor->columnwise_data, epsilon,
          /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
          /*noop_tensor=*/noop_tensor->data, stream);
      break;
    }
    case NVTE_BLOCK_SCALING_1D: {
      // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support.
      NVTE_CHECK((!IS_DBIAS && !IS_DACT),
                 "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D");
      bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
      float epsilon = quant_config_cpp.amax_epsilon;
      FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
      FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
      if (output_tensor->has_data()) {
294
        rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
295
296
      }
      if (output_tensor->has_columnwise_data()) {
297
        columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
298
299
300
301
302
303
304
305
306
307
308
309
      }
      quantize_transpose_vector_blockwise(
          grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
          output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
          columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
      break;
    }
    default:
      NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
  }
}

310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
                               const size_t *split_sections, const size_t num_tensors,
                               const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
  using namespace detail;

  const Tensor *input_tensor = convertNVTETensorCheck(input);
  std::vector<Tensor *> output_tensors;
  for (size_t i = 0; i < num_tensors; ++i) {
    output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
  }

  // Quantization config
  QuantizationConfig quant_config_cpp;
  if (quant_config != nullptr) {
    quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
  }

  // Noop flag
  Tensor dummy_tensor;
  Tensor *noop_tensor = &dummy_tensor;
  if (quant_config_cpp.noop_tensor != nullptr) {
    noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
  }

  // Check for unsupported options
  if (quant_config_cpp.stochastic_rounding) {
    NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
               "Stochastic rounding is only supported for NVFP4 quantization.");
  }

  // Take the scaling mode of the first output tensor
  auto scaling_mode = output_tensors[0]->scaling_mode;

  // Dispatch to quantization kernel depending on data format
  switch (scaling_mode) {
    case NVTE_NVFP4_1D_SCALING: {
      NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");

      // Check tensors
      CheckNoopTensor(*noop_tensor, "cast_noop");
      CheckInputTensor(*input_tensor, "input");
      // Skip checking output tensor list
      // output list here is allowed to have empty tensor

      // Choose kernel
      int32_t rows = input_tensor->flat_first_dim();
      int32_t cols = input_tensor->flat_last_dim();
      auto dtype = input_tensor->dtype();

      NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
                 "2D quantization is not supported for group quantize.");

      // Launch NVFP4 group quantize kernel
      nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
          *input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
          &quant_config_cpp, stream);
      break;
    }
    default:
      NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
  }
}

374
375
376
377
}  // namespace dispatch
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_