cast.cu 5.64 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/cast.h>
#include "../common.h"
#include "../utils.cuh"
#include "../util/vectorized_pointwise.h"

namespace transformer_engine {

namespace detail {

struct Empty {};

__device__ inline fp32 identity(fp32 value, const Empty&) {
  return value;
}

struct DequantizeParam {
  const fp32 *scale_inv;
};

__device__ inline fp32 dequantize_func(fp32 value, const DequantizeParam &param) {
  return value * (*(param.scale_inv));
}

}  // namespace detail

void fp8_quantize(const Tensor &input,
                  const Tensor &scale,
                  Tensor *output,
                  Tensor *amax,
                  Tensor *scale_inv,
                  cudaStream_t stream) {
    NVTE_CHECK(input.dtype != DType::kFloat8E4M3 &&
               input.dtype != DType::kFloat8E5M2,
               "Input must be in higher precision.");
    NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");

    NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
    NVTE_CHECK(output->dtype == DType::kFloat8E4M3 ||
               output->dtype == DType::kFloat8E5M2,
               "Output must have FP8 type.");
    NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");

    NVTE_CHECK(scale.dptr != nullptr, "Scale is not allocated.");
    NVTE_CHECK(scale.dtype == DType::kFloat32, "Scale must have FP32 type.");
    NVTE_CHECK(scale.shape == std::vector<size_t>{ 1 }, "Scale must have 1 element.");

    NVTE_CHECK(amax->dptr != nullptr, "AMAX is not allocated.");
    NVTE_CHECK(amax->dtype == DType::kFloat32, "AMAX must have FP32 type.");
    NVTE_CHECK(amax->shape == std::vector<size_t>{ 1 }, "AMAX must have 1 element.");

    NVTE_CHECK(scale_inv->dptr != nullptr, "Inverted scale is not allocated.");
    NVTE_CHECK(scale_inv->dtype == DType::kFloat32, "Inverted scale must have FP32 type.");
    NVTE_CHECK(scale_inv->shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element.");

    const size_t N = product(input.shape);
    TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.dtype, IType,
        TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output->dtype, OType,
          constexpr int nvec = 32 / sizeof(IType);
          VectorizedUnaryKernelLauncher<nvec, detail::Empty, detail::identity>(
                    reinterpret_cast<const IType*>(input.dptr),
                    reinterpret_cast<OType*>(output->dptr),
                    reinterpret_cast<const fp32*>(scale.dptr),
                    reinterpret_cast<fp32*>(scale_inv->dptr),
                    reinterpret_cast<fp32*>(amax->dptr),
                    N,
                    {},
                    stream);
        );  // NOLINT(*)
    );  // NOLINT(*)
}

void fp8_dequantize(const Tensor &input,
                    const Tensor &scale_inv,
                    Tensor *output,
                    cudaStream_t stream) {
    NVTE_CHECK(input.dtype == DType::kFloat8E4M3 ||
               input.dtype == DType::kFloat8E5M2,
               "Input must have FP8 type.");
    NVTE_CHECK(input.dptr != nullptr, "Input is not allocated.");

    NVTE_CHECK(output->dptr != nullptr, "Output is not allocated.");
    NVTE_CHECK(output->dtype != DType::kFloat8E4M3 &&
               output->dtype != DType::kFloat8E5M2,
               "Output must be in higher precision.");
    NVTE_CHECK(output->shape == input.shape, "Input and output shapes need to match.");

    NVTE_CHECK(scale_inv.dptr != nullptr, "Inverted scale is not allocated.");
    NVTE_CHECK(scale_inv.dtype == DType::kFloat32, "Inverted scale must have FP32 type.");
    NVTE_CHECK(scale_inv.shape == std::vector<size_t>{ 1 }, "Inverted scale must have 1 element.");

    const size_t N = product(input.shape);
    TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(input.dtype, IType,
        TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(output->dtype, OType,
          constexpr int nvec = 32 / sizeof(OType);
          detail::DequantizeParam p;
          p.scale_inv = reinterpret_cast<const fp32*>(scale_inv.dptr);
          VectorizedUnaryKernelLauncher<nvec, detail::DequantizeParam, detail::dequantize_func>(
                    reinterpret_cast<const IType*>(input.dptr),
                    reinterpret_cast<OType*>(output->dptr),
                    nullptr,
                    nullptr,
                    nullptr,
                    N,
                    p,
                    stream);
        );  // NOLINT(*)
    );  // NOLINT(*)
}

}  // namespace transformer_engine

void nvte_fp8_quantize(const NVTETensor input,
                       const NVTETensor scale,
                       NVTETensor output,
                       NVTETensor amax,
                       NVTETensor scale_inv,
                       cudaStream_t stream) {
  using namespace transformer_engine;
  fp8_quantize(*reinterpret_cast<const Tensor*>(input),
               *reinterpret_cast<const Tensor*>(scale),
               reinterpret_cast<Tensor*>(output),
               reinterpret_cast<Tensor*>(amax),
               reinterpret_cast<Tensor*>(scale_inv),
               stream);
}

void nvte_fp8_dequantize(const NVTETensor input,
                         const NVTETensor scale_inv,
                         NVTETensor output,
                         cudaStream_t stream) {
  using namespace transformer_engine;
  fp8_dequantize(*reinterpret_cast<const Tensor*>(input),
                 *reinterpret_cast<const Tensor*>(scale_inv),
                 reinterpret_cast<Tensor*>(output),
                 stream);
}