cast.cu 3.47 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/cast.h>
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
#include "../common.h"
#include "../util/vectorized_pointwise.h"
11
#include "../utils.cuh"
Przemek Tredak's avatar
Przemek Tredak committed
12
13
14
15
16
17
18

namespace transformer_engine {

namespace detail {

struct Empty {};

19
__device__ inline fp32 identity(fp32 value, const Empty &) { return value; }
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
25
26
27
28
29
30

struct DequantizeParam {
  const fp32 *scale_inv;
};

__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) {
  return value * (*(param.scale_inv));
}

}  // namespace detail

31
void fp8_quantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
32
33
34
  CheckInputTensor(input, "cast_input");
  CheckOutputTensor(*output, "cast_output");

35
  NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input must be in higher precision.");
36

37
  NVTE_CHECK(is_fp8_dtype(output->data.dtype), "Output must have FP8 type.");
38
39
40
  NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

  const size_t N = product(input.data.shape);
41
42
43
44
45
46
47
48
49
50
51
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
          VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr),
              reinterpret_cast<const fp32 *>(output->scale.dptr),
              reinterpret_cast<fp32 *>(output->amax.dptr), N, {},
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
52
53
}

54
void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
55
56
  CheckInputTensor(input, "cast_input");
  CheckOutputTensor(*output, "cast_output");
57
  NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type.");
58

59
  NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
60
61
62
  NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

  const size_t N = product(input.data.shape);
63
64
65
66
67
68
69
70
71
72
73
  TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
      input.data.dtype, IType,
      TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
          output->data.dtype, OType, constexpr int nvec = 32 / sizeof(OType);
          detail::DequantizeParam p;
          p.scale_inv = reinterpret_cast<const fp32 *>(input.scale_inv.dptr);
          VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
              reinterpret_cast<const IType *>(input.data.dptr),
              reinterpret_cast<OType *>(output->data.dptr), nullptr, nullptr, N, p,
              stream););  // NOLINT(*)
  );                      // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
77
}

}  // namespace transformer_engine

78
void nvte_fp8_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
79
  NVTE_API_CALL(nvte_fp8_quantize);
Przemek Tredak's avatar
Przemek Tredak committed
80
  using namespace transformer_engine;
81
  fp8_quantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
Przemek Tredak's avatar
Przemek Tredak committed
82
83
84
               stream);
}

85
void nvte_fp8_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
86
  NVTE_API_CALL(nvte_fp8_dequantize);
Przemek Tredak's avatar
Przemek Tredak committed
87
  using namespace transformer_engine;
88
  fp8_dequantize(*reinterpret_cast<const Tensor *>(input), reinterpret_cast<Tensor *>(output),
Przemek Tredak's avatar
Przemek Tredak committed
89
90
                 stream);
}