quantization.cpp 20.7 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
 *
 * See LICENSE for license information.
 ************************************************************************/
6
#include <cuda_runtime.h>
7

8
#include "../extensions.h"
9
#include "transformer_engine/cast.h"
10
#include "transformer_engine/recipe.h"
11
#include "transformer_engine/transformer_engine.h"
12
#include "xla/ffi/api/c_api.h"
13
14
15
16

namespace transformer_engine {
namespace jax {

17
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
18
19
20
                                               DType in_dtype, DType out_dtype,
                                               JAXX_Scaling_Mode scaling_mode,
                                               QuantizeLayout q_layout) {
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};

  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias chooses its internal impl based on what
  // pointers are allocated, e.g. whether to output with column-wise
  // data. However, we don't have access to any allocated buffers in
  // this function. We pass a dummy pointer as a workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);

36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
    output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                          std::vector<size_t>{1});
    }
  }

  if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
    auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
                                                                                : output_shape;
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }

  if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

65
66
67
68
69
70
71
  TensorWrapper dummy_workspace;

  nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                      dummy_workspace.data(), nullptr);

  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
72
73
}

74
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
75
76
77
78
                            Buffer_Type amax_buf, Result_Type output_buf,
                            Result_Type output_trans_buf, Result_Type scale_inv_buf,
                            Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf,
                            Result_Type dbias_buf, Result_Type workspace_buf,
79
80
                            JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
                            bool is_dbias, int64_t flatten_axis) {
81
82
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
83
84
85
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
86
87

  auto *input = input_buf.untyped_data();
88

89
  auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
90
91

  auto *output = output_buf->untyped_data();
92
93
94
  auto *output_trans = output_trans_buf->untyped_data();
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();
95
96

  auto input_dims = input_buf.dimensions();
97
98
99
100
  int64_t input_ndim = input_dims.size();
  if (flatten_axis < 0) flatten_axis += input_ndim;
  NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");

101
  auto workspace_dims = workspace_buf->dimensions();
102
103
  auto m = product(input_dims, 0, flatten_axis);
  auto n = product(input_dims, flatten_axis, input_ndim);
104
105
106
107
108
109
110
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
  std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
111
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
112

113
114
115
  bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
                                 scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;

116
117
  if (quantize_layout == QuantizeLayout::ROWWISE ||
      quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
118
119
    output_tensor.set_rowwise_data(output, out_dtype, output_shape);

120
    if (is_fp8_dtype(out_dtype)) {
121
      if (is_tensor_scaling) {
122
        float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
123
124
        float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
        float *updated_amax = reinterpret_cast<float *>(updated_amax_buf->untyped_data());
125
        NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
126
127
        NVTE_CHECK(amax == updated_amax && amax != nullptr,
                   "amax must be provided for delayed tensor scaling");
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
        output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
        output_tensor.set_rowwise_scale_inv(
            scale_inv_buf->untyped_data(),
            convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_rowwise_scale_inv(
            scale_inv_buf->untyped_data(),
            convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
            std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                                product(scale_inv_buf->dimensions(), flatten_axis,
                                        scale_inv_buf->dimensions().size())});
      }
    }
143
144
  }

145
146
  if (quantize_layout == QuantizeLayout::COLWISE ||
      quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
147
148
149
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
150
    output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
151
    // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
152
    auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
153

154
    if (is_tensor_scaling) {
155
156
157
158
159
160
161
162
163
164
      output_tensor.set_columnwise_scale_inv(
          tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
          std::vector<size_t>{1});
    } else {
      output_tensor.set_columnwise_scale_inv(
          tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
          std::vector<size_t>{
              product(tmp_buf->dimensions(), 0, flatten_axis),
              product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
    }
165
166
  }

167
168
169
170
  if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
    output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
  }

171
172
173
174
175
176
177
178
179
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);

  if (is_dbias) {
    nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                        workspace_tensor.data(), stream);
  } else {
    nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
  }
180
181
182
  return ffi_with_cuda_error_check();
}

183
XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
184
185
186
187
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
188
                                  .Arg<Buffer_Type>()      // amax
189
                                  .Ret<Buffer_Type>()      // output
190
191
192
193
194
195
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
196
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
197
198
199
                                  .Attr<int64_t>("q_layout")
                                  .Attr<bool>("is_dbias")
                                  .Attr<int64_t>("flatten_axis"),
200
201
                              FFI_CudaGraph_Traits);

202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                         Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  std::vector<size_t> shape(input_dims.begin(), input_dims.end());
  auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
  auto output_tensor = TensorWrapper(output, shape, out_dtype);

219
  nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
220
221
222
223
224
225
226
227
228
229
230
231
232
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>(),     // output
                              FFI_CudaGraph_Traits);

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scales,
                              Buffer_Type group_sizes, Result_Type outputs,
                              Result_Type colwise_outputs, Result_Type scale_invs,
                              Result_Type colwise_scale_invs, Result_Type amaxs,
                              JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
                              int64_t flatten_axis) {
  NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING,
             "Unsupported scaling mode: ", static_cast<int>(scaling_mode));

  auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(outputs->element_type());
  NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");

  auto scale_dtype = convert_ffi_datatype_to_te_dtype(scales.element_type());
  auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type());
  auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type());
  auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type());
  auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);

  auto *input_ptr = reinterpret_cast<uint8_t *>(inputs.untyped_data());
  auto *scale_ptr = reinterpret_cast<uint8_t *>(scales.untyped_data());
  auto *output_ptr = reinterpret_cast<uint8_t *>(outputs->untyped_data());
  auto *colwise_output_ptr = reinterpret_cast<uint8_t *>(colwise_outputs->untyped_data());
  auto *sinv_ptr = reinterpret_cast<uint8_t *>(scale_invs->untyped_data());
  auto *colwise_sinv_ptr = reinterpret_cast<uint8_t *>(colwise_scale_invs->untyped_data());
  auto *amax_ptr = reinterpret_cast<uint8_t *>(amaxs->untyped_data());

  bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE ||
                     quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
  bool has_colwise = quantize_layout == QuantizeLayout::COLWISE ||
                     quantize_layout == QuantizeLayout::ROWWISE_COLWISE;
  bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING;
  bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
                                 scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;
267
  bool const is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING;
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320

  size_t input_dtype_bytes = te_dtype_bytes(in_dtype);
  size_t output_dtype_bytes = te_dtype_bytes(out_dtype);
  size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype);
  size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype);
  size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0;
  size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0;
  size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0;
  size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0;

  auto input_dims = inputs.dimensions();
  int64_t input_ndim = input_dims.size();
  if (flatten_axis < 0) flatten_axis += input_ndim;
  NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");

  auto m = product(input_dims, 0, flatten_axis);
  auto n = product(input_dims, flatten_axis, input_ndim);
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m * n};

  // These lists are to keep the TensorWrapper objects alive
  std::vector<TensorWrapper> input_holders;
  std::vector<TensorWrapper> output_holders;

  // These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
  std::vector<NVTETensor> input_list;
  std::vector<NVTETensor> output_list;

  size_t num_groups = group_sizes.dimensions()[0];
  size_t dim_list_bytes = group_size_dtype_bytes * num_groups;
  std::vector<int32_t> dim_list_host(num_groups);
  auto *group_size_ptr = reinterpret_cast<int32_t *>(group_sizes.untyped_data());
  cudaMemcpyAsync(dim_list_host.data(), group_size_ptr, dim_list_bytes, cudaMemcpyDeviceToHost,
                  stream);
  // Note: This may break cudaGraph.
  cudaStreamSynchronize(stream);

  size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0);
  NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes,
             "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m,
             input_dims[0]);

  if (is_delayed_scaling) {
    NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups,
               ", got ", amaxs->dimensions()[0]);
    NVTE_CHECK(amax_dtype == DType::kFloat32 && scale_dtype == DType::kFloat32);
    cudaMemsetAsync(amax_ptr, 0, sizeof(float) * num_groups, stream);
  }

  size_t sinv_size = 0;
  size_t colwise_sinv_size = 0;
  size_t non_group_m = flatten_axis > 1 ? product(input_dims, 1, flatten_axis) : 1;
  size_t num_non_empty_groups = 0;
321
322
  size_t total_rowwise_sinv_size = 0;
  size_t total_colwise_sinv_size = 0;
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
  for (size_t i = 0; i < num_groups; i++) {
    size_t m_i = dim_list_host[i] * non_group_m;
    // Skip for zero-size input + shiff the scale ptr
    if (m_i == 0) {
      if (is_tensor_scaling) scale_ptr += scale_dtype_bytes;
      continue;
    }
    num_non_empty_groups++;
    auto shape_i = std::vector<size_t>{m_i, n};
    auto shape_trans_i = std::vector<size_t>{n, m_i};

    auto inp_i = TensorWrapper(static_cast<void *>(input_ptr), shape_i, in_dtype);
    auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode));

    if (has_rowwise) {
      out_i.set_rowwise_data(static_cast<void *>(output_ptr), out_dtype, shape_i);

      if (is_fp8_dtype(out_dtype)) {
        if (is_tensor_scaling) {
          out_i.set_scale(static_cast<void *>(scale_ptr), DType::kFloat32, std::vector<size_t>{1});
          out_i.set_amax(static_cast<void *>(amax_ptr), DType::kFloat32, std::vector<size_t>{1});
          out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype,
                                      std::vector<size_t>{1});
          sinv_size = 1;
        } else {
          const bool is_colwise = false;
          auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
          out_i.set_rowwise_scale_inv(static_cast<void *>(sinv_ptr), sinv_dtype, sinv_shape_i);
          sinv_size = product(sinv_shape_i);
        }
      }
    }

    if (has_colwise) {
      auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i;
      out_i.set_columnwise_data(static_cast<void *>(colwise_output_ptr), out_dtype, tmp_shape);
      // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
      auto &tmp_sinv_ptr = is_tensor_scaling ? sinv_ptr : colwise_sinv_ptr;

      if (is_tensor_scaling) {
        out_i.set_columnwise_scale_inv(static_cast<void *>(tmp_sinv_ptr), sinv_dtype,
                                       std::vector<size_t>{1});
        colwise_sinv_size = 1;
      } else {
        const bool is_colwise = true;
        auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise);
        out_i.set_columnwise_scale_inv(static_cast<void *>(colwise_sinv_ptr), sinv_dtype,
                                       sinv_shape_i);
        colwise_sinv_size = product(sinv_shape_i);
      }
    }

    input_holders.push_back(std::move(inp_i));
    output_holders.push_back(std::move(out_i));

    input_list.push_back(input_holders.back().data());
    output_list.push_back(output_holders.back().data());

    input_ptr += m_i * n * input_dtype_bytes;
    scale_ptr += scale_dtype_bytes;
    output_ptr += m_i * n * output_dtype_bytes;
    colwise_output_ptr += m_i * n * colwise_output_dtype_bytes;
    sinv_ptr += sinv_size * sinv_dtype_bytes;
    colwise_sinv_ptr += colwise_sinv_size * colwise_sinv_dtype_bytes;
    amax_ptr += amax_dtype_bytes;
388
389
390
391
392
393
    total_rowwise_sinv_size += sinv_size;
    total_colwise_sinv_size += colwise_sinv_size;
  }
  if (is_mxfp8_scaling) {
    nvte_memset(scale_invs->untyped_data(), 0, total_rowwise_sinv_size, stream);
    nvte_memset(colwise_scale_invs->untyped_data(), 0, total_colwise_sinv_size, stream);
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
  }

  QuantizationConfigWrapper quant_config;
  nvte_multi_tensor_quantize(input_list.data(), output_list.data(), quant_config,
                             num_non_empty_groups, stream);

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // group_sizes
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
                                  .Attr<int64_t>("q_layout")
416
                                  .Attr<int64_t>("flatten_axis"));
417

418
419
}  // namespace jax
}  // namespace transformer_engine