cast.cu 3.54 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/cast.h>
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
#include "../common.h"
#include "../util/vectorized_pointwise.h"
11
#include "../utils.cuh"
Przemek Tredak's avatar
Przemek Tredak committed
12
13
14
15
16
17
18

namespace transformer_engine {

namespace detail {

struct Empty {};

19
__device__ inline fp32 identity(fp32 value, const Empty &) { return value; }
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
25
26
27
28
29
30

struct DequantizeParam {
  const fp32 *scale_inv;
};

__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) {
  return value * (*(param.scale_inv));
}

}  // namespace detail

31
void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
32
33
34
  CheckInputTensor(input, "cast_input");
  CheckOutputTensor(*output, "cast_output");

35
  NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision.");
36

37
  NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type.");
38
39
40
  NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

  const size_t N = product(input.data.shape);
41
42
43
44
45
46
47
48
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
          VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr),
              reinterpret_cast<const fp32 *>(output->scale.dptr),
49
50
              reinterpret_cast<fp32 *>(output->amax.dptr),
              reinterpret_cast<fp32 *>(output->scale_inv.dptr), N, {},
51
52
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
53
54
}

55
void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
56
57
  CheckInputTensor(input, "cast_input");
  CheckOutputTensor(*output, "cast_output");
58
  NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
59

60
  NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
61
62
63
  NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

  const size_t N = product(input.data.shape);
64
65
66
67
68
69
70
71
  TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType);
          detail::DequantizeParam p;
          p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
          VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
              reinterpret_cast<const IType *>(input.data.dptr),
72
              reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, nullptr, N, p,
73
74
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
75
76
77
78
}

}  // namespace transformer_engine

79
void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
80
  NVTE_API_CALL(nvte_fp8_quantize);
Przemek Tredak's avatar
Przemek Tredak committed
81
  using namespace transformer_engine;
82
  fp8_quantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
Przemek Tredak's avatar
Przemek Tredak committed
83
84
85
               stream);
}

86
void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
87
  NVTE_API_CALL(nvte_fp8_dequantize);
Przemek Tredak's avatar
Przemek Tredak committed
88
  using namespace transformer_engine;
89
  fp8_dequantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
Przemek Tredak's avatar
Przemek Tredak committed
90
91
                 stream);
}