modules.h 11 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>

Tim Moon's avatar
Tim Moon committed
15
#include <cuda_runtime_api.h>
16
17
18
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Tim Moon's avatar
Tim Moon committed
19
20
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
21
#include "common/util/logging.h"
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

struct Shape {
    int num_dim;
    size_t dims[kMaxNumDim];

    void from_vector(const std::vector<size_t> &shape) {
        num_dim = shape.size();
        assert(num_dim <= kMaxNumDim);
        std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
    }

    std::vector<size_t> to_vector() const {
        assert(num_dim <= kMaxNumDim);
        std::vector<size_t> shape(num_dim);
        std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
        return shape;
    }
};

struct CustomCallCommonDescriptor {
    Shape shape;
    DType in_dtype;
    DType out_dtype;
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype);

55
56
57
58
59
60
61
62
63
64
65
66
struct CustomCallCommonWkDescriptor {
    Shape shape;
    Shape wkshape;
    DType in_dtype;
    DType out_dtype;
    DType wk_dtype;
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
                                                 const std::vector<size_t> &wkshape, DType in_dtype,
                                                 DType out_dtype, DType wk_dtype);

67
struct CustomCallNormDescriptor {
68
69
70
71
72
73
    size_t batch_size;
    size_t hidden_size;
    size_t wkspace_size;
    size_t barrier_size;
    size_t *dgamma_part_sizes;  // 2D tensor
    size_t *dbeta_part_sizes;   // 2D tensor
74
75
    DType x_dtype;
    DType w_dtype;
76
77
78
79
    DType wkspace_dtype;
    DType barrier_dtype;
    DType dgamma_part_dtype;
    DType dbeta_part_dtype;
80
    bool zero_centered_gamma;
81
    float eps;
82
    int sm_margin;
83
84
};

85
86
87
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
                                             size_t wkspace_size, size_t barrier_size,
                                             size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
88
89
90
91
                                             DType x_dtype, DType w_dtype, DType wkspace_dtype,
                                             DType barrier_dtype, DType dgamma_part_dtype,
                                             DType dbeta_part_dtype, bool zero_centered_gamma,
                                             float eps, int sm_margin);
92
93

struct SoftmaxDescriptor {
94
95
96
    size_t batch_size;
    size_t padding_size;
    size_t head_dim;
97
98
99
100
101
102
    size_t q_seqlen;
    size_t k_seqlen;
    DType dtype;
    float scale_factor;
};

103
104
105
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
106

107
struct CustomCallFusedAttnDescriptor {
108
109
    size_t input_batch;
    size_t bias_batch;
110
111
    size_t q_max_seqlen;
    size_t kv_max_seqlen;
112
    size_t attn_heads;
113
    size_t num_gqa_groups;
114
    size_t bias_heads;
115
    size_t head_dim;
116
    size_t wkspace_size;
117
118
119
120
121
    float scaling_factor;
    float dropout_probability;
    NVTE_Bias_Type bias_type;
    NVTE_Mask_Type mask_type;
    DType dtype;
122
    DType wkspace_dtype;
123
124
125
126
    bool is_training;
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
127
128
129
130
131
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    size_t wkspace_size, float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    DType dtype, DType wkspace_dtype, bool is_training);
132

133
134
135
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
136
                                            size_t q_num_heads, size_t kv_num_heads,
137
138
139
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim);

140
141
142
143
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

144
145
146
147
148
149
150
151
152
153
154
155
void Gelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void GeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void DGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                         DType in_dtype, DType out_dtype);

void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len);

156
157
158
159
160
161
162
163
164
void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len);

165
166
167
168
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps);
169
170
171
172
173
174

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

175
176
177
178
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps);
179

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

210
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
211
212
213
214
    size_t input_batch, size_t bias_batch, size_t max_seqlen,
    size_t attn_heads, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
215

216
217
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len);
218

219
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
220
221
222
223
    size_t input_batch, size_t bias_batch, size_t max_seqlen,
    size_t attn_heads, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
224

225
226
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len);
227

228
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
229
230
231
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
232
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
233

234
235
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len);
236

237
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
238
239
240
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
241
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);
242

243
244
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                            size_t opaque_len);
245

246
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
247
248
249
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
250
251
252
253
254
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
255
256
257
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
258
259
260
261
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training);

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

262
263
264
265
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_