cast.cu 5.57 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
Przemek Tredak's avatar
Przemek Tredak committed
10
#include <transformer_engine/cast.h>
11

12
13
14
15
#include <cfloat>
#include <limits>
#include <string>

Przemek Tredak's avatar
Przemek Tredak committed
16
#include "../common.h"
17
#include "../transpose/cast_transpose.h"
Przemek Tredak's avatar
Przemek Tredak committed
18
#include "../util/vectorized_pointwise.h"
19
#include "../utils.cuh"
20
21
22
23
24
25
26
27
28
29
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"

void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
30

31
32
33
34
35
36
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
Przemek Tredak's avatar
Przemek Tredak committed
37

38
39
40
  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output,
                                                                     dbias, workspace, stream);
}
Przemek Tredak's avatar
Przemek Tredak committed
41

42
43
44
45
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
                        cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_noop);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
46

47
48
49
50
51
52
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
Przemek Tredak's avatar
Przemek Tredak committed
53

54
55
56
57
58
59
60
61
  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output,
                                                                     dbias, workspace, stream);
}

void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
                         NVTETensor workspace, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
62

63
64
65
66
67
68
69
  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr const NVTETensor activation_input = nullptr;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
70
71
}

72
73
74
75
76
77
78
79
80
81
82
83
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dgelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
84
85
}

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsilu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
}

void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_drelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
112
113
}

114
115
116
117
118
119
120
121
122
123
124
125
126
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
                                NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                                cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
}
Przemek Tredak's avatar
Przemek Tredak committed
127

128
129
130
131
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
                                NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                                cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
Przemek Tredak's avatar
Przemek Tredak committed
132
  using namespace transformer_engine;
133
134
135
136
137
138
139

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
140
141
}

142
143
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_dequantize);
Przemek Tredak's avatar
Przemek Tredak committed
144
  using namespace transformer_engine;
145
146
  detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input),
                            reinterpret_cast<Tensor *>(output), stream);
Przemek Tredak's avatar
Przemek Tredak committed
147
}