quantization.cpp 1.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "jax/csrc/extensions.h"
#include "transformer_engine/cast.h"

namespace transformer_engine {
namespace jax {

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
14
15
16
17
18
19
  auto *input = buffers[0];
  auto *amax = reinterpret_cast<float *>(buffers[1]);
  auto *scale = reinterpret_cast<float *>(buffers[2]);
  auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  auto *amax_out = reinterpret_cast<float *>(buffers[5]);
20
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX Quantize primitive.");
21
22
23
24
25
26
27

  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto shape = desc.shape.to_vector();
  auto input_tensor = TensorWrapper(input, shape, desc.in_dtype);
  auto output_tensor = TensorWrapper(output, shape, desc.out_dtype, amax_out, scale, scale_inv);

  nvte_fp8_quantize(input_tensor.data(), output_tensor.data(), stream);
28
29
30
}

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
31
32
33
34
35
  auto *input = buffers[0];
  auto *amax = reinterpret_cast<float *>(buffers[1]);
  auto *scale = reinterpret_cast<float *>(buffers[2]);
  auto *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
36

37
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
38

39
40
  auto shape = desc.shape.to_vector();
  auto input_tensor = TensorWrapper(input, shape, desc.in_dtype, amax, scale, scale_inv);
41

42
  auto output_tensor = TensorWrapper(output, shape, desc.out_dtype);
43

44
  nvte_fp8_dequantize(input_tensor.data(), output_tensor.data(), stream);
45
46
47
48
}

}  // namespace jax
}  // namespace transformer_engine