extensions.h 40.5 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

10
11
#include <optional>

Przemek Tredak's avatar
Przemek Tredak committed
12
#include "common.h"
Tim Moon's avatar
Tim Moon committed
13
#include "common/common.h"
Przemek Tredak's avatar
Przemek Tredak committed
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/***************************************************************************************************
 * Permutation
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
    at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
    int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);

at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
                           at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
                           int64_t topK);

at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
                             at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
                             int64_t topK);

std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
                                                     const transformer_engine::DType dtype,
                                                     at::Tensor row_id_map, at::Tensor prob);

35
36
37
38
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

39
40
41
42
43
44
45
46
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
                                               const transformer_engine::DType kv_dtype,
                                               NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                               NVTE_Mask_Type attn_mask_type, float p_dropout,
                                               size_t num_attn_heads, size_t num_gqa_groups,
                                               size_t max_seqlen_q, size_t max_seqlen_kv,
                                               size_t head_dim_qk, size_t head_dim_v,
                                               int64_t window_size_left, int64_t window_size_right);
cyanguwa's avatar
cyanguwa committed
47
48

std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
49
50
    size_t max_seqlen, bool is_training, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
51
52
    const std::vector<int64_t> window_size, const at::Tensor cu_seqlens, const at::Tensor QKV,
    const transformer_engine::DType qkv_type, const c10::optional<at::Tensor> cu_seqlens_padded,
53
54
55
56
57
58
59
    const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
    const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
    const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
    const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
    c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
    const int amax_O_offset, const c10::optional<at::Tensor> Bias,
    const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
60
61

std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
62
    size_t max_seqlen, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout,
63
64
65
66
    NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
    bool deterministic, const at::Tensor cu_seqlens, const at::Tensor QKV, const at::Tensor O,
    const at::Tensor dO, const transformer_engine::DType qkv_type,
    const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
67
    const c10::optional<at::Tensor> cu_seqlens_padded, const c10::optional<at::Tensor> descale_QKV,
68
69
70
71
72
    const c10::optional<at::Tensor> descale_S, const c10::optional<at::Tensor> descale_O,
    const c10::optional<at::Tensor> descale_dO, const c10::optional<at::Tensor> descale_dP,
    const c10::optional<at::Tensor> scale_S, const c10::optional<at::Tensor> scale_dP,
    const c10::optional<at::Tensor> scale_dQKV, c10::optional<at::Tensor> amax_dP,
    c10::optional<at::Tensor> amax_dQKV);
cyanguwa's avatar
cyanguwa committed
73
74

std::vector<at::Tensor> fused_attn_fwd_kvpacked(
75
76
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
    bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
77
78
79
    NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
    const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
    const at::Tensor KV, const transformer_engine::DType qkv_type,
80
81
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
    const c10::optional<at::Tensor> cu_seqlens_kv_padded,
82
83
84
85
86
87
88
    const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
    const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
    const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
    const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
    c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
    const int amax_O_offset, const c10::optional<at::Tensor> Bias,
    const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
89
90

std::vector<at::Tensor> fused_attn_bwd_kvpacked(
91
92
    size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
93
94
95
96
    const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
    const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor KV, const at::Tensor O,
    const at::Tensor dO, const transformer_engine::DType qkv_type,
    const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
97
98
99
100
101
102
103
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
    const c10::optional<at::Tensor> cu_seqlens_kv_padded,
    const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
    const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
    const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
    const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
    c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
104
105

std::vector<at::Tensor> fused_attn_fwd(
106
107
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
    bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
108
109
110
111
    NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
    const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const at::Tensor Q,
    const at::Tensor K, const at::Tensor V, const transformer_engine::DType qkv_type,
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
112
    const c10::optional<at::Tensor> cu_seqlens_kv_padded,
113
114
115
116
117
118
119
    const c10::optional<at::Tensor> descale_QKV, const int descale_QKV_offset,
    const c10::optional<at::Tensor> descale_S, const int descale_S_offset,
    const c10::optional<at::Tensor> scale_S, const int scale_S_offset,
    const c10::optional<at::Tensor> scale_O, const int scale_O_offset,
    c10::optional<at::Tensor> amax_S, const int amax_S_offset, c10::optional<at::Tensor> amax_O,
    const int amax_O_offset, const c10::optional<at::Tensor> Bias,
    const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
120
121

std::vector<at::Tensor> fused_attn_bwd(
122
123
    size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
124
125
126
127
    const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
    const at::Tensor cu_seqlens_kv, const at::Tensor Q, const at::Tensor K, const at::Tensor V,
    const at::Tensor O, const at::Tensor dO, const transformer_engine::DType qkv_type,
    const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
128
129
130
131
132
133
134
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
    const c10::optional<at::Tensor> cu_seqlens_kv_padded,
    const c10::optional<at::Tensor> descale_QKV, const c10::optional<at::Tensor> descale_S,
    const c10::optional<at::Tensor> descale_O, const c10::optional<at::Tensor> descale_dO,
    const c10::optional<at::Tensor> descale_dP, const c10::optional<at::Tensor> scale_S,
    const c10::optional<at::Tensor> scale_dP, const c10::optional<at::Tensor> scale_dQKV,
    c10::optional<at::Tensor> amax_dP, c10::optional<at::Tensor> amax_dQKV);
Przemek Tredak's avatar
Przemek Tredak committed
135

136
137
138
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

139
140
141
142
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
void te_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
             bool transa, at::Tensor B, at::Tensor B_scale_inverse,
             transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale,
             transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
             transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad,
             at::Tensor workspace, size_t workspaceSize, bool accumulate,
             bool use_split_accumulator, int math_sm_count);

void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
                    bool transa, at::Tensor B, at::Tensor B_scale_inverse,
                    transformer_engine::DType B_type, bool transb, at::Tensor D, at::Tensor D_scale,
                    transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
                    transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad,
                    at::Tensor workspace, size_t workspaceSize, bool accumulate,
                    bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                    bool gemm_producer, at::Tensor counter);
Przemek Tredak's avatar
Przemek Tredak committed
159

160
161
162
163
164
165
166
167
168
169
void te_grouped_gemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
                     transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
                     at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
                     bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
                     transformer_engine::DType D_type, at::Tensor D_amax,
                     std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
                     std::vector<at::Tensor> pre_gelu_out, bool grad,
                     std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
                     bool use_split_accumulator, int math_sm_count);

170
171
172
173
174
175
176
177
178
179
void te_grouped_gemm_single_output(
    std::vector<at::Tensor> A, std::vector<at::Tensor> A_scale_inverse, int A_offset,
    transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
    at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type, bool transb,
    std::vector<int64_t> m_splits, at::Tensor D, int D_offset, at::Tensor D_scale,
    transformer_engine::DType D_type, at::Tensor D_amax, std::vector<at::Tensor> bias,
    transformer_engine::DType bias_type, std::vector<at::Tensor> pre_gelu_out, bool grad,
    std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
    bool use_split_accumulator, int math_sm_count);

180
181
182
183
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

184
185
186
187
188
189
190
191
void fused_cast_transpose(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                          at::Tensor input_cast, at::Tensor input_transpose,
                          transformer_engine::DType otype);

void fused_cast_transpose_noop(at::Tensor input, at::Tensor noop, at::Tensor scale, at::Tensor amax,
                               at::Tensor scale_inv, at::Tensor input_cast,
                               at::Tensor input_transpose, transformer_engine::DType otype,
                               int scale_offset = 0, int amax_offset = 0, int scale_inv_offset = 0);
Przemek Tredak's avatar
Przemek Tredak committed
192

193
194
195
196
197
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output, at::Tensor scale,
                                                   at::Tensor amax, at::Tensor scale_inv,
                                                   transformer_engine::DType otype,
                                                   int scale_offset = 0, int amax_offset = 0,
                                                   int scale_inv_offset = 0);
Przemek Tredak's avatar
Przemek Tredak committed
198

199
200
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output, at::Tensor scale,
                                                  at::Tensor amax, at::Tensor scale_inv,
201
202
                                                  transformer_engine::DType otype,
                                                  transformer_engine::DType grad_bias_type,
203
204
                                                  int scale_offset = 0, int amax_offset = 0,
                                                  int scale_inv_offset = 0);
205

Przemek Tredak's avatar
Przemek Tredak committed
206
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
207
208
                                                         at::Tensor gelu_input, at::Tensor scale,
                                                         at::Tensor amax, at::Tensor scale_inv,
209
                                                         transformer_engine::DType otype,
210
211
                                                         int scale_offset = 0, int amax_offset = 0,
                                                         int scale_inv_offset = 0);
Przemek Tredak's avatar
Przemek Tredak committed
212

213
214
215
216
217
218
void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
                                  at::Tensor grad_input_transpose, at::Tensor scale,
                                  at::Tensor amax, at::Tensor scale_inv,
                                  transformer_engine::DType otype, int scale_offset = 0,
                                  int amax_offset = 0, int scale_inv_offset = 0);

Tim Moon's avatar
Tim Moon committed
219
220
221
222
223
224
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
                                std::vector<at::Tensor> scale_list,
                                std::vector<at::Tensor> cast_output_list,
                                std::vector<at::Tensor> transposed_output_list,
                                std::vector<at::Tensor> amax_output_list,
                                std::vector<at::Tensor> scale_inv_output_list,
225
                                transformer_engine::DType otype);
Tim Moon's avatar
Tim Moon committed
226

227
228
229
230
231
std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>> fused_multi_cast_transpose_alloc(
    std::vector<at::Tensor> input_list, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
    std::vector<int> scale_indices, std::vector<int> amax_indices,
    std::vector<int> scale_inv_indices, transformer_engine::DType otype);

232
at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype);
Tim Moon's avatar
Tim Moon committed
233

234
void fp8_transpose_noalloc(at::Tensor input, at::Tensor output, transformer_engine::DType otype);
Przemek Tredak's avatar
Przemek Tredak committed
235

236
237
void fp8_transpose_noalloc_noop(at::Tensor input, at::Tensor output, at::Tensor noop,
                                transformer_engine::DType otype);
238

239
240
241
242
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
at::Tensor gelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                transformer_engine::DType otype);

at::Tensor relu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                transformer_engine::DType otype);

at::Tensor geglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                 transformer_engine::DType otype);

at::Tensor reglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                 transformer_engine::DType otype);

at::Tensor swiglu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                  transformer_engine::DType otype);

at::Tensor qgelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                 transformer_engine::DType otype);

at::Tensor srelu(at::Tensor input, at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                 transformer_engine::DType otype);

at::Tensor dgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);

at::Tensor drelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);

at::Tensor dgeglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);

at::Tensor dreglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);

at::Tensor dswiglu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);

at::Tensor dqgelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);

at::Tensor dsrelu(at::Tensor grad, at::Tensor input, transformer_engine::DType otype);
277

278
279
280
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
281

282
283
284
285
286
287
288
289
290
std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
                                      const at::Tensor &mu, const at::Tensor &rsigma,
                                      const at::Tensor &gamma, const int sm_margin,
                                      const bool zero_centered_gamma);

std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight,
                                          const at::Tensor &bias, float eps, at::Tensor scale,
                                          at::Tensor amax, at::Tensor scale_inv,
                                          transformer_engine::DType otype, const int sm_margin,
291
                                          const bool zero_centered_gamma,
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                                          const int scale_offset = 0, const int amax_offset = 0,
                                          const int scale_inv_offset = 0);

std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(
    const at::Tensor &input, const at::Tensor &weight, const at::Tensor &bias, float eps,
    at::Tensor scale, at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv,
    transformer_engine::DType otype, const int sm_margin, const bool zero_centered_gamma,
    const int scale_offset = 0, const int amax_offset = 0, const int scale_inv_offset = 0);

at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight,
                                 const at::Tensor &bias, float eps, at::Tensor scale,
                                 at::Tensor amax, at::Tensor scale_inv,
                                 transformer_engine::DType otype, const int sm_margin,
                                 const bool zero_centered_gamma, const int scale_offset = 0,
                                 const int amax_offset = 0, const int scale_inv_offset = 0);

std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, const at::Tensor &weight,
                                      const at::Tensor &bias, float eps, const int sm_margin,
                                      const bool zero_centered_gamma);

std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
                                              const at::Tensor &bias, at::Tensor ln_out, float eps,
                                              const int sm_margin, const bool zero_centered_gamma);

at::Tensor layernorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight,
                             const at::Tensor &bias, float eps, const int sm_margin,
                             const bool zero_centered_gamma);
Przemek Tredak's avatar
Przemek Tredak committed
319

320
321
322
323
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

324
325
326
std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
                                    const at::Tensor &rsigma, const at::Tensor &gamma,
                                    const int sm_margin, const bool zero_centered_gamma);
327

328
329
330
331
332
333
std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input, const at::Tensor &weight,
                                        float eps, at::Tensor scale, at::Tensor amax,
                                        at::Tensor scale_inv, transformer_engine::DType otype,
                                        const int sm_margin, const bool zero_centered_gamma,
                                        const int scale_offset = 0, const int amax_offset = 0,
                                        const int scale_inv_offset = 0);
334

335
336
337
338
339
std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(
    const at::Tensor &input, const at::Tensor &weight, float eps, at::Tensor scale,
    at::Tensor ln_out, at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype,
    const int sm_margin, const bool zero_centered_gamma, const int scale_offset = 0,
    const int amax_offset = 0, const int scale_inv_offset = 0);
Przemek Tredak's avatar
Przemek Tredak committed
340

341
342
343
344
345
at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
                               at::Tensor scale, at::Tensor amax, at::Tensor scale_inv,
                               transformer_engine::DType otype, const int sm_margin,
                               const bool zero_centered_gamma, const int scale_offset = 0,
                               const int amax_offset = 0, const int scale_inv_offset = 0);
Przemek Tredak's avatar
Przemek Tredak committed
346

347
348
std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input, const at::Tensor &weight, float eps,
                                    const int sm_margin, const bool zero_centered_gamma);
349

350
351
352
std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input, const at::Tensor &weight,
                                            at::Tensor ln_out, float eps, const int sm_margin,
                                            const bool zero_centered_gamma);
353

354
355
at::Tensor rmsnorm_fwd_inf(const at::Tensor &input, const at::Tensor &weight, float eps,
                           const int sm_margin, const bool zero_centered_gamma);
356

357
/***************************************************************************************************
358
 * Cast
359
 **************************************************************************************************/
360

361
at::Tensor cast_to_fp8(const at::Tensor &input, const at::Tensor &scale, at::Tensor amax,
362
363
364
                       at::Tensor scale_inv, transformer_engine::DType otype,
                       const int scale_offset = 0, const int amax_offset = 0,
                       const int scale_inv_offset = 0);
365

366
void cast_to_fp8_noalloc(const at::Tensor &input, const at::Tensor &scale, at::Tensor output,
367
368
369
                         at::Tensor amax, at::Tensor scale_inv, transformer_engine::DType otype,
                         const int scale_offset = 0, const int amax_offset = 0,
                         const int scale_inv_offset = 0);
370

371
at::Tensor cast_from_fp8(const at::Tensor &input, const at::Tensor &scale_inv,
372
373
                         transformer_engine::DType itype, transformer_engine::DType otype,
                         const int scale_inv_offset = 0);
374

375
376
377
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
378

379
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);
380

381
382
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                   float scale_factor);
383

384
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);
385

386
387
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                          float scale_factor);
388

389
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);
390
391
392

at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
393
                                                       float scale_factor);
394

395
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);
396
397
398

at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
399
                                                         float scale_factor);
400

401
402
403
404
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

405
406
407
408
409
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 std::vector<at::Tensor> scale_invs,
                                                 const std::string &amax_compute_algo,
410
                                                 transformer_engine::DType fp8_dtype, float margin);
411

412
413
414
415
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

416
417
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
                              const bool transpose_output_memory);
418

419
420
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
                               const bool transpose_output_memory);
421

422
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
423
                                  const at::Tensor &freqs, const int cp_size, const int cp_rank);
424

425
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
426
                                   const at::Tensor &freqs, const int cp_size, const int cp_rank);
427
428

/***************************************************************************************************
429
 * Miscellaneous
430
431
 **************************************************************************************************/

432
433
size_t get_cublasLt_version();

434
435
size_t get_cudnn_version();

436
437
438
439
/***************************************************************************************************
 * Support THD format for Context Parallel
 **************************************************************************************************/

440
441
442
443
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
                                int half_idx);

void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
444
                                    const at::Tensor &cu_seqlens, bool lse_packed);
445
446

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
447
                                    bool lse_packed, int second_half_lse_seqlen);
448
449
450

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
                        const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
451
                        bool only_second_half, bool lse_packed);
452
453
454
455

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
                         const at::Tensor &cu_seqlens, const std::string &first_half,
                         const std::string &second_half);
456

457
458
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
                                       int world_size, int rank);
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474

/***************************************************************************************************
 * multi_tensor_* kernels
 **************************************************************************************************/

void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
                             std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::Tensor inv_scale, at::optional<bool> per_tensor_python);

475
using transformer_engine::DType;
476
477
478
479
480
481
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
                            std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                            const float beta1, const float beta2, const float epsilon,
                            const int step, const int mode, const int bias_correction,
                            const float weight_decay);

482
483
484
485
486
487
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
                                std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                                const float beta1, const float beta2, const float epsilon,
                                const int step, const int mode, const int bias_correction,
                                const float weight_decay, DType fp8_dtype);

488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
                                       std::vector<std::vector<at::Tensor>> tensor_lists,
                                       at::Tensor lr, const float beta1, const float beta2,
                                       const float epsilon, at::Tensor step, const int mode,
                                       const int bias_correction, const float weight_decay,
                                       at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
                                              std::vector<std::vector<at::Tensor>> tensor_lists,
                                              at::Tensor lr, const float beta1, const float beta2,
                                              const float epsilon, at::Tensor step, const int mode,
                                              const int bias_correction, const float weight_decay,
                                              at::Tensor inv_scale);

void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
                           float momentum, float dampening, float lr, bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale);
506

507
508
509
510
511
512
513
514
/***************************************************************************************************
 * padding
 **************************************************************************************************/

void fused_multi_row_padding(at::Tensor input, at::Tensor output,
                             std::vector<size_t> input_row_list,
                             std::vector<size_t> padded_input_row_list);

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
/***************************************************************************************************
 * Comm+GEMM Overlap Wrappers
 **************************************************************************************************/

class CommOverlapHelper : torch::CustomClassHolder {
 private:
  bool initialized{false};
  bool backend_is_nccl{false};
  std::map<std::string, c10d::ProcessGroup *> pgs;

 public:
  int myrank = -1;
  int numranks = -1;
  int mylocal = -1;
  int numlocal = -1;
  int mynode = -1;
  int numnodes = -1;

  CommOverlapHelper();

  CommOverlapHelper(c10d::ProcessGroup *world_group,
                    std::optional<c10d::ProcessGroup *> intra_node_group,
                    std::optional<c10d::ProcessGroup *> inter_node_group);

  ~CommOverlapHelper();

  void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
                    ExtComm comm);

  void ub_barrier(ExtComm comm);
};

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
 private:
  torch::Tensor _ubuf_torch;
  torch::Tensor _ubuf_counter;

 public:
  CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
              CommOverlapHelper *helper, int tp_size, int num_splits = 3,
              int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
              int num_comm_sm = 16, bool set_sm_margin = true, bool atomic_gemm = false);

  void set_ubuf_scale_inv(torch::Tensor scale_inv) {
    assert(scale_inv.numel());
    assert(scale_inv.scalar_type() == torch::kFloat32);
    transformer_engine::CommOverlapBase::set_ubuf_scale_inv(
        reinterpret_cast<float *>(scale_inv.data_ptr()));
  }

  void copy_input_to_ubuf(torch::Tensor input, int comm_type);

  torch::Tensor get_ubuf_output(int comm_type);

  /*
  ** Bulk GEMM + COMM
  ** This function assumes the communication input is pre-copied to _ubuf
  */
  std::vector<at::Tensor> bulk_overlap(
      at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
      transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
      int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
      at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
      transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
      size_t workspaceSize, bool accumulate, bool use_split_accumulator,
      transformer_engine::CommOverlapType comm_type, at::Tensor rs_output);

  /*
  ** Split FPROP GEMM + ReduceScatter
  */
  void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
                              transformer_engine::DType A_type, bool transa, at::Tensor B,
                              at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
                              transformer_engine::DType B_type, bool transb, at::Tensor D,
                              at::Tensor D_scale, transformer_engine::DType D_type,
                              at::Tensor D_amax, at::Tensor bias,
                              transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
                              bool grad, at::Tensor workspace, size_t workspaceSize,
                              bool accumulate, bool use_split_accumulator, bool gemm_overlap,
                              at::Tensor rs_output);

  /*
  ** Split FPROP GEMM + ReduceScatter
  */
  void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
                        transformer_engine::DType A_type, bool transa, at::Tensor B,
                        at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
                        transformer_engine::DType B_type, bool transb, at::Tensor D,
                        at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
                        at::Tensor bias, transformer_engine::DType bias_type,
                        at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
                        size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                        bool gemm_overlap, at::Tensor rs_output);
};  // CommOverlap

class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
 private:
  torch::Tensor _ubuf_torch;
  torch::Tensor _ubuf_counter;

 public:
  CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
                 CommOverlapHelper *helper, int tp_size,
                 transformer_engine::CommOverlapType comm_type,
                 int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
                 int num_comm_sm = 3, bool set_sm_margin = true, bool atomic_gemm = false,
                 bool use_ce = true, bool aggregate = false);

  void set_ubuf_scale_inv(torch::Tensor scale_inv) {
    assert(scale_inv.numel());
    assert(scale_inv.scalar_type() == torch::kFloat32);
    transformer_engine::CommOverlapP2PBase::set_ubuf_scale_inv(
        reinterpret_cast<float *>(scale_inv.data_ptr()));
  }

  void copy_input_to_ubuf(torch::Tensor input, bool chunk);

  torch::Tensor get_ubuf_output(int comm_type);

  /*
  ** Split AllGather + AtomicGEMM using P2P communication
  ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
  *needed to have AG outputs
  ** in each rank to be in the contiguous memory space after all ring exchange
  *phases.
  */
  void atomic_gemm_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
                              transformer_engine::DType A_type, bool transa, at::Tensor B,
                              at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
                              transformer_engine::DType B_type, bool transb, at::Tensor D,
                              at::Tensor D_scale, transformer_engine::DType D_type,
                              at::Tensor D_amax, at::Tensor bias,
                              transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
                              bool grad, at::Tensor workspace, size_t workspaceSize,
                              bool accumulate, bool use_split_accumulator, at::Tensor B_copy);

  /*
  ** Split AllGather + GEMM using P2P communication
  ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is
  *needed to have AG outputs
  ** in each rank to be in the contiguous memory space after all ring exchange
  *phases.
  */
  void split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
                        transformer_engine::DType A_type, bool transa, at::Tensor B,
                        at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
                        transformer_engine::DType B_type, bool transb, at::Tensor D,
                        at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
                        at::Tensor bias, transformer_engine::DType bias_type,
                        at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
                        size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                        at::Tensor B_copy);

  /*
  ** Split ReduceScatter + GEMM using P2P communication
  */
  void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
                              transformer_engine::DType A_type, bool transa, at::Tensor B,
                              at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
                              transformer_engine::DType B_type, bool transb, at::Tensor D,
                              at::Tensor D_scale, transformer_engine::DType D_type,
                              at::Tensor D_amax, at::Tensor bias,
                              transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
                              bool grad, at::Tensor workspace, size_t workspaceSize,
                              bool accumulate, bool use_split_accumulator, at::Tensor rs_output);

  /*
  ** Split ReduceScatter + GEMM using P2P communication
  */
  void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
                        transformer_engine::DType A_type, bool transa, at::Tensor B,
                        at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
                        transformer_engine::DType B_type, bool transb, at::Tensor D,
                        at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
                        at::Tensor bias, transformer_engine::DType bias_type,
                        at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
                        size_t workspaceSize, bool accumulate, bool use_split_accumulator,
                        at::Tensor rs_output);
};  // CommOverlapP2P

695
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_