cast.cu 5.61 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
#include <cuda.h>
yuguo's avatar
yuguo committed
8
#ifndef __HIP_PLATFORM_AMD__
9
#include <cudaTypedefs.h>
yuguo's avatar
yuguo committed
10
#endif
11
#include <cuda_runtime.h>
Przemek Tredak's avatar
Przemek Tredak committed
12
#include <transformer_engine/cast.h>
13

14
15
16
17
#include <cfloat>
#include <limits>
#include <string>

Przemek Tredak's avatar
Przemek Tredak committed
18
#include "../common.h"
19
#include "../transpose/cast_transpose.h"
Przemek Tredak's avatar
Przemek Tredak committed
20
#include "../util/vectorized_pointwise.h"
21
#include "../utils.cuh"
22
23
24
25
26
27
28
29
30
31
#include "cast_kernels.cuh"
#include "dequantize_kernels.cuh"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transpose.h"

void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
32

33
34
35
36
37
38
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
Przemek Tredak's avatar
Przemek Tredak committed
39

40
41
42
  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, nullptr, output,
                                                                     dbias, workspace, stream);
}
Przemek Tredak's avatar
Przemek Tredak committed
43

44
45
46
47
void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop,
                        cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_noop);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
48

49
50
51
52
53
54
  constexpr bool IS_DBIAS = false;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr NVTETensor dbias = nullptr;
  constexpr NVTETensor workspace = nullptr;
  constexpr const NVTETensor grad = nullptr;
Przemek Tredak's avatar
Przemek Tredak committed
55

56
57
58
59
60
61
62
63
  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(input, grad, noop, output,
                                                                     dbias, workspace, stream);
}

void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
                         NVTETensor workspace, cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias);
  using namespace transformer_engine;
Przemek Tredak's avatar
Przemek Tredak committed
64

65
66
67
68
69
70
71
  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = false;
  constexpr bool IS_ACT = false;
  constexpr const NVTETensor activation_input = nullptr;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, nullptr>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
72
73
}

74
75
76
77
78
79
80
81
82
83
84
85
void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dgelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dgelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
86
87
}

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsilu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsilu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
}

void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input,
                               NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                               cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_drelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, drelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
114
115
}

116
117
118
119
120
121
122
123
124
125
126
127
128
void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input,
                                NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                                cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dqgelu);
  using namespace transformer_engine;

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dqgelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
}
Przemek Tredak's avatar
Przemek Tredak committed
129

130
131
132
133
void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input,
                                NVTETensor output, NVTETensor dbias, NVTETensor workspace,
                                cudaStream_t stream) {
  NVTE_API_CALL(nvte_quantize_dbias_dsrelu);
Przemek Tredak's avatar
Przemek Tredak committed
134
  using namespace transformer_engine;
135
136
137
138
139
140
141

  constexpr bool IS_DBIAS = true;
  constexpr bool IS_DACT = true;
  constexpr bool IS_ACT = false;

  detail::quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, dsrelu<fp32, fp32>>(
      activation_input, input, nullptr, output, dbias, workspace, stream);
Przemek Tredak's avatar
Przemek Tredak committed
142
143
}

144
145
void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
  NVTE_API_CALL(nvte_dequantize);
Przemek Tredak's avatar
Przemek Tredak committed
146
  using namespace transformer_engine;
147
148
  detail::dequantize_helper(*reinterpret_cast<const Tensor *>(input),
                            reinterpret_cast<Tensor *>(output), stream);
Przemek Tredak's avatar
Przemek Tredak committed
149
}