extensions.h 10 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

10
11
12
13
14
15
16
17
18
19
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>

20
21
22
#include <cassert>
#include <cstddef>
#include <cstdint>
23
#include <iostream>
24
25
#include <stdexcept>
#include <string>
26
#include <vector>
27

28
#include "common/common.h"
29
#include "common/util/logging.h"
30
31
#include "utils.h"

32
33
34
35
36
namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

37
// TODO: Rename Shape to ???
38
struct Shape {
39
40
  int num_dim;
  size_t dims[kMaxNumDim];
41

42
  void from_vector(const std::vector<size_t> &shape);
43

44
  std::vector<size_t> to_vector() const;
45
46
};

47
48
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
49
inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; }
50
51
52
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
53
54
  auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
  return pybind11::bytes(str);
55
56
57
58
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
59
60
61
62
  if (opaque_len != sizeof(T)) {
    throw std::runtime_error("Invalid opaque object size");
  }
  return reinterpret_cast<const T *>(opaque);
63
64
65
}

std::vector<size_t> MakeShapeVector(NVTEShape shape);
66

67
// Packing
68

69
struct CustomCallCommonDescriptor {
70
71
72
73
  Shape shape;
  DType in_dtype;
  DType out_dtype;
  size_t act_enum;
74
75
76
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
77
                                               DType out_dtype, size_t act_enum = 0);
78

79
struct CustomCallCommonWkDescriptor {
80
81
82
83
84
85
  Shape shape;
  Shape wkshape;
  DType in_dtype;
  DType out_dtype;
  DType wk_dtype;
  size_t act_enum;
86
87
88
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
89
90
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
                                                 DType out_dtype, DType wk_dtype,
91
                                                 size_t act_enum = 0);
92

93
struct CustomCallNormDescriptor {
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  size_t batch_size;
  size_t hidden_size;
  size_t wkspace_size;
  size_t barrier_size;
  Shape dgamma_part_shape;
  Shape dbeta_part_shape;
  DType x_dtype;
  DType w_dtype;
  DType wkspace_dtype;
  DType barrier_dtype;
  DType dgamma_part_dtype;
  DType dbeta_part_dtype;
  bool zero_centered_gamma;
  float eps;
  int sm_margin;
109
110
};

111
112
113
114
115
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
116
117

struct SoftmaxDescriptor {
118
119
120
121
122
123
124
  size_t batch_size;
  size_t padding_size;
  size_t head_dim;
  size_t q_seqlen;
  size_t k_seqlen;
  DType dtype;
  float scale_factor;
125
126
};

127
128
129
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
130

131
struct CustomCallFusedAttnDescriptor {
132
133
134
135
136
137
138
139
  size_t input_batch;
  size_t bias_batch;
  size_t q_max_seqlen;
  size_t kv_max_seqlen;
  size_t attn_heads;
  size_t num_gqa_groups;
  size_t bias_heads;
  size_t head_dim;
140
  size_t max_segments_per_seq;
141
142
143
144
145
146
147
148
149
  size_t wkspace_size;
  float scaling_factor;
  float dropout_probability;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
  NVTE_QKV_Layout qkv_layout;
  DType dtype;
  DType wkspace_dtype;
  bool is_training;
150
  bool deterministic;
151
152
153
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
154
    size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
155
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
156
157
    size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
    float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
158
159
    NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
    bool deterministic);
160

161
// Transpose
162

163
164
165
166
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

167
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
168
                                                    DType in_dtype, DType out_dtype);
169

170
void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
171
172
173
174
175

// Activation

size_t get_activation_len(NVTE_Activation_Type activation_enum);

176
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
177

178
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
179

180
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
181

182
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
183
                                                        DType in_dtype, DType out_dtype);
184

185
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
186
                              size_t opaque_len);
187

188
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
189
                              size_t opaque_len);
190

191
192
// Normalization

193
194
195
196
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps);
197
198
199
200
201
202

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

203
204
205
206
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps);
207

208
209
210
211
212
213
214
215
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

216
217
// Quantization

218
219
220
221
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

222
223
// Softmax

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

242
243
244
245
246
247
248
249
250
// Attention

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim);

251
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
252
253
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
254
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
255
256
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
    size_t max_segments_per_seq);
257
258
259
260

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
261
262
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
263
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
264
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
265
    bool deterministic, size_t max_segments_per_seq);
266
267
268

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

269
270
271
272
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_