extensions.h 11.4 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
16
17
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/transformer_engine.h>

18
19
20
#include <cassert>
#include <cstddef>
#include <cstdint>
21
#include <iostream>
22
23
#include <stdexcept>
#include <string>
24
#include <vector>
25

26
#include "common/common.h"
27
#include "common/util/logging.h"
28
29
30
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "transformer_engine/activation.h"
31
32
#include "utils.h"

33
34
35
namespace transformer_engine {
namespace jax {

36
37
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
38
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
39
40
41
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
42
43
  auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
  return pybind11::bytes(str);
44
45
46
47
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
48
49
50
51
  if (opaque_len != sizeof(T)) {
    throw std::runtime_error("Invalid opaque object size");
  }
  return reinterpret_cast<const T *>(opaque);
52
53
54
}

// Packing
55

56
struct CustomCallCommonDescriptor {
57
58
59
60
  Shape shape;
  DType in_dtype;
  DType out_dtype;
  size_t act_enum;
61
62
63
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
64
                                               DType out_dtype, size_t act_enum = 0);
65

66
struct CustomCallCommonWkDescriptor {
67
68
69
70
71
72
  Shape shape;
  Shape wkshape;
  DType in_dtype;
  DType out_dtype;
  DType wk_dtype;
  size_t act_enum;
73
74
75
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
76
77
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
                                                 DType out_dtype, DType wk_dtype,
78
                                                 size_t act_enum = 0);
79

80
struct CustomCallNormDescriptor {
81
82
83
84
85
86
87
88
89
  size_t batch_size;
  size_t hidden_size;
  size_t wkspace_size;
  DType x_dtype;
  DType w_dtype;
  DType wkspace_dtype;
  bool zero_centered_gamma;
  float eps;
  int sm_margin;
90
91
};

92
93
94
95
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
                                             size_t wkspace_size, DType x_dtype, DType w_dtype,
                                             DType wkspace_dtype, bool zero_centered_gamma,
                                             float eps, int sm_margin);
96
97

struct SoftmaxDescriptor {
98
99
100
101
102
103
104
  size_t batch_size;
  size_t padding_size;
  size_t head_dim;
  size_t q_seqlen;
  size_t k_seqlen;
  DType dtype;
  float scale_factor;
105
106
};

107
108
109
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
110

111
struct CustomCallFusedAttnDescriptor {
112
113
114
115
116
117
118
119
  size_t input_batch;
  size_t bias_batch;
  size_t q_max_seqlen;
  size_t kv_max_seqlen;
  size_t attn_heads;
  size_t num_gqa_groups;
  size_t bias_heads;
  size_t head_dim;
120
  size_t max_segments_per_seq;
121
122
123
124
125
126
127
128
129
  size_t wkspace_size;
  float scaling_factor;
  float dropout_probability;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
  NVTE_QKV_Layout qkv_layout;
  DType dtype;
  DType wkspace_dtype;
  bool is_training;
130
  bool deterministic;
131
132
  int64_t window_size_left;
  int64_t window_size_right;
133
134
135
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
136
    size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
137
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
138
139
    size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
140
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
141
    bool deterministic, int64_t window_size_left, int64_t window_size_right);
142

143
// Transpose
144

145
146
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

147
148
XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);

149
150
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

151
152
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);

153
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
154
                                                    DType in_dtype, DType out_dtype);
155

156
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
157

158
159
XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasCastTransposeHandler);

160
161
162
163
// Activation

size_t get_activation_len(NVTE_Activation_Type activation_enum);

164
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
165

166
167
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);

168
169
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

170
171
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);

172
173
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

174
175
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);

176
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
177
                                                        DType in_dtype, DType out_dtype);
178

179
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
180
                              size_t opaque_len);
181

182
183
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasCastTransposeHandler);

184
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
185
                              size_t opaque_len);
186

187
188
XLA_FFI_DECLARE_HANDLER_SYMBOL(DGatedActLuCastTransposeHandler);

189
190
// Normalization

191
192
193
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
194
                                                  float eps, int sm_margin);
195
196
197

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

198
199
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardHandler);

200
201
202
void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

203
204
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler);

205
206
207
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
208
                                                   float eps, int sm_margin);
209

210
211
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

212
213
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);

214
215
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

216
217
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardHandler);

218
219
void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

220
221
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormForwardFP8Handler);

222
223
void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

224
225
XLA_FFI_DECLARE_HANDLER_SYMBOL(RMSNormBackwardHandler);

226
227
// Quantization

228
229
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

230
231
XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);

232
233
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

234
235
XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler);

236
237
// Softmax

238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

256
257
258
259
260
261
262
263
264
265
266
267
XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler);

268
269
// Attention

270
271
272
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler);

273
274
275
276
277
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
278
279
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right);
280

281
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
282
283
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
284
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
285
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
286
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
287
288
289

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

290
291
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);

292
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
293
294
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
295
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
296
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
297
298
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right);
299
300
301

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

302
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
303

304
305
306
307
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_