cast.cu 3.67 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/cast.h>
#include "../common.h"
#include "../utils.cuh"
#include "../util/vectorized_pointwise.h"

namespace transformer_engine {

namespace detail {

struct Empty {};

__device__ inline fp32 identity(fp32 value, const Empty&) {
  return value;
}

struct DequantizeParam {
  const fp32 *scale_inv;
};

__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) {
  return value * (*(param.scale_inv));
}

}  // namespace detail

void fp8_quantize(const Tensor &input,
                  Tensor *output,
                  cudaStream_t stream) {
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
  CheckInputTensor(input, "cast_input");
  CheckOutputTensor(*output, "cast_output");

  NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
             "Input must be in higher precision.");

  NVTE_CHECK(is_fp8_dtype(output->data.dtype),
             "Output must have FP8 type.");
  NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

  const size_t N = product(input.data.shape);
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(IType);
      VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
          reinterpret_cast<const IType*>(input.data.dptr),
          reinterpret_cast<OType*>(output->data.dptr),
          reinterpret_cast<const fp32*>(output->scale.dptr),
          reinterpret_cast<fp32*>(output->amax.dptr),
          N,
          {},
          stream);
Przemek Tredak's avatar
Przemek Tredak committed
57
    );  // NOLINT(*)
58
  );  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
59
60
61
62
63
}

void fp8_dequantize(const Tensor &input,
                    Tensor *output,
                    cudaStream_t stream) {
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
  CheckInputTensor(input, "cast_input");
  CheckOutputTensor(*output, "cast_output");
  NVTE_CHECK(is_fp8_dtype(input.data.dtype),
             "Input must have FP8 type.");

  NVTE_CHECK(!is_fp8_dtype(output->data.dtype),
             "Output must be in higher precision.");
  NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

  const size_t N = product(input.data.shape);
  TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.data.dtype, IType,
    TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->data.dtype, OType,
      constexpr int nvec = 32 / sizeof(OType);
      detail::DequantizeParam p;
      p.scale_inv = reinterpret_cast<const fp32*>(input.scale_inv.dptr);
      VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
          reinterpret_cast<const IType*>(input.data.dptr),
          reinterpret_cast<OType*>(output->data.dptr),
          nullptr,
          nullptr,
          N,
          p,
          stream);
Przemek Tredak's avatar
Przemek Tredak committed
87
    );  // NOLINT(*)
88
  );  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
89
90
91
92
93
94
95
}

}  // namespace transformer_engine

void nvte_fp8_quantize(const NVTETensor input,
                       NVTETensor output,
                       cudaStream_t stream) {
96
  NVTE_API_CALL(nvte_fp8_quantize);
Przemek Tredak's avatar
Przemek Tredak committed
97
98
99
100
101
102
103
104
105
  using namespace transformer_engine;
  fp8_quantize(*reinterpret_cast<const Tensor*>(input),
               reinterpret_cast<Tensor*>(output),
               stream);
}

void nvte_fp8_dequantize(const NVTETensor input,
                         NVTETensor output,
                         cudaStream_t stream) {
106
  NVTE_API_CALL(nvte_fp8_dequantize);
Przemek Tredak's avatar
Przemek Tredak committed
107
108
109
110
111
  using namespace transformer_engine;
  fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
                 reinterpret_cast<Tensor*>(output),
                 stream);
}