"qa/L0_tensorflow_unittest/test.sh" did not exist on "64a8dc900840e89ffd17e1536b377e3c32f26d93"
extensions.h 10.6 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
16
17
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/transformer_engine.h>

18
19
20
#include <cassert>
#include <cstddef>
#include <cstdint>
21
#include <iostream>
22
23
#include <stdexcept>
#include <string>
24
#include <vector>
25

26
#include "common/common.h"
27
#include "common/util/logging.h"
28
29
30
#include "extensions/ffi.h"
#include "extensions/misc.h"
#include "transformer_engine/activation.h"
31
32
#include "utils.h"

33
34
35
namespace transformer_engine {
namespace jax {

36
37
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
38
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
39
40
41
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
42
43
  auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
  return pybind11::bytes(str);
44
45
46
47
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
48
49
50
51
  if (opaque_len != sizeof(T)) {
    throw std::runtime_error("Invalid opaque object size");
  }
  return reinterpret_cast<const T *>(opaque);
52
53
54
}

// Packing
55

56
struct CustomCallCommonDescriptor {
57
58
59
60
  Shape shape;
  DType in_dtype;
  DType out_dtype;
  size_t act_enum;
61
62
63
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
64
                                               DType out_dtype, size_t act_enum = 0);
65

66
struct CustomCallCommonWkDescriptor {
67
68
69
70
71
72
  Shape shape;
  Shape wkshape;
  DType in_dtype;
  DType out_dtype;
  DType wk_dtype;
  size_t act_enum;
73
74
75
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
76
77
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
                                                 DType out_dtype, DType wk_dtype,
78
                                                 size_t act_enum = 0);
79

80
struct CustomCallNormDescriptor {
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  size_t batch_size;
  size_t hidden_size;
  size_t wkspace_size;
  size_t barrier_size;
  Shape dgamma_part_shape;
  Shape dbeta_part_shape;
  DType x_dtype;
  DType w_dtype;
  DType wkspace_dtype;
  DType barrier_dtype;
  DType dgamma_part_dtype;
  DType dbeta_part_dtype;
  bool zero_centered_gamma;
  float eps;
  int sm_margin;
96
97
};

98
99
100
101
102
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
103
104

struct SoftmaxDescriptor {
105
106
107
108
109
110
111
  size_t batch_size;
  size_t padding_size;
  size_t head_dim;
  size_t q_seqlen;
  size_t k_seqlen;
  DType dtype;
  float scale_factor;
112
113
};

114
115
116
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
117

118
struct CustomCallFusedAttnDescriptor {
119
120
121
122
123
124
125
126
  size_t input_batch;
  size_t bias_batch;
  size_t q_max_seqlen;
  size_t kv_max_seqlen;
  size_t attn_heads;
  size_t num_gqa_groups;
  size_t bias_heads;
  size_t head_dim;
127
  size_t max_segments_per_seq;
128
129
130
131
132
133
134
135
136
  size_t wkspace_size;
  float scaling_factor;
  float dropout_probability;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
  NVTE_QKV_Layout qkv_layout;
  DType dtype;
  DType wkspace_dtype;
  bool is_training;
137
  bool deterministic;
138
139
  int64_t window_size_left;
  int64_t window_size_right;
140
141
142
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
143
    size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
144
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
145
146
    size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
147
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
148
    bool deterministic, int64_t window_size_left, int64_t window_size_right);
149

150
// Transpose
151

152
153
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

154
155
XLA_FFI_DECLARE_HANDLER_SYMBOL(TransposeHandler);

156
157
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

158
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
159
                                                    DType in_dtype, DType out_dtype);
160

161
162
XLA_FFI_DECLARE_HANDLER_SYMBOL(CastTransposeHandler);

163
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
164
165
166
167
168

// Activation

size_t get_activation_len(NVTE_Activation_Type activation_enum);

169
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
170

171
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
172

173
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
174

175
176
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);

177
178
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuFP8Handler);

179
180
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuHandler);

181
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
182
                                                        DType in_dtype, DType out_dtype);
183

184
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
185
                              size_t opaque_len);
186

187
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
188
                              size_t opaque_len);
189

190
191
// Normalization

192
193
194
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
195
                                                  float eps, int sm_margin);
196
197
198
199
200
201

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

202
203
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormForwardFP8Handler);

204
205
206
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
207
                                                   float eps, int sm_margin);
208

209
210
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

211
212
XLA_FFI_DECLARE_HANDLER_SYMBOL(LayerNormBackwardHandler);

213
214
215
216
217
218
void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

219
220
// Quantization

221
222
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

223
224
XLA_FFI_DECLARE_HANDLER_SYMBOL(QuantizeHandler);

225
226
void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

227
228
// Softmax

229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

247
248
249
250
251
252
253
// Attention

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
254
255
                                            size_t head_dim, int64_t window_size_left,
                                            int64_t window_size_right);
256

257
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
258
259
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
260
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
261
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
262
    size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right);
263
264
265

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

266
267
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);

268
pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
269
270
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
271
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
272
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
273
274
    bool deterministic, size_t max_segments_per_seq, int64_t window_size_left,
    int64_t window_size_right);
275
276
277

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

278
279
280
281
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_