cast.cu 5.98 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include <cuda.h>
yuguo's avatar
yuguo committed
8
#ifndef __HIP_PLATFORM_AMD__
9
#include <cudaTypedefs.h>
yuguo's avatar
yuguo committed
10
#endif
11
#include <cuda_runtime.h>
Przemek Tredak's avatar
Przemek Tredak committed
12
#include <transformer_engine/cast.h>
13

14
15
16
17
#include <cfloat>
#include <limits>
#include <string>

Przemek Tredak's avatar
Przemek Tredak committed
18
#include "../common.h"
19
#include "../transpose/cast_transpose.h"
Przemek Tredak's avatar
Przemek Tredak committed
20
#include "../util/vectorized_pointwise.h"
21
#include "../utils.cuh"
22
23
24
25
26
27
28
29
30
31
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"

void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
32

33
34
35
36
37
38
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
Przemek Tredak's avatar
Przemek Tredak committed
39

40
41
  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, output, dbias,
                                                                     workspace, nullptr, stream);
42
}
Przemek Tredak's avatar
Przemek Tredak committed
43

44
45
46
47
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
                        cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_noop);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
48

49
50
51
52
53
54
55
56
57
58
59
60
  // Create config with noop tensor
  QuantizationConfig quant_config;
  quant_config.noop_tensor = noop;

  nvte_quantize_v2(input, output, reinterpret_cast<NVTEQuantizationConfig>(&quant_config), stream);
}

void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
                      const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_v2);
  using namespace transformer_engine;

61
62
63
64
65
66
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
Przemek Tredak's avatar
Przemek Tredak committed
67

68
69
  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
      input, grad, output, dbias, workspace, quant_config, stream);
70
71
72
73
74
75
}

void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
                         NVTETensor workspace, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
76

77
78
79
80
81
82
  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr const NVTETensor activation_input = nullptr;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
83
      activation_input, input, output, dbias, workspace, nullptr, stream);
Przemek Tredak's avatar
Przemek Tredak committed
84
85
}

86
87
88
89
90
91
92
93
94
95
96
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dgelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
97
      activation_input, input, output, dbias, workspace, nullptr, stream);
Przemek Tredak's avatar
Przemek Tredak committed
98
99
}

100
101
102
103
104
105
106
107
108
109
110
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsilu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
111
      activation_input, input, output, dbias, workspace, nullptr, stream);
112
113
114
115
116
117
118
119
120
121
122
123
124
}

void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_drelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
125
      activation_input, input, output, dbias, workspace, nullptr, stream);
Przemek Tredak's avatar
Przemek Tredak committed
126
127
}

128
129
130
131
132
133
134
135
136
137
138
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
                                NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                                cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
139
      activation_input, input, output, dbias, workspace, nullptr, stream);
140
}
Przemek Tredak's avatar
Przemek Tredak committed
141

142
143
144
145
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
                                NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                                cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
Przemek Tredak's avatar
Przemek Tredak committed
146
  using namespace transformer_engine;
147
148
149
150
151
152

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
153
      activation_input, input, output, dbias, workspace, nullptr, stream);
Przemek Tredak's avatar
Przemek Tredak committed
154
155
}

156
157
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_dequantize);
Przemek Tredak's avatar
Przemek Tredak committed
158
  using namespace transformer_engine;
159
160
  detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input),
                            reinterpret_cast<Tensor *>(output), stream);
Przemek Tredak's avatar
Przemek Tredak committed
161
}