quantization.cpp 8.58 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
 *
 * See LICENSE for license information.
 ************************************************************************/
6
#include <cuda_runtime.h>
7

8
#include "extensions.h"
9
#include "transformer_engine/cast.h"
10
#include "xla/ffi/api/c_api.h"
11
12
13
14

namespace transformer_engine {
namespace jax {

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                               DType in_dtype, DType out_dtype) {
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};

  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias chooses its internal impl based on what
  // pointers are allocated, e.g. whether to output with column-wise
  // data. However, we don't have access to any allocated buffers in
  // this function. We pass a dummy pointer as a workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
  output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);

  TensorWrapper dummy_workspace;

  nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                      dummy_workspace.data(), nullptr);

  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
41
42
}

43
44
45
46
47
48
Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
                            Result_Type output_buf, Result_Type output_trans_buf,
                            Result_Type scale_inv_buf, Result_Type trans_scale_inv_buf,
                            Result_Type amax_out_buf, Result_Type dbias_buf,
                            Result_Type workspace_buf, int64_t scaling_mode_enum,
                            int64_t quantize_axis_enum, bool is_dbias) {
49
50
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
51
52
53
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization.");
54
55

  auto *input = input_buf.untyped_data();
56
57
58

  auto scaling_mode = static_cast<NVTEScalingMode>(scaling_mode_enum);
  auto const quantize_axis = static_cast<QuantizeAxis>(quantize_axis_enum);
59
60

  auto *output = output_buf->untyped_data();
61
62
63
  auto *output_trans = output_trans_buf->untyped_data();
  auto *dbias = dbias_buf->untyped_data();
  void *workspace = workspace_buf->untyped_data();
64
65

  auto input_dims = input_buf.dimensions();
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  auto workspace_dims = workspace_buf->dimensions();
  auto m = product(input_dims, 0, input_dims.size() - 1);
  auto n = input_dims.back();
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
  std::vector<size_t> workspace_shape{workspace_dims.begin(), workspace_dims.end()};

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
  auto output_tensor = TensorWrapper(scaling_mode);

  if (quantize_axis == QuantizeAxis::ROWWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
    output_tensor.set_rowwise_data(output, out_dtype, output_shape);
    output_tensor.set_rowwise_scale_inv(
        scale_inv_buf->untyped_data(),
        convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()),
        std::vector<size_t>{
            product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1),
            scale_inv_buf->dimensions().back()});
  }

  if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
    float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
    float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
    NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling");
    NVTE_CHECK(amax_out != nullptr, "amax must be provided for delayed tensor scaling");
    output_tensor.set_scale(scale, DType::kFloat32, std::vector<size_t>{1});
    cudaMemsetAsync(amax_out, 0, sizeof(float), stream);
    output_tensor.set_amax(amax_out, DType::kFloat32, std::vector<size_t>{1});
  }

  if (quantize_axis == QuantizeAxis::COLWISE || quantize_axis == QuantizeAxis::ROWWISE_COLWISE) {
    output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
    // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
    auto &colwise_scale_inv_buf =
        (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) ? scale_inv_buf : trans_scale_inv_buf;
    output_tensor.set_columnwise_scale_inv(
        colwise_scale_inv_buf->untyped_data(),
        convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()),
        std::vector<size_t>{product(colwise_scale_inv_buf->dimensions(), 0,
                                    colwise_scale_inv_buf->dimensions().size() - 1),
                            colwise_scale_inv_buf->dimensions().back()});
  }

  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);

  if (is_dbias) {
    nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                        workspace_tensor.data(), stream);
  } else {
    nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
  }
120
121
122
  return ffi_with_cuda_error_check();
}

123
XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI,
124
125
126
127
128
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // scale
                                  .Ret<Buffer_Type>()      // output
129
130
131
132
133
134
135
136
137
                                  .Ret<Buffer_Type>()      // colwise output
                                  .Ret<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // scale_inv colwise
                                  .Ret<Buffer_Type>()      // amax
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // wkspace
                                  .Attr<int64_t>("scaling_mode")
                                  .Attr<int64_t>("q_axis")
                                  .Attr<bool>("is_dbias"),
138
139
                              FFI_CudaGraph_Traits);

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                         Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  std::vector<size_t> shape(input_dims.begin(), input_dims.end());
  auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
  auto output_tensor = TensorWrapper(output, shape, out_dtype);

157
  nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
158
159
160
161
162
163
164
165
166
167
168
169
170
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>(),     // output
                              FFI_CudaGraph_Traits);

171
172
}  // namespace jax
}  // namespace transformer_engine