modules.h 9.35 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>

Tim Moon's avatar
Tim Moon committed
15
#include <cuda_runtime_api.h>
16
17
18
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Tim Moon's avatar
Tim Moon committed
19
20
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
21
#include "common/util/logging.h"
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

struct Shape {
    int num_dim;
    size_t dims[kMaxNumDim];

    void from_vector(const std::vector<size_t> &shape) {
        num_dim = shape.size();
        assert(num_dim <= kMaxNumDim);
        std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
    }

    std::vector<size_t> to_vector() const {
        assert(num_dim <= kMaxNumDim);
        std::vector<size_t> shape(num_dim);
        std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
        return shape;
    }
};

46
47
48
49
50
51
52
53
54
enum class NVTE_Activation_Enum {
  GELU,
  GEGLU,
  SILU,
  SWIGLU,
};

size_t get_activation_len(NVTE_Activation_Enum act_enum);

55
56
57
58
struct CustomCallCommonDescriptor {
    Shape shape;
    DType in_dtype;
    DType out_dtype;
59
    size_t act_enum;
60
61
62
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
63
                                               DType out_dtype, size_t act_enum = 0);
64

65
66
67
68
69
70
struct CustomCallCommonWkDescriptor {
    Shape shape;
    Shape wkshape;
    DType in_dtype;
    DType out_dtype;
    DType wk_dtype;
71
    size_t act_enum;
72
73
74
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
75
76
77
                                                 const std::vector<size_t> &wkshape,
                                                 DType in_dtype, DType out_dtype, DType wk_dtype,
                                                 size_t act_enum = 0);
78

79
struct CustomCallNormDescriptor {
80
81
82
83
    size_t batch_size;
    size_t hidden_size;
    size_t wkspace_size;
    size_t barrier_size;
84
85
    Shape dgamma_part_shape;
    Shape dbeta_part_shape;
86
87
    DType x_dtype;
    DType w_dtype;
88
89
90
91
    DType wkspace_dtype;
    DType barrier_dtype;
    DType dgamma_part_dtype;
    DType dbeta_part_dtype;
92
    bool zero_centered_gamma;
93
    float eps;
94
    int sm_margin;
95
96
};

97
98
99
100
101
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
102
103

struct SoftmaxDescriptor {
104
105
106
    size_t batch_size;
    size_t padding_size;
    size_t head_dim;
107
108
109
110
111
112
    size_t q_seqlen;
    size_t k_seqlen;
    DType dtype;
    float scale_factor;
};

113
114
115
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
116

117
struct CustomCallFusedAttnDescriptor {
118
119
    size_t input_batch;
    size_t bias_batch;
120
121
    size_t q_max_seqlen;
    size_t kv_max_seqlen;
122
    size_t attn_heads;
123
    size_t num_gqa_groups;
124
    size_t bias_heads;
125
    size_t head_dim;
126
    size_t wkspace_size;
127
128
129
130
    float scaling_factor;
    float dropout_probability;
    NVTE_Bias_Type bias_type;
    NVTE_Mask_Type mask_type;
131
    NVTE_QKV_Layout qkv_layout;
132
    DType dtype;
133
    DType wkspace_dtype;
134
135
136
137
    bool is_training;
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
138
    size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
139
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
140
141
142
    size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
    bool is_training);
143

144
145
146
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
147
                                            size_t q_num_heads, size_t kv_num_heads,
148
149
150
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim);

151
152
153
154
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

155
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
156

157
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
158

159
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
160

161
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
162
163
                                                         DType in_dtype, DType out_dtype);

164
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
165
166
                             size_t opaque_len);

167
168
169
170
171
172
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                         DType in_dtype, DType out_dtype);

void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len);

173
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
174
175
                             size_t opaque_len);

176
177
178
179
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps);
180
181
182
183
184
185

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

186
187
188
189
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps);
190

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

221
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
222
223
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
224
225
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
226
227
228
229

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
230
231
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
232
233
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
234
235
236

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

237
238
239
240
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_