extensions.h 23.2 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
9
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_

10
11
#include <optional>

Przemek Tredak's avatar
Przemek Tredak committed
12
13
#include "common.h"

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
/***************************************************************************************************
 * Permutation
 **************************************************************************************************/

std::tuple<at::Tensor, at::Tensor, std::vector<at::Tensor>> moe_permute_fwd(
    at::Tensor input, const transformer_engine::DType dtype, at::Tensor indices,
    int64_t num_out_tokens, std::vector<at::Tensor> workspace, int64_t max_expanded_token_num);

at::Tensor moe_permute_bwd(at::Tensor input, const transformer_engine::DType dtype,
                           at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
                           int64_t topK);

at::Tensor moe_unpermute_fwd(at::Tensor input, const transformer_engine::DType dtype,
                             at::Tensor row_id_map, at::Tensor prob, int64_t num_tokens,
                             int64_t topK);

std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::Tensor input_fwd,
                                                     const transformer_engine::DType dtype,
                                                     at::Tensor row_id_map, at::Tensor prob);

34
35
36
37
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

38
39
40
41
42
43
44
45
NVTE_Fused_Attn_Backend get_fused_attn_backend(const transformer_engine::DType q_dtype,
                                               const transformer_engine::DType kv_dtype,
                                               NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
                                               NVTE_Mask_Type attn_mask_type, float p_dropout,
                                               size_t num_attn_heads, size_t num_gqa_groups,
                                               size_t max_seqlen_q, size_t max_seqlen_kv,
                                               size_t head_dim_qk, size_t head_dim_v,
                                               int64_t window_size_left, int64_t window_size_right);
cyanguwa's avatar
cyanguwa committed
46

47
std::vector<py::object> fused_attn_fwd(
48
49
    size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
    bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
50
    NVTE_Mask_Type attn_mask_type, const std::vector<int64_t> window_size,
51
52
    const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q,
    const py::handle K, const py::handle V, const at::ScalarType fake_dtype,
53
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
54
55
    const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
    py::handle o_quantizer, const c10::optional<at::Tensor> Bias,
56
    const c10::optional<at::Generator> rng_gen, size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
57

58
std::vector<py::object> fused_attn_bwd(
59
60
    size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
61
    const std::vector<int64_t> window_size, bool deterministic, const at::Tensor cu_seqlens_q,
62
63
    const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V,
    const py::handle O, const py::handle dO, const at::ScalarType fake_dtype,
64
    const transformer_engine::DType dqkv_type, const std::vector<at::Tensor> Aux_CTX_Tensors,
65
    const c10::optional<at::Tensor> cu_seqlens_q_padded,
66
67
    const c10::optional<at::Tensor> cu_seqlens_kv_padded, py::handle s_quantizer,
    py::handle dp_quantizer, py::handle dqkv_quantizer);
Przemek Tredak's avatar
Przemek Tredak committed
68

69
70
71
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

72
73
74
75
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

76
using MaybeTensor = std::optional<at::Tensor>;
77
78

void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, transformer_engine::DType A_type,
79
80
81
82
83
84
                    std::vector<int64_t> A_scaling_mode, bool transa, at::Tensor B,
                    at::Tensor B_scale_inverse, transformer_engine::DType B_type,
                    std::vector<int64_t> B_scaling_mode, bool transb, at::Tensor D,
                    at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
                    at::Tensor bias, transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
                    bool grad, at::Tensor workspace, size_t workspaceSize, bool accumulate,
85
86
                    bool use_split_accumulator, int math_sm_count, int m_split, int n_split,
                    bool gemm_producer, at::Tensor counter);
Przemek Tredak's avatar
Przemek Tredak committed
87

88
89
90
91
92
93
std::optional<std::vector<at::Tensor>> te_general_grouped_gemm(
    std::vector<py::handle> A, bool transa, std::vector<py::handle> B, bool transb,
    std::optional<std::vector<at::Tensor>> D, transformer_engine::DType D_type,
    std::vector<int64_t> m_splits, std::vector<at::Tensor> bias,
    transformer_engine::DType bias_type, bool single_output, std::vector<at::Tensor> pre_gelu_out,
    bool grad, std::vector<at::Tensor> workspace, size_t workspaceSize, bool accumulate,
94
95
    bool use_split_accumulator, int math_sm_count);

96
97
98
99
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

100
101
102
103
104
105
106
107
108
std::vector<py::object> fused_multi_quantize(std::vector<py::handle> input_list,
                                             std::optional<std::vector<py::handle>> output_list,
                                             std::vector<py::handle> quantizer_list,
                                             transformer_engine::DType otype);

at::Tensor fp8_transpose(at::Tensor input, transformer_engine::DType otype,
                         std::optional<at::Tensor> output = std::nullopt);

namespace transformer_engine::pytorch {
109

110
111
112
113
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

114
115
116
py::object gelu(const at::Tensor &input, py::handle quantizer);

py::object relu(const at::Tensor &input, py::handle quantizer);
117

118
py::object geglu(const at::Tensor &input, py::handle quantizer);
119

120
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
121

122
py::object reglu(const at::Tensor &input, py::handle quantizer);
123

124
py::object swiglu(const at::Tensor &input, py::handle quantizer);
125

126
py::object qgelu(const at::Tensor &input, py::handle quantizer);
127

128
py::object srelu(const at::Tensor &input, py::handle quantizer);
129

130
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
131

132
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
133

134
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
135

136
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
137

138
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
139

140
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
141

142
143
144
145
146
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);

}  // namespace transformer_engine::pytorch
147

148
149
150
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
151

152
std::vector<py::object> layernorm_bwd(const at::Tensor &dz, const at::Tensor &x,
153
154
155
156
                                      const at::Tensor &mu, const at::Tensor &rsigma,
                                      const at::Tensor &gamma, const int sm_margin,
                                      const bool zero_centered_gamma);

157
158
159
std::vector<py::object> layernorm_fwd(py::handle input, py::handle weight, MaybeTensor bias,
                                      float eps, py::object ln_out, py::handle quantizer,
                                      transformer_engine::DType out_dtype, const int sm_margin,
160
161
                                      const bool zero_centered_gamma);

162
163
164
165
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

166
std::vector<py::object> rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x,
167
168
                                    const at::Tensor &rsigma, const at::Tensor &gamma,
                                    const int sm_margin, const bool zero_centered_gamma);
169

170
171
172
173
174
175
176
177
std::vector<py::object> rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps,
                                    py::object ln_out, py::handle quantizer,
                                    transformer_engine::DType otype, const int sm_margin,
                                    const bool zero_centered_gamma);

/***************************************************************************************************
 * Cast
 **************************************************************************************************/
178

179
namespace transformer_engine::pytorch {
180

181
182
183
184
185
186
187
188
189
190
191
192
193
194
py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output,
                    std::optional<at::Tensor> noop);

py::object dequantize(const py::handle &input, transformer_engine::DType otype);

std::vector<py::object> bgrad_quantize(const at::Tensor &input, py::handle py_quantizer);

std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool transb, py::object D,
                             py::handle quantizer, std::optional<DType> out_dtype, MaybeTensor bias,
                             DType bias_type, bool gelu, MaybeTensor gelu_in, bool grad,
                             at::Tensor workspace, size_t workspaceSize, bool accumulate,
                             bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr,
                             std::optional<CommOverlapType> comm_type = std::nullopt,
                             MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false);
195

196
/***************************************************************************************************
197
 * Cast fusions
198
 **************************************************************************************************/
199

200
201
202
203
204
205
206
207
208
209
210
std::vector<py::object> dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                    py::handle quantizer);

std::vector<py::object> dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
211

212
213
std::vector<py::object> dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input,
                                     py::handle quantizer);
214

215
}  // namespace transformer_engine::pytorch
216

217
218
219
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
220

221
at::Tensor scaled_softmax_forward(at::Tensor input, float scale_factor);
222

223
224
at::Tensor scaled_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                   float scale_factor);
225

226
at::Tensor scaled_masked_softmax_forward(at::Tensor input, at::Tensor mask, float scale_factor);
227

228
229
at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_, at::Tensor softmax_results_,
                                          float scale_factor);
230

231
at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input, float scale_factor);
232
233
234

at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
235
                                                       float scale_factor);
236

237
at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input, float scale_factor);
238
239
240

at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
241
                                                         float scale_factor);
242

243
244
245
246
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

247
248
249
250
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 const std::string &amax_compute_algo,
251
                                                 transformer_engine::DType fp8_dtype, float margin);
252

253
254
255
256
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

257
258
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
                              const bool transpose_output_memory);
259

260
261
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
                               const bool transpose_output_memory);
262

263
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
264
                                  const at::Tensor &freqs, const int cp_size, const int cp_rank);
265

266
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
267
                                   const at::Tensor &freqs, const int cp_size, const int cp_rank);
268
269

/***************************************************************************************************
270
 * Miscellaneous
271
272
 **************************************************************************************************/

273
274
size_t get_cublasLt_version();

275
276
size_t get_cudnn_version();

277
278
279
280
/***************************************************************************************************
 * Support THD format for Context Parallel
 **************************************************************************************************/

281
282
283
284
at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_seqlens,
                                int half_idx);

void thd_second_half_lse_correction(at::Tensor lse, const at::Tensor &lse_per_step,
285
                                    const at::Tensor &cu_seqlens, bool lse_packed);
286
287

at::Tensor thd_read_second_half_lse(const at::Tensor &lse, const at::Tensor &cu_seqlens,
288
                                    bool lse_packed, int second_half_lse_seqlen);
289
290
291

void thd_out_correction(at::Tensor out, const at::Tensor &out_per_step, const at::Tensor &lse,
                        const at::Tensor &lse_per_step, const at::Tensor &cu_seqlens,
292
                        bool only_second_half, bool lse_packed);
293
294
295
296

void thd_grad_correction(at::Tensor grad, const at::Tensor &grad_per_step,
                         const at::Tensor &cu_seqlens, const std::string &first_half,
                         const std::string &second_half);
297

298
299
at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_tokens,
                                       int world_size, int rank);
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315

/***************************************************************************************************
 * multi_tensor_* kernels
 **************************************************************************************************/

void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
                             std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::Tensor inv_scale, at::optional<bool> per_tensor_python);

316
using transformer_engine::DType;
317
318
319
320
321
322
void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
                            std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                            const float beta1, const float beta2, const float epsilon,
                            const int step, const int mode, const int bias_correction,
                            const float weight_decay);

323
324
325
326
327
328
void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag,
                                            std::vector<std::vector<at::Tensor>> tensor_lists,
                                            const float lr, const float beta1, const float beta2,
                                            const float epsilon, const int step, const int mode,
                                            const int bias_correction, const float weight_decay);

329
330
331
332
333
334
void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag,
                                std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                                const float beta1, const float beta2, const float epsilon,
                                const int step, const int mode, const int bias_correction,
                                const float weight_decay, DType fp8_dtype);

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
                                       std::vector<std::vector<at::Tensor>> tensor_lists,
                                       at::Tensor lr, const float beta1, const float beta2,
                                       const float epsilon, at::Tensor step, const int mode,
                                       const int bias_correction, const float weight_decay,
                                       at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
                                              std::vector<std::vector<at::Tensor>> tensor_lists,
                                              at::Tensor lr, const float beta1, const float beta2,
                                              const float epsilon, at::Tensor step, const int mode,
                                              const int bias_correction, const float weight_decay,
                                              at::Tensor inv_scale);

void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
                           float momentum, float dampening, float lr, bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale);
353

354
355
356
357
358
359
360
361
/***************************************************************************************************
 * padding
 **************************************************************************************************/

void fused_multi_row_padding(at::Tensor input, at::Tensor output,
                             std::vector<size_t> input_row_list,
                             std::vector<size_t> padded_input_row_list);

362
363
364
365
366
367
368
369
370
371
/***************************************************************************************************
 * swizzle
 **************************************************************************************************/

void swizzle_scaling_factors(transformer_engine::TensorWrapper &input, bool trans);

at::Tensor rowwise_swizzle(at::Tensor input, at::Tensor scale_inv);

at::Tensor columnwise_swizzle(at::Tensor input, at::Tensor scale_inv);

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
/***************************************************************************************************
 * Comm+GEMM Overlap Wrappers
 **************************************************************************************************/

class CommOverlapHelper : torch::CustomClassHolder {
 private:
  bool initialized{false};
  bool backend_is_nccl{false};
  std::map<std::string, c10d::ProcessGroup *> pgs;

 public:
  int myrank = -1;
  int numranks = -1;
  int mylocal = -1;
  int numlocal = -1;
  int mynode = -1;
  int numnodes = -1;

  CommOverlapHelper();

  CommOverlapHelper(c10d::ProcessGroup *world_group,
393
                    std::optional<c10d::ProcessGroup *> intra_node_group);
394
395
396
397
398
399
400
401
402
403
404
405
406
407

  ~CommOverlapHelper();

  void ub_allgather(void *globaldata, size_t globalbytes, void *localdata, size_t localbytes,
                    ExtComm comm);

  void ub_barrier(ExtComm comm);
};

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
 public:
  CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
              CommOverlapHelper *helper, int tp_size, int num_splits = 3,
              int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
408
409
410
411
412
413
414
415
416
417
418
419
420
              int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 16,
              bool set_sm_margin = true, bool atomic_gemm = false,
              bool rs_overlap_first_gemm = false);

  ~CommOverlap() {}

  void set_buffer_params(py::handle quantizer);

  void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);

  py::object get_buffer(py::handle quantizer, bool local_chunk = false,
                        std::optional<const std::vector<int64_t>> shape = std::nullopt);

421
422
423
424
425
426
427
428
};  // CommOverlap

class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
 public:
  CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
                 CommOverlapHelper *helper, int tp_size,
                 transformer_engine::CommOverlapType comm_type,
                 int num_max_streams = NVTE_COMM_OVERLAP_MAX_STREAMS, int comm_cga_size = 2,
429
430
431
432
433
434
435
436
437
438
439
440
441
                 int gemm_priority = 0, int comm_priority = 0, int num_comm_sm = 3,
                 bool set_sm_margin = true, bool atomic_gemm = false, bool use_ce = true,
                 bool aggregate = false);

  ~CommOverlapP2P() {}

  void set_buffer_params(py::handle quantizer);

  void copy_into_buffer(py::handle input, py::handle quantizer, bool local_chunk = false);

  py::object get_buffer(py::handle quantizer, bool local_chunk = false,
                        std::optional<const std::vector<int64_t>> shape = std::nullopt);

442
443
};  // CommOverlapP2P

444
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_