quantization.cpp 5.33 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include "extensions.h"
8
#include "transformer_engine/cast.h"
9
#include "xla/ffi/api/c_api.h"
10
11
12
13
14

namespace transformer_engine {
namespace jax {

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
15
16
17
18
19
20
  auto *input = buffers[0];
  auto *amax = reinterpret_cast<float *>(buffers[1]);
  auto *scale = reinterpret_cast<float *>(buffers[2]);
  auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  auto *amax_out = reinterpret_cast<float *>(buffers[5]);
21
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");
22
23
24
25
26
27

  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto shape = desc.shape.to_vector();
  auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
  auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);

28
  nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
29
30
}

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
Error_Type QuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                       Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf,
                       Result_Type amax_out_buf) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();
  auto *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");

  auto input_dims = input_buf.dimensions();
  std::vector<size_t> shape(input_dims.begin(), input_dims.end());
  auto input_tensor = TensorWrapper(input, shape, in_dtype);
  auto output_tensor = TensorWrapper(output, shape, out_dtype, amax_out, scale, scale_inv);

51
  nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
52
53
54
55
56
57
58
59
60
61
62
63
64
65
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(QuantizeHandler, QuantizeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>(),     // amax_out
                              FFI_CudaGraph_Traits);

66
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
67
68
69
70
71
  auto *input = buffers[0];
  auto *amax = reinterpret_cast<float *>(buffers[1]);
  auto *scale = reinterpret_cast<float *>(buffers[2]);
  auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
72

73
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
74

75
76
77
  auto shape = desc.shape.to_vector();
  auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
  auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
78

79
  nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
80
81
}

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                         Buffer_Type scale_buf, Buffer_Type scale_inv_buf, Result_Type output_buf) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  auto *input = input_buf.untyped_data();
  auto *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  auto *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  auto *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  std::vector<size_t> shape(input_dims.begin(), input_dims.end());
  auto input_tensor = TensorWrapper(input, shape, in_dtype, amax, scale, scale_inv);
  auto output_tensor = TensorWrapper(output, shape, out_dtype);

99
  nvte_dequantize(input_tensor.data(), output_tensor.data(), stream);
100
101
102
103
104
105
106
107
108
109
110
111
112
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DequantizeHandler, DequantizeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>(),     // output
                              FFI_CudaGraph_Traits);

113
114
}  // namespace jax
}  // namespace transformer_engine