quantization.cpp 11.3 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
 *
 * See LICENSE for license information.
 ************************************************************************/
6
#include <cuda_runtime.h>
7

8
#include "../extensions.h"
9
#include "transformer_engine/cast.h"
10
#include "transformer_engine/recipe.h"
11
#include "xla/ffi/api/c_api.h"
12
13
14
15

namespace transformer_engine {
namespace jax {

16
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
17
18
19
                                               DType in_dtype, DType out_dtype,
                                               JAXX_Scaling_Mode scaling_mode,
                                               QuantizeLayout q_layout) {
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};

  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias chooses its internal impl based on what
  // pointers are allocated, e.g. whether to output with column-wise
  // data. However, we don't have access to any allocated buffers in
  // this function. We pass a dummy pointer as a workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
  // Only the pointers will be checked for scale_inv, thus the shapes do not matter
  if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) {
    output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_rowwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                          std::vector<size_t>{1});
    }
  }

  if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) {
    auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape
                                                                                : output_shape;
    output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, tmp_shape);

    // Only the pointers will be checked for scale_inv, thus the shapes do not matter
    if (is_fp8_dtype(out_dtype)) {
      output_tensor.set_columnwise_scale_inv(reinterpret_cast<void *>(&temp), DType::kFloat32,
                                             std::vector<size_t>{1});
    }
  }

  if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) {
    output_tensor.set_amax(reinterpret_cast<void *>(&temp), DType::kFloat32,
                           std::vector<size_t>{1});
    output_tensor.set_scale(reinterpret_cast<void *>(&temp), DType::kFloat32,
                            std::vector<size_t>{1});
  }

64
65
66
67
68
69
70
  TensorWrapper dummy_workspace;

  nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                      dummy_workspace.data(), nullptr);

  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
71
72
}

73
74
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                            Result_Type output_buf, Result_Type output_trans_buf,
75
76
                            Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
                            Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf,
77
78
                            JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum,
                            bool is_dbias, int64_t flatten_axis) {
79
80
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
81
82
83
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
84
85

  auto *input = input_buf.untyped_data();
86

87
  auto const quantize_layout = static_cast<QuantizeLayout>(quantize_layout_enum);
88
89

  auto *output = output_buf->untyped_data();
90
91
92
  auto *output_trans = output_trans_buf->untyped_data();
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();
93
94

  auto input_dims = input_buf.dimensions();
95
96
97
98
  int64_t input_ndim = input_dims.size();
  if (flatten_axis < 0) flatten_axis += input_ndim;
  NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!");

99
  auto workspace_dims = workspace_buf->dimensions();
100
101
  auto m = product(input_dims, 0, flatten_axis);
  auto n = product(input_dims, flatten_axis, input_ndim);
102
103
104
105
106
107
108
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
  std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
109
  auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode));
110

111
112
113
  bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ||
                                 scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING;

114
115
  if (quantize_layout == QuantizeLayout::ROWWISE ||
      quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
116
117
    output_tensor.set_rowwise_data(output, out_dtype, output_shape);

118
    if (is_fp8_dtype(out_dtype)) {
119
      if (is_tensor_scaling) {
120
121
122
123
124
        float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
        float *amax = reinterpret_cast<float *>(amax_buf->untyped_data());
        NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
        NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling");
        output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
125
        nvte_memset(amax, 0, sizeof(float), stream);
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        output_tensor.set_amax(amax, DType::kFloat32, std::vector<size_t>{1});
        output_tensor.set_rowwise_scale_inv(
            scale_inv_buf->untyped_data(),
            convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
            std::vector<size_t>{1});
      } else {
        output_tensor.set_rowwise_scale_inv(
            scale_inv_buf->untyped_data(),
            convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
            std::vector<size_t>{product(scale_inv_buf->dimensions(), 0, flatten_axis),
                                product(scale_inv_buf->dimensions(), flatten_axis,
                                        scale_inv_buf->dimensions().size())});
      }
    }
140
141
  }

142
143
  if (quantize_layout == QuantizeLayout::COLWISE ||
      quantize_layout == QuantizeLayout::ROWWISE_COLWISE) {
144
145
146
    auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
                          ? output_trans_shape
                          : output_shape;
147
    output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape);
148
    // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
149
    auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf;
150

151
    if (is_tensor_scaling) {
152
153
154
155
156
157
158
159
160
161
      output_tensor.set_columnwise_scale_inv(
          tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
          std::vector<size_t>{1});
    } else {
      output_tensor.set_columnwise_scale_inv(
          tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()),
          std::vector<size_t>{
              product(tmp_buf->dimensions(), 0, flatten_axis),
              product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())});
    }
162
163
  }

164
165
166
167
  if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) {
    output_tensor.set_amax(nullptr, DType::kFloat32, std::vector<size_t>{1});
  }

168
169
170
171
172
173
174
175
176
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);

  if (is_dbias) {
    nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                        workspace_tensor.data(), stream);
  } else {
    nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
  }
177
178
179
  return ffi_with_cuda_error_check();
}

180
XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
181
182
183
184
185
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
186
187
188
189
190
191
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
192
                                  .Attr<JAXX_Scaling_Mode>("scaling_mode")
193
194
195
                                  .Attr<int64_t>("q_layout")
                                  .Attr<bool>("is_dbias")
                                  .Attr<int64_t>("flatten_axis"),
196
197
                              FFI_CudaGraph_Traits);

198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                         Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  std::vector<size_t> shape(input_dims.begin(), input_dims.end());
  auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
  auto output_tensor = TensorWrapper(output, shape, out_dtype);

215
  nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
216
217
218
219
220
221
222
223
224
225
226
227
228
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>(),     // output
                              FFI_CudaGraph_Traits);

229
230
}  // namespace jax
}  // namespace transformer_engine