extensions.h 10 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>
14
15
16
#include <stdexcept>
#include <string>
#include <iostream>
17

Tim Moon's avatar
Tim Moon committed
18
#include <cuda_runtime_api.h>
19
20
21
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Tim Moon's avatar
Tim Moon committed
22
23
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
24
#include <transformer_engine/activation.h>
25
#include "common/common.h"
26
#include "common/util/logging.h"
27
28
29
30
31
32
#include "utils.h"

#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cudnn.h>
33
34
35
36
37
38

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

39

40
// TODO: Rename Shape to ???
41
42
43
44
struct Shape {
    int num_dim;
    size_t dims[kMaxNumDim];

45
    void from_vector(const std::vector<size_t> &shape);
46

47
    std::vector<size_t> to_vector() const;
48
49
};

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
inline bool use_fp8(DType type) {
  return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
}
// 2.
template <typename T>
pybind11::bytes PackOpaque(const T &descriptor) {
    auto str = std::string(reinterpret_cast<const char *>(&descriptor), sizeof(T));
    return pybind11::bytes(str);
}
// 3.
template <typename T>
const T *UnpackOpaque(const char *opaque, size_t opaque_len) {
    if (opaque_len != sizeof(T)) {
        throw std::runtime_error("Invalid opaque object size");
    }
    return reinterpret_cast<const T *>(opaque);
}

std::vector<size_t> MakeShapeVector(NVTEShape shape);
71

72
// Packing
73

74
75
76
77
struct CustomCallCommonDescriptor {
    Shape shape;
    DType in_dtype;
    DType out_dtype;
78
    size_t act_enum;
79
80
81
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
82
                                               DType out_dtype, size_t act_enum = 0);
83

84
85
86
87
88
89
struct CustomCallCommonWkDescriptor {
    Shape shape;
    Shape wkshape;
    DType in_dtype;
    DType out_dtype;
    DType wk_dtype;
90
    size_t act_enum;
91
92
93
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
94
95
96
                                                 const std::vector<size_t> &wkshape,
                                                 DType in_dtype, DType out_dtype, DType wk_dtype,
                                                 size_t act_enum = 0);
97

98
struct CustomCallNormDescriptor {
99
100
101
102
    size_t batch_size;
    size_t hidden_size;
    size_t wkspace_size;
    size_t barrier_size;
103
104
    Shape dgamma_part_shape;
    Shape dbeta_part_shape;
105
106
    DType x_dtype;
    DType w_dtype;
107
108
109
110
    DType wkspace_dtype;
    DType barrier_dtype;
    DType dgamma_part_dtype;
    DType dbeta_part_dtype;
111
    bool zero_centered_gamma;
112
    float eps;
113
    int sm_margin;
114
115
};

116
117
118
119
120
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
121
122

struct SoftmaxDescriptor {
123
124
125
    size_t batch_size;
    size_t padding_size;
    size_t head_dim;
126
127
128
129
130
131
    size_t q_seqlen;
    size_t k_seqlen;
    DType dtype;
    float scale_factor;
};

132
133
134
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
135

136
struct CustomCallFusedAttnDescriptor {
137
138
    size_t input_batch;
    size_t bias_batch;
139
140
    size_t q_max_seqlen;
    size_t kv_max_seqlen;
141
    size_t attn_heads;
142
    size_t num_gqa_groups;
143
    size_t bias_heads;
144
    size_t head_dim;
145
    size_t wkspace_size;
146
147
148
149
    float scaling_factor;
    float dropout_probability;
    NVTE_Bias_Type bias_type;
    NVTE_Mask_Type mask_type;
150
    NVTE_QKV_Layout qkv_layout;
151
    DType dtype;
152
    DType wkspace_dtype;
153
154
155
156
    bool is_training;
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
157
    size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
158
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
159
160
161
    size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
    bool is_training);
162

163
// Transpose
164

165
166
167
168
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

169
170
171
172
173
174
175
176
177
178
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                         DType in_dtype, DType out_dtype);

void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len);

// Activation

size_t get_activation_len(NVTE_Activation_Type activation_enum);

179
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
180

181
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
182

183
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
184

185
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
186
187
                                                         DType in_dtype, DType out_dtype);

188
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
189
190
                             size_t opaque_len);

191
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
192
193
                             size_t opaque_len);

194
195
// Normalization

196
197
198
199
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps);
200
201
202
203
204
205

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

206
207
208
209
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps);
210

211
212
213
214
215
216
217
218
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

219
220
// Quantization

221
222
223
224
void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

225
226
// Softmax

227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

245
246
247
248
249
250
251
252
253
// Attention

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
                                            size_t q_num_heads, size_t kv_num_heads,
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim);

254
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
255
256
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
257
258
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
259
260
261
262

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
263
264
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
265
266
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
267
268
269

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

270
271
272
273
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_