modules.h 8.64 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <vector>

Tim Moon's avatar
Tim Moon committed
15
#include <cuda_runtime_api.h>
16
17
18
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Tim Moon's avatar
Tim Moon committed
19
20
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>
21
#include "common/util/logging.h"
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

namespace transformer_engine {
namespace jax {

constexpr int kMaxNumDim = 8;

struct Shape {
    int num_dim;
    size_t dims[kMaxNumDim];

    void from_vector(const std::vector<size_t> &shape) {
        num_dim = shape.size();
        assert(num_dim <= kMaxNumDim);
        std::memcpy(dims, shape.data(), num_dim * sizeof(size_t));
    }

    std::vector<size_t> to_vector() const {
        assert(num_dim <= kMaxNumDim);
        std::vector<size_t> shape(num_dim);
        std::memcpy(shape.data(), dims, num_dim * sizeof(size_t));
        return shape;
    }
};

struct CustomCallCommonDescriptor {
    Shape shape;
    DType in_dtype;
    DType out_dtype;
};

pybind11::bytes PackCustomCallCommonDescriptor(const std::vector<size_t> &shape, DType in_dtype,
                                               DType out_dtype);

struct CustomCallNormDescriptor {
56
57
58
59
60
61
    size_t batch_size;
    size_t hidden_size;
    size_t wkspace_size;
    size_t barrier_size;
    size_t *dgamma_part_sizes;  // 2D tensor
    size_t *dbeta_part_sizes;   // 2D tensor
62
63
    DType x_dtype;
    DType w_dtype;
64
65
66
67
    DType wkspace_dtype;
    DType barrier_dtype;
    DType dgamma_part_dtype;
    DType dbeta_part_dtype;
68
    bool zero_centered_gamma;
69
    float eps;
70
    int sm_margin;
71
72
};

73
74
75
76
77
78
pybind11::bytes PackCustomCallNormDescriptor(size_t batch_size, size_t hidden_size,
                                             size_t wkspace_size, size_t barrier_size,
                                             size_t *dgamma_part_sizes, size_t *dbeta_part_sizes,
                                             DType x_dtype, DType w_dtype,
                                             DType wkspace_dtype, DType barrier_dtype,
                                             DType dgamma_part_dtype, DType dbeta_part_dtype,
79
                                             bool zero_centered_gamma, float eps, int sm_margin);
80
81

struct SoftmaxDescriptor {
82
83
84
    size_t batch_size;
    size_t padding_size;
    size_t head_dim;
85
86
87
88
89
90
    size_t q_seqlen;
    size_t k_seqlen;
    DType dtype;
    float scale_factor;
};

91
92
93
pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t padding_size,
                                                size_t head_dim, size_t q_seqlen, size_t k_seqlen,
                                                DType dtype, float scale_factor);
94

95
struct CustomCallFusedAttnDescriptor {
96
    size_t batch_size;
97
98
    size_t q_max_seqlen;
    size_t kv_max_seqlen;
99
100
    size_t num_heads;
    size_t num_gqa_groups;
101
    size_t head_dim;
102
    size_t wkspace_size;
103
104
105
106
107
    float scaling_factor;
    float dropout_probability;
    NVTE_Bias_Type bias_type;
    NVTE_Mask_Type mask_type;
    DType dtype;
108
    DType wkspace_dtype;
109
110
111
112
    bool is_training;
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
113
114
115
116
117
    size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t num_heads, size_t num_gqa_groups, size_t head_dim, size_t wkspace_size,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
    DType dtype, DType wkspace_dtype, bool is_training);
118

119
120
121
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
                                            NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                            NVTE_Mask_Type mask_type, float dropout_probability,
122
                                            size_t q_num_heads, size_t kv_num_heads,
123
124
125
                                            size_t q_max_seqlen, size_t kv_max_seqlen,
                                            size_t head_dim);

126
127
128
129
130
131
132
133
134
135
136
137
138
void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void DGatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void DGatedGeluCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
                             size_t opaque_len);

139
140
141
142
pybind11::tuple GetLayerNormForwardWorkspaceSizes(
    size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, DType out_dtype,
    bool is_layer_norm, bool zero_centered_gamma, float eps
);
143
144
145
146
147
148

void LayerNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void LayerNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque,
                         size_t opaque_len);

149
150
151
152
153
pybind11::tuple GetLayerNormBackwardWorkspaceSizes(
    size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, bool is_layer_norm,
    bool zero_centered_gamma, float eps
);

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
void LayerNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormForwardFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void RMSNormBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Quantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void Dequantize(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          std::size_t opaque_len);

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           std::size_t opaque_len);

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                std::size_t opaque_len);

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           std::size_t opaque_len);

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            std::size_t opaque_len);

184
185
186
187
188
189
pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes(
    size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);

190
191
void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len);
192

193
194
195
196
197
198
pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes(
    size_t batch_size, size_t max_seqlen, size_t num_heads, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);

199
200
void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len);
201

202
203
204
205
206
207
208
pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes(
    size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t num_heads, size_t num_gqa_groups, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);

209
210
void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len);
211

212
213
214
215
216
217
218
pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes(
    size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
    size_t num_heads, size_t num_gqa_groups, size_t head_dim,
    float scaling_factor, float dropout_probability,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training
);

219
220
void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
                            size_t opaque_len);
221

222
223
224
225
}  // namespace jax
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_