extensions.h 23.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

10
11
#include <optional>

Przemek Tredak's avatar
Przemek Tredak committed
12
13
#include "common.h"

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/***************************************************************************************************
 * Permutation
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
    at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
    int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);

at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
                           at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
                           int64_t topK);

at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
                             at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
                             int64_t topK);

std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
                                                     const transformer_engine::DType dtype,
                                                     at::Tensor row_id_map, at::Tensor prob);

34
35
36
37
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

38
39
40
41
42
43
44
45
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
                                               const transformer_engine::DType kv_dtype,
                                               NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                               NVTE_Mask_Type attn_mask_type, float p_dropout,
                                               size_t num_attn_heads, size_t num_gqa_groups,
                                               size_t max_seqlen_q, size_t max_seqlen_kv,
                                               size_t head_dim_qk, size_t head_dim_v,
                                               int64_t window_size_left, int64_t window_size_right);
cyanguwa's avatar
cyanguwa committed
46

47
std::vector<py::object> fused_attn_fwd(
48
49
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
    bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
50
    NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
51
52
    const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
    const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
53
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
54
55
    const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
    py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
56
    const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
57

58
std::vector<py::object> fused_attn_bwd(
59
60
    size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
61
    const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
62
63
    const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
    const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
64
    const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
65
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
66
67
    const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
    py::handle dp_quantizer, py::handle dqkv_quantizer);
Przemek Tredak's avatar
Przemek Tredak committed
68

69
70
71
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

72
73
74
75
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

76
using MaybeTensor = std::optional<at::Tensor>;
77
78

void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
79
80
81
82
83
84
                    std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
                    at::Tensor B_scale_inverse, transformer_engine::DType B_type,
                    std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D,
                    at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
                    at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
                    bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
85
86
                    bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                    bool gemm_producer, at::Tensor counter);
Przemek Tredak's avatar
Przemek Tredak committed
87

88
89
90
91
92
93
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
    std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
    std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
    std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
    transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
    bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
94
95
    bool use_split_accumulator, int math_sm_count);

96
97
98
99
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

100
101
102
103
104
105
106
107
108
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
                                             std::optional<std::vector<py::handle>> output_list,
                                             std::vector<py::handle> quantizer_list,
                                             transformer_engine::DType otype);

at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
                         std::optional<at::Tensor> output = std::nullopt);

namespace transformer_engine::pytorch {
109

110
111
112
113
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

114
115
116
py::object gelu(const at::Tensor &input, py::handle quantizer);

py::object relu(const at::Tensor &input, py::handle quantizer);
117

118
py::object geglu(const at::Tensor &input, py::handle quantizer);
119

120
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
121

122
py::object reglu(const at::Tensor &input, py::handle quantizer);
123

124
py::object swiglu(const at::Tensor &input, py::handle quantizer);
125

126
py::object qgelu(const at::Tensor &input, py::handle quantizer);
127

128
py::object srelu(const at::Tensor &input, py::handle quantizer);
129

130
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
131

132
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
133

134
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
135

136
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
137

138
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
139

140
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
141

142
143
144
145
146
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

}  // namespace transformer_engine::pytorch
147

148
149
150
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
151

152
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
153
154
155
156
                                      const at::Tensor &mu, const at::Tensor &rsigma,
                                      const at::Tensor &gamma, const int sm_margin,
                                      const bool zero_centered_gamma);

157
158
159
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
                                      float eps, py::object ln_out, py::handle quantizer,
                                      transformer_engine::DType out_dtype, const int sm_margin,
160
161
                                      const bool zero_centered_gamma);

162
163
164
165
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

166
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
167
168
                                    const at::Tensor &rsigma, const at::Tensor &gamma,
                                    const int sm_margin, const bool zero_centered_gamma);
169

170
171
172
173
174
175
176
177
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
                                    py::object ln_out, py::handle quantizer,
                                    transformer_engine::DType otype, const int sm_margin,
                                    const bool zero_centered_gamma);

/***************************************************************************************************
 * Cast
 **************************************************************************************************/
178

179
namespace transformer_engine::pytorch {
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
                    std::optional<at::Tensor> noop);

py::object dequantize(const py::handle &input, transformer_engine::DType otype);

std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);

std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
                             py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
                             DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
                             at::Tensor workspace, size_t workspaceSize, bool accumulate,
                             bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
                             std::optional<CommOverlapType> comm_type = std::nullopt,
                             MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
195

196
/***************************************************************************************************
197
 * Cast fusions
198
 **************************************************************************************************/
199

200
201
202
203
204
205
206
207
208
209
210
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
211

212
213
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
214

215
}  // namespace transformer_engine::pytorch
216

217
218
219
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
220

221
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);
222

223
224
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                   float scale_factor);
225

226
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);
227

228
229
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                          float scale_factor);
230

231
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);
232
233
234

at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
235
                                                       float scale_factor);
236

237
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);
238
239
240

at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
241
                                                         float scale_factor);
242

243
244
245
246
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

247
248
249
250
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 const std::string &amax_compute_algo,
251
                                                 transformer_engine::DType fp8_dtype, float margin);
252

253
254
255
256
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

257
258
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
                              const bool transpose_output_memory);
259

260
261
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
                               const bool transpose_output_memory);
262

263
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
264
                                  const at::Tensor &freqs, const int cp_size, const int cp_rank);
265

266
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
267
                                   const at::Tensor &freqs, const int cp_size, const int cp_rank);
268
269

/***************************************************************************************************
270
 * Miscellaneous
271
272
 **************************************************************************************************/

273
274
size_t get_cublasLt_version();

275
276
size_t get_cudnn_version();

277
278
279
280
/***************************************************************************************************
 * Support THD format for Context Parallel
 **************************************************************************************************/

281
282
283
284
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
                                int half_idx);

void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
285
                                    const at::Tensor &cu_seqlens, bool lse_packed);
286
287

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
288
                                    bool lse_packed, int second_half_lse_seqlen);
289
290
291

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
                        const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
292
                        bool only_second_half, bool lse_packed);
293
294
295
296

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
                         const at::Tensor &cu_seqlens, const std::string &first_half,
                         const std::string &second_half);
297

298
299
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
                                       int world_size, int rank);
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

/***************************************************************************************************
 * multi_tensor_* kernels
 **************************************************************************************************/

void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
                             std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::Tensor inv_scale, at::optional<bool> per_tensor_python);

316
using transformer_engine::DType;
317
318
319
320
321
322
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
                            std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                            const float beta1, const float beta2, const float epsilon,
                            const int step, const int mode, const int bias_correction,
                            const float weight_decay);

323
324
325
326
327
328
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
                                            std::vector<std::vector<at::Tensor>> tensor_lists,
                                            const float lr, const float beta1, const float beta2,
                                            const float epsilon, const int step, const int mode,
                                            const int bias_correction, const float weight_decay);

329
330
331
332
333
334
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
                                std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                                const float beta1, const float beta2, const float epsilon,
                                const int step, const int mode, const int bias_correction,
                                const float weight_decay, DType fp8_dtype);

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
                                       std::vector<std::vector<at::Tensor>> tensor_lists,
                                       at::Tensor lr, const float beta1, const float beta2,
                                       const float epsilon, at::Tensor step, const int mode,
                                       const int bias_correction, const float weight_decay,
                                       at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
                                              std::vector<std::vector<at::Tensor>> tensor_lists,
                                              at::Tensor lr, const float beta1, const float beta2,
                                              const float epsilon, at::Tensor step, const int mode,
                                              const int bias_correction, const float weight_decay,
                                              at::Tensor inv_scale);

void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
                           float momentum, float dampening, float lr, bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale);
353

354
355
356
357
358
359
360
361
/***************************************************************************************************
 * padding
 **************************************************************************************************/

void fused_multi_row_padding(at::Tensor input, at::Tensor output,
                             std::vector<size_t> input_row_list,
                             std::vector<size_t> padded_input_row_list);

362
363
364
365
366
367
368
369
370
371
/***************************************************************************************************
 * swizzle
 **************************************************************************************************/

void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans);

at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);

at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
/***************************************************************************************************
 * Comm+GEMM Overlap Wrappers
 **************************************************************************************************/

class CommOverlapHelper : torch::CustomClassHolder {
 private:
  bool initialized{false};
  bool backend_is_nccl{false};
  std::map<std::string, c10d::ProcessGroup *> pgs;

 public:
  int myrank = -1;
  int numranks = -1;
  int mylocal = -1;
  int numlocal = -1;
  int mynode = -1;
  int numnodes = -1;

  CommOverlapHelper();

  CommOverlapHelper(c10d::ProcessGroup *world_group,
                    std::optional<c10d::ProcessGroup *> intra_node_group,
                    std::optional<c10d::ProcessGroup *> inter_node_group);

  ~CommOverlapHelper();

  void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
                    ExtComm comm);

  void ub_barrier(ExtComm comm);
};

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
 public:
  CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
              CommOverlapHelper *helper, int tp_size, int num_splits = 3,
              int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
409
410
411
412
413
414
415
416
417
418
419
420
421
              int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
              bool set_sm_margin = true, bool atomic_gemm = false,
              bool rs_overlap_first_gemm = false);

  ~CommOverlap() {}

  void set_buffer_params(py::handle quantizer);

  void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);

  py::object get_buffer(py::handle quantizer, bool local_chunk = false,
                        std::optional<const std::vector<int64_t>> shape = std::nullopt);

422
423
424
425
426
427
428
429
};  // CommOverlap

class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
 public:
  CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
                 CommOverlapHelper *helper, int tp_size,
                 transformer_engine::CommOverlapType comm_type,
                 int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
430
431
432
433
434
435
436
437
438
439
440
441
442
                 int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
                 bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
                 bool aggregate = false);

  ~CommOverlapP2P() {}

  void set_buffer_params(py::handle quantizer);

  void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);

  py::object get_buffer(py::handle quantizer, bool local_chunk = false,
                        std::optional<const std::vector<int64_t>> shape = std::nullopt);

443
444
};  // CommOverlapP2P

445
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_