extensions.h 26.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

10
11
#include <optional>

Przemek Tredak's avatar
Przemek Tredak committed
12
13
#include "common.h"

14
15
namespace transformer_engine::pytorch {

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
/***************************************************************************************************
 * Router fusion
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_topk_with_score_function_fwd(
    at::Tensor logits, int topk, bool use_pre_softmax, c10::optional<int> num_groups,
    c10::optional<int> group_topk, c10::optional<float> scaling_factor, std::string score_function,
    c10::optional<at::Tensor> expert_bias);

at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts,
                                              at::Tensor routing_map,
                                              at::Tensor intermediate_output, at::Tensor grad_probs,
                                              int topk, bool use_pre_softmax,
                                              c10::optional<float> scaling_factor,
                                              std::string score_function);

std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_score_for_moe_aux_loss_fwd(
    at::Tensor logits, int topk, std::string score_function);

at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts,
                                            at::Tensor intermediate_output, at::Tensor grad_probs,
                                            int topk, std::string score_function);

std::tuple<at::Tensor, at::Tensor> fused_moe_aux_loss_fwd(at::Tensor probs,
                                                          at::Tensor tokens_per_expert,
41
42
43
                                                          int total_num_tokens, int num_experts,
                                                          int num_rows, int num_cols, int topk,
                                                          float coeff);
44

45
46
at::Tensor fused_moe_aux_loss_bwd(at::Tensor Const_buf, at::Tensor tokens_per_expert, int num_rows,
                                  int num_cols, at::Tensor grad_aux_loss);
47

48
49
50
51
52
/***************************************************************************************************
 * Permutation
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
53
54
    at::Tensor input, const DType dtype, at::Tensor indices, int64_t num_out_tokens,
    std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);
55

56
57
at::Tensor moe_permute_bwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
                           at::Tensor prob, int64_t num_tokens, int64_t topK);
58

59
60
at::Tensor moe_unpermute_fwd(at::Tensor input, const DType dtype, at::Tensor row_id_map,
                             at::Tensor prob, int64_t num_tokens, int64_t topK);
61
62

std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
63
64
                                                     const DType dtype, at::Tensor row_id_map,
                                                     at::Tensor prob);
65

66
67
68
69
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

70
71
72
73
74
NVTE_Fused_Attn_Backend get_fused_attn_backend(
    bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
    NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
    size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
    size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
cyanguwa's avatar
cyanguwa committed
75

76
std::vector<py::object> fused_attn_fwd(
77
78
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
    bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
79
    NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
80
81
    const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
    const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
82
83
84
85
86
    const std::optional<at::Tensor> cu_seqlens_q_padded,
    const std::optional<at::Tensor> cu_seqlens_kv_padded,
    const std::optional<at::Tensor> page_table_k, const std::optional<at::Tensor> page_table_v,
    py::handle s_quantizer, py::handle o_quantizer, const std::optional<at::Tensor> Bias,
    const std::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
87

88
std::vector<py::object> fused_attn_bwd(
89
90
    size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
91
    const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
92
    const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
93
94
    const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type,
    const std::vector<at::Tensor> Aux_CTX_Tensors,
95
96
    const std::optional<at::Tensor> cu_seqlens_q_padded,
    const std::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
97
    py::handle dp_quantizer, py::handle dqkv_quantizer);
Przemek Tredak's avatar
Przemek Tredak committed
98

99
100
101
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

102
103
at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, int max_seq_len);
at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t);
104
105
106
107
void copy_to_kv_cache(at::Tensor new_k, at::Tensor new_v, at::Tensor k_cache, at::Tensor v_cache,
                      at::Tensor page_table, at::Tensor cu_new_lens, at::Tensor cu_cached_lens,
                      NVTE_QKV_Format kv_format, int b, int max_ctx_len, int max_seq_len,
                      int max_pages_per_seq, bool is_non_paged);
108

109
110
111
112
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

113
using MaybeTensor = std::optional<at::Tensor>;
114

115
116
117
118
119
120
121
122
123
std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
                             py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
                             DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
                             at::Tensor workspace, size_t workspaceSize, bool accumulate,
                             bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
                             std::optional<CommOverlapType> comm_type = std::nullopt,
                             MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);

void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type,
124
                    std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
125
126
127
128
                    at::Tensor B_scale_inverse, DType B_type, std::vector<int64_t> B_scaling_mode,
                    bool transb, at::Tensor D, at::Tensor D_scale, DType D_type, at::Tensor D_amax,
                    at::Tensor bias, DType bias_type, at::Tensor pre_gelu_out, bool grad,
                    at::Tensor workspace, size_t workspaceSize, bool accumulate,
129
130
                    bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                    bool gemm_producer, at::Tensor counter);
Przemek Tredak's avatar
Przemek Tredak committed
131

132
133
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
    std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
134
135
136
137
    std::optional<std::vector<at::Tensor>> D, DType D_type, std::vector<int64_t> m_splits,
    std::vector<at::Tensor> bias, DType bias_type, bool single_output,
    std::vector<at::Tensor> pre_gelu_out, bool grad, std::vector<at::Tensor> workspace,
    size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count);
138

139
140
141
142
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

143
at::Tensor fp8_transpose(at::Tensor input, DType otype,
144
145
                         std::optional<at::Tensor> output = std::nullopt);

146
147
at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = std::nullopt);

148
149
150
151
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

152
153
154
py::object gelu(const at::Tensor &input, py::handle quantizer);

py::object relu(const at::Tensor &input, py::handle quantizer);
155

156
py::object geglu(const at::Tensor &input, py::handle quantizer);
157

158
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
159

160
py::object reglu(const at::Tensor &input, py::handle quantizer);
161

162
py::object swiglu(const at::Tensor &input, py::handle quantizer);
163

164
py::object qgelu(const at::Tensor &input, py::handle quantizer);
165

166
py::object srelu(const at::Tensor &input, py::handle quantizer);
167

168
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
169

170
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
171

172
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
173

174
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
175

176
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
177

178
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
179

180
181
182
183
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

184
185
186
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
187

188
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
189
190
191
192
                                      const at::Tensor &mu, const at::Tensor &rsigma,
                                      const at::Tensor &gamma, const int sm_margin,
                                      const bool zero_centered_gamma);

193
194
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
                                      float eps, py::object ln_out, py::handle quantizer,
195
                                      DType out_dtype, const int sm_margin,
196
197
                                      const bool zero_centered_gamma);

198
199
200
201
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

202
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
203
204
                                    const at::Tensor &rsigma, const at::Tensor &gamma,
                                    const int sm_margin, const bool zero_centered_gamma);
205

206
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
207
208
                                    py::object ln_out, py::handle quantizer, DType otype,
                                    const int sm_margin, const bool zero_centered_gamma);
209
210
211
212

/***************************************************************************************************
 * Cast
 **************************************************************************************************/
213

214
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
215
                    std::optional<at::Tensor> noop_flag);
216

217
py::object dequantize(const py::handle &input, DType otype);
218

219
220
221
222
223
224
225
std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
                                              std::vector<py::handle> quantizer_list);

std::vector<py::object> split_quantize(const at::Tensor &tensor,
                                       const std::vector<int> &split_sections,
                                       std::vector<py::handle> quantizer_list);

226
/***************************************************************************************************
227
 * Bias gradient fusions
228
 **************************************************************************************************/
229

230
231
std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);

232
233
234
235
236
237
238
239
240
241
242
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
243

244
245
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
246

247
248
249
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
250

251
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);
252

253
254
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                   float scale_factor);
255

256
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);
257

258
259
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                          float scale_factor);
260

261
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);
262
263
264

at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
265
                                                       float scale_factor);
266

267
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);
268
269
270

at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
271
                                                         float scale_factor);
272

273
274
275
276
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

277
278
void compute_amax(const at::Tensor &tensor, at::Tensor &amax);

279
280
281
282
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 const std::string &amax_compute_algo,
283
                                                 DType fp8_dtype, float margin);
284

285
286
287
288
289
290
291
// Note that the start_offset is the logical offset along the tensor dimension.
// The offset in bytes is start_offset * sizeof(tensor.dtype)
void fp8_block_scaling_compute_partial_amax(const at::Tensor &tensor, at::Tensor amax, size_t h,
                                            size_t w, size_t start_offset, size_t block_len);

void fp8_block_scaling_partial_cast(const at::Tensor &inp, at::Tensor out, const at::Tensor &scale,
                                    size_t h, size_t w, size_t start_offset, size_t block_len,
292
                                    const DType out_dtype);
293

294
295
296
297
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

298
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
Sudhakar Singh's avatar
Sudhakar Singh committed
299
                              const std::optional<at::Tensor> start_positions,
300
                              const NVTE_QKV_Format qkv_format, const bool interleaved,
301
                              const std::optional<at::Tensor> cu_seqlens, const int cp_size,
302
                              const int cp_rank);
303

304
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
305
                               const NVTE_QKV_Format qkv_format, const bool interleaved,
306
                               const std::optional<at::Tensor> cu_seqlens, const int cp_size,
307
                               const int cp_rank);
308
309

/***************************************************************************************************
310
 * Miscellaneous
311
312
 **************************************************************************************************/

313
314
size_t get_cublasLt_version();

315
316
size_t get_cudnn_version();

317
318
319
320
/***************************************************************************************************
 * Support THD format for Context Parallel
 **************************************************************************************************/

321
322
323
324
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
                                int half_idx);

void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
325
                                    const at::Tensor &cu_seqlens, bool lse_packed);
326
327

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
328
                                    bool lse_packed, int second_half_lse_seqlen);
329
330
331

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
                        const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
332
                        bool only_second_half, bool lse_packed);
333
334
335
336

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
                         const at::Tensor &cu_seqlens, const std::string &first_half,
                         const std::string &second_half);
337

338
339
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
                                       int world_size, int rank);
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361

/***************************************************************************************************
 * multi_tensor_* kernels
 **************************************************************************************************/

void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
                             std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::Tensor inv_scale, at::optional<bool> per_tensor_python);

void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
                            std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                            const float beta1, const float beta2, const float epsilon,
                            const int step, const int mode, const int bias_correction,
                            const float weight_decay);

362
363
364
365
366
367
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
                                            std::vector<std::vector<at::Tensor>> tensor_lists,
                                            const float lr, const float beta1, const float beta2,
                                            const float epsilon, const int step, const int mode,
                                            const int bias_correction, const float weight_decay);

368
369
370
371
372
373
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
                                std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                                const float beta1, const float beta2, const float epsilon,
                                const int step, const int mode, const int bias_correction,
                                const float weight_decay, DType fp8_dtype);

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
                                       std::vector<std::vector<at::Tensor>> tensor_lists,
                                       at::Tensor lr, const float beta1, const float beta2,
                                       const float epsilon, at::Tensor step, const int mode,
                                       const int bias_correction, const float weight_decay,
                                       at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
                                              std::vector<std::vector<at::Tensor>> tensor_lists,
                                              at::Tensor lr, const float beta1, const float beta2,
                                              const float epsilon, at::Tensor step, const int mode,
                                              const int bias_correction, const float weight_decay,
                                              at::Tensor inv_scale);

void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
                           float momentum, float dampening, float lr, bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale);
392

393
394
395
396
void multi_tensor_compute_scale_and_scale_inv_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    float max_fp8, bool force_pow_2_scales, float epsilon);

397
398
399
400
401
402
403
404
/***************************************************************************************************
 * padding
 **************************************************************************************************/

void fused_multi_row_padding(at::Tensor input, at::Tensor output,
                             std::vector<size_t> input_row_list,
                             std::vector<size_t> padded_input_row_list);

405
406
407
void fused_multi_row_unpadding(at::Tensor input, at::Tensor output,
                               std::vector<size_t> input_row_list,
                               std::vector<size_t> unpadded_input_row_list);
408
409
410
411
412
413
/***************************************************************************************************
 * NVSHMEM APIs
 **************************************************************************************************/

void init_nvshmem_backend(c10d::ProcessGroup *process_group);

414
at::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
415

416
void nvshmem_send_on_current_stream(at::Tensor src, at::Tensor dst, int peer, at::Tensor signal);
417

418
void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_kind);
419
420
421

void nvshmem_finalize();

422
}  // namespace transformer_engine::pytorch
423

424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
/***************************************************************************************************
 * Comm+GEMM Overlap Wrappers
 **************************************************************************************************/

class CommOverlapHelper : torch::CustomClassHolder {
 private:
  bool initialized{false};
  bool backend_is_nccl{false};
  std::map<std::string, c10d::ProcessGroup *> pgs;

 public:
  int myrank = -1;
  int numranks = -1;
  int mylocal = -1;
  int numlocal = -1;
  int mynode = -1;
  int numnodes = -1;

  CommOverlapHelper();

  CommOverlapHelper(c10d::ProcessGroup *world_group,
445
                    std::optional<c10d::ProcessGroup *> intra_node_group);
446
447
448
449
450
451
452
453
454
455
456
457
458
459

  ~CommOverlapHelper();

  void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
                    ExtComm comm);

  void ub_barrier(ExtComm comm);
};

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
 public:
  CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
              CommOverlapHelper *helper, int tp_size, int num_splits = 3,
              int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
460
461
462
463
464
465
              int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
              bool set_sm_margin = true, bool atomic_gemm = false,
              bool rs_overlap_first_gemm = false);

  ~CommOverlap() {}

466
  void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
467

468
469
  at::Tensor get_buffer(bool local_chunk = false,
                        std::optional<std::vector<int64_t>> shape = std::nullopt);
470

471
472
  at::Stream get_communication_stream();

473
474
475
476
477
478
479
480
};  // CommOverlap

class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
 public:
  CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
                 CommOverlapHelper *helper, int tp_size,
                 transformer_engine::CommOverlapType comm_type,
                 int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
481
482
483
484
485
486
                 int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
                 bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
                 bool aggregate = false);

  ~CommOverlapP2P() {}

487
  void copy_into_buffer(const at::Tensor &input, bool local_chunk = false);
488

489
490
  at::Tensor get_buffer(bool local_chunk = false,
                        std::optional<std::vector<int64_t>> shape = std::nullopt);
491

492
493
  at::Stream get_communication_stream();

494
495
};  // CommOverlapP2P

496
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_