extensions.h 31.9 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

10
11
#include <optional>

Przemek Tredak's avatar
Przemek Tredak committed
12
13
#include "common.h"

14
15
16
17
class CommOverlapHelper;
class CommOverlap;
class CommOverlapP2P;

18
19
namespace transformer_engine::pytorch {

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
/***************************************************************************************************
 * Router fusion
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
    at::Tensor logits, int topk, bool use_pre_softmax, c10::optional<int> num_groups,
    c10::optional<int> group_topk, c10::optional<float> scaling_factor, std::string score_function,
    c10::optional<at::Tensor> expert_bias);

at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts,
                                              at::Tensor routing_map,
                                              at::Tensor intermediate_output, at::Tensor grad_probs,
                                              int topk, bool use_pre_softmax,
                                              c10::optional<float> scaling_factor,
                                              std::string score_function);

std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
    at::Tensor logits, int topk, std::string score_function);

at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
                                            at::Tensor intermediate_output, at::Tensor grad_probs,
                                            int topk, std::string score_function);

std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
                                                          at::Tensor tokens_per_expert,
45
46
47
                                                          int total_num_tokens, int num_experts,
                                                          int num_rows, int num_cols, int topk,
                                                          float coeff);
48

49
50
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows,
                                  int num_cols, at::Tensor grad_aux_loss);
51

52
53
54
55
56
/***************************************************************************************************
 * Permutation
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
57
58
    at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
    std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);
59

60
61
at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
                           at::Tensor prob, int64_t num_tokens, int64_t topK);
62

63
64
at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
                             at::Tensor prob, int64_t num_tokens, int64_t topK);
65
66

std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
67
68
                                                     const DType dtype, at::Tensor row_id_map,
                                                     at::Tensor prob);
69

70
71
72
73
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

74
75
76
77
78
NVTE_Fused_Attn_Backend get_fused_attn_backend(
    bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
    size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
    size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
cyanguwa's avatar
cyanguwa committed
79

80
std::vector<py::object> fused_attn_fwd(
81
82
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
    bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
83
    NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
84
85
    const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
    const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
86
87
88
89
90
    const std::optional<at::Tensor> cu_seqlens_q_padded,
    const std::optional<at::Tensor> cu_seqlens_kv_padded,
    const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
    py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
    const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
91

92
std::vector<py::object> fused_attn_bwd(
93
94
    size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
95
    const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
96
    const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
97
98
    const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
    const std::vector<at::Tensor> Aux_CTX_Tensors,
99
100
    const std::optional<at::Tensor> cu_seqlens_q_padded,
    const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
101
    py::handle dp_quantizer, py::handle dqkv_quantizer);
Przemek Tredak's avatar
Przemek Tredak committed
102

103
104
105
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

106
107
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
108
109
110
111
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
                      at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
                      NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len,
                      int max_pages_per_seq, bool is_non_paged);
112

113
114
115
116
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

117
using MaybeTensor = std::optional<at::Tensor>;
118

119
120
121
122
123
124
std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
                             py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
                             DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
                             at::Tensor workspace, size_t workspaceSize, bool accumulate,
                             bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
                             std::optional<CommOverlapType> comm_type = std::nullopt,
Jan Bielak's avatar
Jan Bielak committed
125
126
                             MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false,
                             float alpha = 1.0f, std::optional<float> beta = std::nullopt);
127
128

void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
129
                    std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
130
131
132
133
                    at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
                    bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
                    at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
                    at::Tensor workspace, size_t workspaceSize, bool accumulate,
134
135
                    bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                    bool gemm_producer, at::Tensor counter);
Przemek Tredak's avatar
Przemek Tredak committed
136

137
138
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
    std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
139
140
141
142
    std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
    std::vector<at::Tensor> bias, DType bias_type, bool single_output,
    std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
    size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count);
143

yuguo's avatar
yuguo committed
144
#ifdef __HIP_PLATFORM_AMD__
yuguo's avatar
yuguo committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
std::vector<py::object> generic_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::object D, int batch_count,
                             py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
                             DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
                             at::Tensor workspace, size_t workspaceSize, bool accumulate,
                             bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
                             std::optional<CommOverlapType> comm_type = std::nullopt,
                             MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);

std::vector<py::object> tensorwise_int8_batchgemm(py::handle A, bool transa, py::handle B, bool transb, py::handle A_scales, py::handle B_scales, py::object D, int batch_count,
                             py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
                             DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
                             at::Tensor workspace, size_t workspaceSize, bool accumulate,
                             bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
                             std::optional<CommOverlapType> comm_type = std::nullopt,
                             MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
                             
yuguo's avatar
yuguo committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
void te_batchgemm(std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int A_offset,
                     transformer_engine::DType A_type, bool transa, std::vector<at::Tensor> B,
                     at::Tensor B_scale_inverse, int B_offset, transformer_engine::DType B_type,
                     bool transb, std::vector<at::Tensor> D, int D_offset, at::Tensor D_scale,
                     transformer_engine::DType D_type, at::Tensor D_amax,
                     std::vector<at::Tensor> bias, transformer_engine::DType bias_type,
                     std::vector<at::Tensor> pre_gelu_out, bool grad,
                     std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
                     bool use_split_accumulator, int math_sm_count);

std::vector<at::Tensor> te_batchgemm_ts(
    std::vector<at::Tensor> A, at::Tensor A_scale_inverse, int64_t A_offset, int64_t A_type,
    int64_t transa, std::vector<at::Tensor> B, at::Tensor B_scale_inverse, int64_t B_offset,
    int64_t B_type, int64_t transb, std::vector<at::Tensor> D, int64_t D_offset, at::Tensor D_scale,
    int64_t D_type, at::Tensor D_amax, std::vector<at::Tensor> bias, int64_t bias_type,
    std::vector<at::Tensor> pre_gelu_out, int64_t grad, std::vector<at::Tensor> workspace,
    int64_t workspaceSize, int64_t accumulate, int64_t use_split_accumulator);
yuguo's avatar
yuguo committed
178
179
#endif

180
181
182
183
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

184
at::Tensor fp8_transpose(at::Tensor input, DType otype,
185
186
                         std::optional<at::Tensor> output = std::nullopt);

187
188
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = std::nullopt);

189
190
191
192
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

193
/* GELU and variants*/
194
195
py::object gelu(const at::Tensor &input, py::handle quantizer);

196
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
197

198
py::object geglu(const at::Tensor &input, py::handle quantizer);
199

200
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
201

202
py::object qgelu(const at::Tensor &input, py::handle quantizer);
203

204
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
205

206
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
207

208
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
209

210
211
/* ReLU and variants*/
py::object relu(const at::Tensor &input, py::handle quantizer);
212

213
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
214

215
py::object reglu(const at::Tensor &input, py::handle quantizer);
216

217
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
218

219
py::object srelu(const at::Tensor &input, py::handle quantizer);
220
221
222

py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

223
224
225
226
227
228
229
230
231
232
233
234
235
py::object sreglu(const at::Tensor &input, py::handle quantizer);

py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

/* Silu and variants*/
py::object silu(const at::Tensor &input, py::handle quantizer);

py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

py::object swiglu(const at::Tensor &input, py::handle quantizer);

py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

236
237
238
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
239

240
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
241
242
243
244
                                      const at::Tensor &mu, const at::Tensor &rsigma,
                                      const at::Tensor &gamma, const int sm_margin,
                                      const bool zero_centered_gamma);

245
246
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
                                      float eps, py::object ln_out, py::handle quantizer,
247
                                      DType out_dtype, const int sm_margin,
248
249
                                      const bool zero_centered_gamma);

250
251
252
253
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

254
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
255
256
                                    const at::Tensor &rsigma, const at::Tensor &gamma,
                                    const int sm_margin, const bool zero_centered_gamma);
257

258
259
260
261
262
std::vector<py::object> rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x,
                                        const at::Tensor &add, const at::Tensor &rsigma,
                                        const at::Tensor &gamma, const int sm_margin,
                                        const bool zero_centered_gamma);

263
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
264
265
                                    py::object ln_out, py::handle quantizer, DType otype,
                                    const int sm_margin, const bool zero_centered_gamma);
266
267
268
269

/***************************************************************************************************
 * Cast
 **************************************************************************************************/
270

271
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
272
                    std::optional<at::Tensor> noop_flag);
273

274
py::object dequantize(const py::handle &input, DType otype);
275

276
277
278
279
280
281
282
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
                                              std::vector<py::handle> quantizer_list);

std::vector<py::object> split_quantize(const at::Tensor &tensor,
                                       const std::vector<int> &split_sections,
                                       std::vector<py::handle> quantizer_list);

283
/***************************************************************************************************
284
 * Bias gradient fusions
285
 **************************************************************************************************/
286

287
288
std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);

289
290
291
292
293
294
295
296
297
298
299
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
300

301
302
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
303

vasunvidia's avatar
vasunvidia committed
304
305
306
307
308
309
310
311
312
313
314
/***************************************************************************************************
 * Dropout
 **************************************************************************************************/

std::vector<py::object> dropout_fwd(const py::handle &input, const float dropout_probability,
                                    std::optional<at::Tensor> out = std::nullopt);

py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask,
                       const float dropout_probability,
                       std::optional<at::Tensor> grad_input = std::nullopt);

315
316
317
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
318

319
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);
320

321
322
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                   float scale_factor);
323

324
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);
325

326
327
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                          float scale_factor);
328

329
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);
330
331
332

at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
333
                                                       float scale_factor);
334

335
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);
336
337
338

at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
339
                                                         float scale_factor);
340

341
342
343
344
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

345
346
void compute_amax(const at::Tensor &tensor, at::Tensor &amax);

347
348
void compute_channel_colwise_amax(const at::Tensor &tensor, at::Tensor &amax, at::Tensor &fp8_scale);

349
350
351
352
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 const std::string &amax_compute_algo,
353
                                                 DType fp8_dtype, float margin);
354

355
356
357
358
359
360
361
// Note that the start_offset is the logical offset along the tensor dimension.
// The offset in bytes is start_offset * sizeof(tensor.dtype)
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
                                            size_t w, size_t start_offset, size_t block_len);

void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
                                    size_t h, size_t w, size_t start_offset, size_t block_len,
362
                                    const DType out_dtype);
363

364
365
366
367
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

368
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
Sudhakar Singh's avatar
Sudhakar Singh committed
369
                              const std::optional<at::Tensor> start_positions,
370
                              const NVTE_QKV_Format qkv_format, const bool interleaved,
371
                              const std::optional<at::Tensor> cu_seqlens, const int cp_size,
372
                              const int cp_rank);
373

374
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
375
                               const NVTE_QKV_Format qkv_format, const bool interleaved,
376
                               const std::optional<at::Tensor> cu_seqlens, const int cp_size,
377
                               const int cp_rank);
378

379
380
381
382
383
384
385
386
387
388
389
390
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
    const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
    const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
    const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank);

at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
                                   const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
                                   const at::Tensor &k_freqs,
                                   const std::vector<int> &qkv_split_arg_list,
                                   const NVTE_QKV_Format qkv_format, const bool interleaved,
                                   const int cp_size, const int cp_rank);

391
/***************************************************************************************************
392
 * Miscellaneous
393
394
 **************************************************************************************************/

395
396
size_t get_cublasLt_version();

397
398
size_t get_cudnn_version();

399
400
401
402
/***************************************************************************************************
 * Support THD format for Context Parallel
 **************************************************************************************************/

403
404
405
406
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
                                int half_idx);

void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
407
                                    const at::Tensor &cu_seqlens, bool lse_packed);
408
409

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
410
                                    bool lse_packed, int second_half_lse_seqlen);
411
412
413

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
                        const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
414
                        bool only_second_half, bool lse_packed);
415
416
417
418

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
                         const at::Tensor &cu_seqlens, const std::string &first_half,
                         const std::string &second_half);
419

420
421
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
                                       int world_size, int rank);
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443

/***************************************************************************************************
 * multi_tensor_* kernels
 **************************************************************************************************/

void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
                             std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::Tensor inv_scale, at::optional<bool> per_tensor_python);

void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
                            std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                            const float beta1, const float beta2, const float epsilon,
                            const int step, const int mode, const int bias_correction,
                            const float weight_decay);

444
445
446
447
448
449
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
                                            std::vector<std::vector<at::Tensor>> tensor_lists,
                                            const float lr, const float beta1, const float beta2,
                                            const float epsilon, const int step, const int mode,
                                            const int bias_correction, const float weight_decay);

450
451
452
453
454
455
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
                                std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                                const float beta1, const float beta2, const float epsilon,
                                const int step, const int mode, const int bias_correction,
                                const float weight_decay, DType fp8_dtype);

456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
                                       std::vector<std::vector<at::Tensor>> tensor_lists,
                                       at::Tensor lr, const float beta1, const float beta2,
                                       const float epsilon, at::Tensor step, const int mode,
                                       const int bias_correction, const float weight_decay,
                                       at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
                                              std::vector<std::vector<at::Tensor>> tensor_lists,
                                              at::Tensor lr, const float beta1, const float beta2,
                                              const float epsilon, at::Tensor step, const int mode,
                                              const int bias_correction, const float weight_decay,
                                              at::Tensor inv_scale);

void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
                           float momentum, float dampening, float lr, bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale);
474

475
476
477
478
void multi_tensor_compute_scale_and_scale_inv_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    float max_fp8, bool force_pow_2_scales, float epsilon);

479
480
481
482
483
484
485
486
/***************************************************************************************************
 * padding
 **************************************************************************************************/

void fused_multi_row_padding(at::Tensor input, at::Tensor output,
                             std::vector<size_t> input_row_list,
                             std::vector<size_t> padded_input_row_list);

487
488
489
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
                               std::vector<size_t> input_row_list,
                               std::vector<size_t> unpadded_input_row_list);
490
/***************************************************************************************************
491
 * NVSHMEM APIs
492
493
 **************************************************************************************************/

494
495
void init_nvshmem_backend(c10d::ProcessGroup *process_group);

496
at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
497

498
void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
499

500
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
501
502
503

void nvshmem_finalize();

504
505
506
507
508
509
510
/***************************************************************************************************
 * Comm+GEMM Overlap Wrappers
 **************************************************************************************************/

void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream,
                                        at::Stream recv_stream);

511
}  // namespace transformer_engine::pytorch
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
/***************************************************************************************************
 * Comm+GEMM Overlap Wrappers
 **************************************************************************************************/

class CommOverlapHelper : torch::CustomClassHolder {
 private:
  bool initialized{false};
  bool backend_is_nccl{false};
  std::map<std::string, c10d::ProcessGroup *> pgs;

 public:
  int myrank = -1;
  int numranks = -1;
  int mylocal = -1;
  int numlocal = -1;
  int mynode = -1;
  int numnodes = -1;

  CommOverlapHelper();

  CommOverlapHelper(c10d::ProcessGroup *world_group,
534
                    std::optional<c10d::ProcessGroup *> intra_node_group);
535
536
537
538
539
540
541
542
543
544
545
546
547
548

  ~CommOverlapHelper();

  void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
                    ExtComm comm);

  void ub_barrier(ExtComm comm);
};

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
 public:
  CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
              CommOverlapHelper *helper, int tp_size, int num_splits = 3,
              int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
549
550
551
552
553
554
              int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
              bool set_sm_margin = true, bool atomic_gemm = false,
              bool rs_overlap_first_gemm = false);

  ~CommOverlap() {}

555
  void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
556

557
558
  at::Tensor get_buffer(bool local_chunk = false,
                        std::optional<std::vector<int64_t>> shape = std::nullopt);
559

560
  std::pair<at::Stream, at::Stream> get_communication_stream();
561

562
563
564
565
566
567
568
569
};  // CommOverlap

class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
 public:
  CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
                 CommOverlapHelper *helper, int tp_size,
                 transformer_engine::CommOverlapType comm_type,
                 int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
570
571
572
573
574
575
                 int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
                 bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
                 bool aggregate = false);

  ~CommOverlapP2P() {}

576
  void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
577

578
579
  at::Tensor get_buffer(bool local_chunk = false,
                        std::optional<std::vector<int64_t>> shape = std::nullopt);
580

581
  std::pair<at::Stream, at::Stream> get_communication_stream();
582

583
584
};  // CommOverlapP2P

585
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_