modules.h 9.46 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>

Tim Moon's avatar
Tim Moon committed
15
#include <cuda_runtime_api.h>
16
17
18
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Tim Moon's avatar
Tim Moon committed
19
20
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
21
#include <transformer_engine/activation.h>
22
#include "common/util/logging.h"
23
24
25
26
27
28

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

29
30
size_t get_activation_len(NVTE_Activation_Type activation_enum);

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
struct Shape {
    int num_dim;
    size_t dims[kMaxNumDim];

    void from_vector(const std::vector<size_t> &shape) {
        num_dim = shape.size();
        assert(num_dim <= kMaxNumDim);
        std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
    }

    std::vector<size_t> to_vector() const {
        assert(num_dim <= kMaxNumDim);
        std::vector<size_t> shape(num_dim);
        std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
        return shape;
    }
};

49
50
51
52
53
54
55
56
57
enum class NVTE_Activation_Enum {
  GELU,
  GEGLU,
  SILU,
  SWIGLU,
};

size_t get_activation_len(NVTE_Activation_Enum act_enum);

58
59
60
61
struct CustomCallCommonDescriptor {
    Shape shape;
    DType in_dtype;
    DType out_dtype;
62
    size_t act_enum;
63
64
65
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
66
                                               DType out_dtype, size_t act_enum = 0);
67

68
69
70
71
72
73
struct CustomCallCommonWkDescriptor {
    Shape shape;
    Shape wkshape;
    DType in_dtype;
    DType out_dtype;
    DType wk_dtype;
74
    size_t act_enum;
75
76
77
};

pybind11::bytes PackCustomCallCommonWkDescriptor(const std::vector<size_t> &shape,
78
79
80
                                                 const std::vector<size_t> &wkshape,
                                                 DType in_dtype, DType out_dtype, DType wk_dtype,
                                                 size_t act_enum = 0);
81

82
struct CustomCallNormDescriptor {
83
84
85
86
    size_t batch_size;
    size_t hidden_size;
    size_t wkspace_size;
    size_t barrier_size;
87
88
    Shape dgamma_part_shape;
    Shape dbeta_part_shape;
89
90
    DType x_dtype;
    DType w_dtype;
91
92
93
94
    DType wkspace_dtype;
    DType barrier_dtype;
    DType dgamma_part_dtype;
    DType dbeta_part_dtype;
95
    bool zero_centered_gamma;
96
    float eps;
97
    int sm_margin;
98
99
};

100
101
102
103
104
pybind11::bytes PackCustomCallNormDescriptor(
    size_t batch_size, size_t hidden_size, size_t wkspace_size, size_t barrier_size,
    const std::vector<size_t> &dgamma_part_shape, const std::vector<size_t> &dbeta_part_shape,
    DType x_dtype, DType w_dtype, DType wkspace_dtype, DType barrier_dtype, DType dgamma_part_dtype,
    DType dbeta_part_dtype, bool zero_centered_gamma, float eps, int sm_margin);
105
106

struct SoftmaxDescriptor {
107
108
109
    size_t batch_size;
    size_t padding_size;
    size_t head_dim;
110
111
112
113
114
115
    size_t q_seqlen;
    size_t k_seqlen;
    DType dtype;
    float scale_factor;
};

116
117
118
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
119

120
struct CustomCallFusedAttnDescriptor {
121
122
    size_t input_batch;
    size_t bias_batch;
123
124
    size_t q_max_seqlen;
    size_t kv_max_seqlen;
125
    size_t attn_heads;
126
    size_t num_gqa_groups;
127
    size_t bias_heads;
128
    size_t head_dim;
129
    size_t wkspace_size;
130
131
132
133
    float scaling_factor;
    float dropout_probability;
    NVTE_Bias_Type bias_type;
    NVTE_Mask_Type mask_type;
134
    NVTE_QKV_Layout qkv_layout;
135
    DType dtype;
136
    DType wkspace_dtype;
137
138
139
140
    bool is_training;
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
141
    size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
142
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
143
144
145
    size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype,
    bool is_training);
146

147
148
149
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
150
                                            size_t q_num_heads, size_t kv_num_heads,
151
152
153
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim);

154
155
156
157
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

158
void ActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
159

160
void ActLuFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
161

162
void DActLu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);
163

164
pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
165
166
                                                         DType in_dtype, DType out_dtype);

167
void DActLuDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
168
169
                             size_t opaque_len);

170
171
172
173
174
175
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                         DType in_dtype, DType out_dtype);

void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len);

176
void DGatedActLuCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
177
178
                             size_t opaque_len);

179
180
181
182
pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                  DType in_dtype, DType w_dtype, DType out_dtype,
                                                  bool is_layer_norm, bool zero_centered_gamma,
                                                  float eps);
183
184
185
186
187
188

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

189
190
191
192
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size,
                                                   DType in_dtype, DType w_dtype,
                                                   bool is_layer_norm, bool zero_centered_gamma,
                                                   float eps);
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

224
pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
225
226
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
227
228
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
229
230
231
232

void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
233
234
    size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
235
236
    float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
    NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training);
237
238
239

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

240
241
242
243
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_