extensions.h 32.5 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"
Tim Moon's avatar
Tim Moon committed
8
#include "common/common.h"
Przemek Tredak's avatar
Przemek Tredak committed
9

10
11
12
13
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

14
15
16
17
18
19
NVTE_Fused_Attn_Backend get_fused_attn_backend(
                const transformer_engine::DType q_dtype,
                const transformer_engine::DType kv_dtype,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
20
21
22
23
                float p_dropout,
                size_t num_attn_heads, size_t num_gqa_groups,
                size_t max_seqlen_q, size_t max_seqlen_kv,
                size_t head_dim);
cyanguwa's avatar
cyanguwa committed
24
25

std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
26
                size_t max_seqlen, bool is_training,
27
28
29
30
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
31
32
33
                const at::Tensor cu_seqlens,
                const at::Tensor QKV,
                const transformer_engine::DType qkv_type,
34
35
36
37
                const c10::optional<at::Tensor> seq_offsets_q,
                const c10::optional<at::Tensor> seq_offsets_k,
                const c10::optional<at::Tensor> seq_offsets_v,
                const c10::optional<at::Tensor> seq_offsets_o,
cyanguwa's avatar
cyanguwa committed
38
                const c10::optional<at::Tensor> descale_QKV,
39
                const c10::optional<at::Tensor> descale_S,
cyanguwa's avatar
cyanguwa committed
40
41
42
43
44
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_O,
                c10::optional<at::Tensor> amax_S,
                c10::optional<at::Tensor> amax_O,
                const c10::optional<at::Tensor> Bias,
45
46
                const c10::optional<at::Generator> rng_gen,
                size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
47
48

std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
49
                size_t max_seqlen, float attn_scale,
50
51
52
53
                float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
54
55
56
57
58
                const at::Tensor cu_seqlens,
                const at::Tensor QKV,
                const at::Tensor O,
                const at::Tensor dO,
                const transformer_engine::DType qkv_type,
59
                const transformer_engine::DType dqkv_type,
cyanguwa's avatar
cyanguwa committed
60
                const std::vector<at::Tensor> Aux_CTX_Tensors,
61
62
63
64
                const c10::optional<at::Tensor> seq_offsets_q,
                const c10::optional<at::Tensor> seq_offsets_k,
                const c10::optional<at::Tensor> seq_offsets_v,
                const c10::optional<at::Tensor> seq_offsets_o,
cyanguwa's avatar
cyanguwa committed
65
66
67
68
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> descale_S,
                const c10::optional<at::Tensor> descale_O,
                const c10::optional<at::Tensor> descale_dO,
69
                const c10::optional<at::Tensor> descale_dP,
cyanguwa's avatar
cyanguwa committed
70
71
72
73
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_dP,
                const c10::optional<at::Tensor> scale_dQKV,
                c10::optional<at::Tensor> amax_dP,
74
                c10::optional<at::Tensor> amax_dQKV);
cyanguwa's avatar
cyanguwa committed
75
76

std::vector<at::Tensor> fused_attn_fwd_kvpacked(
77
                size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
78
79
80
81
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
82
83
84
85
86
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor KV,
                const transformer_engine::DType qkv_type,
87
88
89
90
                const c10::optional<at::Tensor> seq_offsets_q,
                const c10::optional<at::Tensor> seq_offsets_k,
                const c10::optional<at::Tensor> seq_offsets_v,
                const c10::optional<at::Tensor> seq_offsets_o,
cyanguwa's avatar
cyanguwa committed
91
                const c10::optional<at::Tensor> descale_QKV,
92
                const c10::optional<at::Tensor> descale_S,
cyanguwa's avatar
cyanguwa committed
93
94
95
96
97
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_O,
                c10::optional<at::Tensor> amax_S,
                c10::optional<at::Tensor> amax_O,
                const c10::optional<at::Tensor> Bias,
98
99
                const c10::optional<at::Generator> rng_gen,
                size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
100
101

std::vector<at::Tensor> fused_attn_bwd_kvpacked(
102
103
                size_t max_seqlen_q, size_t max_seqlen_kv,
                float attn_scale, float p_dropout, bool set_zero,
104
105
106
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
107
108
109
110
111
112
113
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor KV,
                const at::Tensor O,
                const at::Tensor dO,
                const transformer_engine::DType qkv_type,
114
                const transformer_engine::DType dqkv_type,
cyanguwa's avatar
cyanguwa committed
115
                const std::vector<at::Tensor> Aux_CTX_Tensors,
116
117
118
119
                const c10::optional<at::Tensor> seq_offsets_q,
                const c10::optional<at::Tensor> seq_offsets_k,
                const c10::optional<at::Tensor> seq_offsets_v,
                const c10::optional<at::Tensor> seq_offsets_o,
cyanguwa's avatar
cyanguwa committed
120
121
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> descale_S,
122
123
                const c10::optional<at::Tensor> descale_O,
                const c10::optional<at::Tensor> descale_dO,
124
                const c10::optional<at::Tensor> descale_dP,
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_dP,
                const c10::optional<at::Tensor> scale_dQKV,
                c10::optional<at::Tensor> amax_dP,
                c10::optional<at::Tensor> amax_dQKV);

std::vector<at::Tensor> fused_attn_fwd(
                size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor K,
                const at::Tensor V,
                const transformer_engine::DType qkv_type,
143
144
145
146
                const c10::optional<at::Tensor> seq_offsets_q,
                const c10::optional<at::Tensor> seq_offsets_k,
                const c10::optional<at::Tensor> seq_offsets_v,
                const c10::optional<at::Tensor> seq_offsets_o,
147
                const c10::optional<at::Tensor> descale_QKV,
148
                const c10::optional<at::Tensor> descale_S,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_O,
                c10::optional<at::Tensor> amax_S,
                c10::optional<at::Tensor> amax_O,
                const c10::optional<at::Tensor> Bias,
                const c10::optional<at::Generator> rng_gen,
                size_t rng_elts_per_thread);

std::vector<at::Tensor> fused_attn_bwd(
                size_t max_seqlen_q, size_t max_seqlen_kv,
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor K,
                const at::Tensor V,
                const at::Tensor O,
                const at::Tensor dO,
                const transformer_engine::DType qkv_type,
171
                const transformer_engine::DType dqkv_type,
172
                const std::vector<at::Tensor> Aux_CTX_Tensors,
173
174
175
176
                const c10::optional<at::Tensor> seq_offsets_q,
                const c10::optional<at::Tensor> seq_offsets_k,
                const c10::optional<at::Tensor> seq_offsets_v,
                const c10::optional<at::Tensor> seq_offsets_o,
177
178
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> descale_S,
cyanguwa's avatar
cyanguwa committed
179
180
                const c10::optional<at::Tensor> descale_O,
                const c10::optional<at::Tensor> descale_dO,
181
                const c10::optional<at::Tensor> descale_dP,
cyanguwa's avatar
cyanguwa committed
182
183
184
185
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_dP,
                const c10::optional<at::Tensor> scale_dQKV,
                c10::optional<at::Tensor> amax_dP,
186
                c10::optional<at::Tensor> amax_dQKV);
Przemek Tredak's avatar
Przemek Tredak committed
187

188
189
190
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

191
192
193
194
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

Przemek Tredak's avatar
Przemek Tredak committed
195
196
197
198
199
200
201
202
203
void te_gemm(at::Tensor A,
             at::Tensor A_scale_inverse,
             transformer_engine::DType A_type,
             bool transa,
             at::Tensor B,
             at::Tensor B_scale_inverse,
             transformer_engine::DType B_type,
             bool transb,
             at::Tensor D,
204
             at::Tensor D_scale,
Przemek Tredak's avatar
Przemek Tredak committed
205
             transformer_engine::DType D_type,
206
             at::Tensor D_amax,
Przemek Tredak's avatar
Przemek Tredak committed
207
             at::Tensor bias,
208
             transformer_engine::DType bias_type,
Przemek Tredak's avatar
Przemek Tredak committed
209
210
211
212
213
             at::Tensor pre_gelu_out,
             bool grad,
             at::Tensor workspace,
             size_t workspaceSize,
             bool accumulate,
214
215
             bool use_split_accumulator,
             int math_sm_count
Przemek Tredak's avatar
Przemek Tredak committed
216
217
);

218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
void te_atomic_gemm(at::Tensor A,
                    at::Tensor A_scale_inverse,
                    transformer_engine::DType A_type,
                    bool transa,
                    at::Tensor B,
                    at::Tensor B_scale_inverse,
                    transformer_engine::DType B_type,
                    bool transb,
                    at::Tensor D,
                    at::Tensor D_scale,
                    transformer_engine::DType D_type,
                    at::Tensor D_amax,
                    at::Tensor bias,
                    transformer_engine::DType bias_type,
                    at::Tensor pre_gelu_out,
                    bool grad,
                    at::Tensor workspace,
                    size_t workspaceSize,
                    bool accumulate,
                    bool use_split_accumulator,
                    int math_sm_count,
                    int m_split,
                    int n_split,
                    bool gemm_producer,
                    at::Tensor counter
);
Przemek Tredak's avatar
Przemek Tredak committed
244

245
246
247
248
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

Przemek Tredak's avatar
Przemek Tredak committed
249
250
251
252
253
254
255
256
257
258
void fused_cast_transpose(at::Tensor input,
                          at::Tensor scale,
                          at::Tensor amax,
                          at::Tensor scale_inv,
                          at::Tensor input_cast,
                          at::Tensor input_transpose,
                          transformer_engine::DType otype
);


259
260
261
262
263
264
265
266
267
268
269
void fused_cast_transpose_noop(at::Tensor input,
                               at::Tensor noop,
                               at::Tensor scale,
                               at::Tensor amax,
                               at::Tensor scale_inv,
                               at::Tensor input_cast,
                               at::Tensor input_transpose,
                               transformer_engine::DType otype
);


Przemek Tredak's avatar
Przemek Tredak committed
270
271
272
273
274
275
276
277
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
                                                   at::Tensor scale,
                                                   at::Tensor amax,
                                                   at::Tensor scale_inv,
                                                   transformer_engine::DType otype
);


278
279
280
281
282
283
284
285
286
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
                                              at::Tensor scale,
                                              at::Tensor amax,
                                              at::Tensor scale_inv,
                                              transformer_engine::DType otype,
                                              transformer_engine::DType grad_bias_type
);


Przemek Tredak's avatar
Przemek Tredak committed
287
288
289
290
291
292
293
294
295
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
                                                         at::Tensor gelu_input,
                                                         at::Tensor scale,
                                                         at::Tensor amax,
                                                         at::Tensor scale_inv,
                                                         transformer_engine::DType otype
);


Tim Moon's avatar
Tim Moon committed
296
297
298
299
300
301
302
303
304
305
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
                                std::vector<at::Tensor> scale_list,
                                std::vector<at::Tensor> cast_output_list,
                                std::vector<at::Tensor> transposed_output_list,
                                std::vector<at::Tensor> amax_output_list,
                                std::vector<at::Tensor> scale_inv_output_list,
                                transformer_engine::DType otype
);


Przemek Tredak's avatar
Przemek Tredak committed
306
307
308
309
at::Tensor fp8_transpose(at::Tensor input,
                         transformer_engine::DType otype
);

310
311
312
313
314
315
316
317
318
319
320
void fp8_transpose_noalloc(at::Tensor input,
                           at::Tensor output,
                           transformer_engine::DType otype
);

void fp8_transpose_noalloc_noop(at::Tensor input,
                                at::Tensor output,
                                at::Tensor noop,
                                transformer_engine::DType otype
);

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

at::Tensor gelu(at::Tensor input,
                at::Tensor scale,
                at::Tensor amax,
                at::Tensor scale_inv,
                transformer_engine::DType otype
);

at::Tensor relu(at::Tensor input,
                at::Tensor scale,
                at::Tensor amax,
                at::Tensor scale_inv,
                transformer_engine::DType otype
);

at::Tensor geglu(at::Tensor input,
                 at::Tensor scale,
                 at::Tensor amax,
                 at::Tensor scale_inv,
                 transformer_engine::DType otype
);

at::Tensor reglu(at::Tensor input,
                 at::Tensor scale,
                 at::Tensor amax,
                 at::Tensor scale_inv,
                 transformer_engine::DType otype
);

at::Tensor swiglu(at::Tensor input,
                  at::Tensor scale,
                  at::Tensor amax,
                  at::Tensor scale_inv,
                  transformer_engine::DType otype
);

360
361
362
363
364
365
366
at::Tensor qgelu(at::Tensor input,
                  at::Tensor scale,
                  at::Tensor amax,
                  at::Tensor scale_inv,
                  transformer_engine::DType otype
);

367
368
369
370
371
372
373
at::Tensor srelu(at::Tensor input,
                at::Tensor scale,
                at::Tensor amax,
                at::Tensor scale_inv,
                transformer_engine::DType otype
);

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
at::Tensor dgelu(at::Tensor grad,
                 at::Tensor input,
                 transformer_engine::DType otype
);

at::Tensor drelu(at::Tensor grad,
                 at::Tensor input,
                 transformer_engine::DType otype
);

at::Tensor dgeglu(at::Tensor grad,
                  at::Tensor input,
                  transformer_engine::DType otype
);

at::Tensor dreglu(at::Tensor grad,
                  at::Tensor input,
                  transformer_engine::DType otype
);
Przemek Tredak's avatar
Przemek Tredak committed
393

394
395
396
at::Tensor dswiglu(at::Tensor grad,
                   at::Tensor input,
                   transformer_engine::DType otype
Przemek Tredak's avatar
Przemek Tredak committed
397
398
);

399
400
401
402
403
at::Tensor dqgelu(at::Tensor grad,
                   at::Tensor input,
                   transformer_engine::DType otype
);

404
405
406
407
408
at::Tensor dsrelu(at::Tensor grad,
                 at::Tensor input,
                 transformer_engine::DType otype
);

409
410
411
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
412
413
414
415
416

std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
                                      const at::Tensor &x,
                                      const at::Tensor &mu,
                                      const at::Tensor &rsigma,
417
                                      const at::Tensor &gamma,
418
419
                                      const int sm_margin,
                                      const bool zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
420
421
422
423
424
425
426
427
428
429
);


std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
                                          const at::Tensor &weight,
                                          const at::Tensor &bias,
                                          float eps,
                                          at::Tensor scale,
                                          at::Tensor amax,
                                          at::Tensor scale_inv,
430
                                          transformer_engine::DType otype,
431
432
                                          const int sm_margin,
                                          const bool zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
433
434
);

435
436
437
438
439
440
441
442
443
444
445
446
447
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
                                                  const at::Tensor &weight,
                                                  const at::Tensor &bias,
                                                  float eps,
                                                  at::Tensor scale,
                                                  at::Tensor ln_out,
                                                  at::Tensor amax,
                                                  at::Tensor scale_inv,
                                                  transformer_engine::DType otype,
                                                  const int sm_margin,
                                                  const bool zero_centered_gamma
);

448
449
450
451
452
453
454
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
                                 const at::Tensor &weight,
                                 const at::Tensor &bias,
                                 float eps,
                                 at::Tensor scale,
                                 at::Tensor amax,
                                 at::Tensor scale_inv,
455
                                 transformer_engine::DType otype,
456
                                 const int sm_margin,
457
                                 const bool zero_centered_gamma
458
);
Przemek Tredak's avatar
Przemek Tredak committed
459
460
461
462

std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
                                      const at::Tensor &weight,
                                      const at::Tensor &bias,
463
                                      float eps,
464
465
                                      const int sm_margin,
                                      const bool zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
466
467
);

468
469
470
471
472
473
474
475
476
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
                                      const at::Tensor &weight,
                                      const at::Tensor &bias,
                                      at::Tensor ln_out,
                                      float eps,
                                      const int sm_margin,
                                      const bool zero_centered_gamma
);

477
478
479
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
                             const at::Tensor &weight,
                             const at::Tensor &bias,
480
                             float eps,
481
                             const int sm_margin,
482
                             const bool zero_centered_gamma
483
);
Przemek Tredak's avatar
Przemek Tredak committed
484

485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
                                    const at::Tensor &x,
                                    const at::Tensor &rsigma,
                                    const at::Tensor &gamma,
                                    const int sm_margin,
                                    const bool zero_centered_gamma
);


std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
                                        const at::Tensor &weight,
                                        float eps,
                                        at::Tensor scale,
                                        at::Tensor amax,
                                        at::Tensor scale_inv,
                                        transformer_engine::DType otype,
                                        const int sm_margin,
                                        const bool zero_centered_gamma
);

std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
                                                const at::Tensor &weight,
                                                float eps,
                                                at::Tensor scale,
                                                at::Tensor ln_out,
                                                at::Tensor amax,
                                                at::Tensor scale_inv,
                                                transformer_engine::DType otype,
                                                const int sm_margin,
                                                const bool zero_centered_gamma
);

at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
                               const at::Tensor &weight,
                               float eps,
                               at::Tensor scale,
                               at::Tensor amax,
                               at::Tensor scale_inv,
                               transformer_engine::DType otype,
528
                               const int sm_margin,
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
                               const bool zero_centered_gamma
);

std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input,
                                    const at::Tensor &weight,
                                    float eps,
                                    const int sm_margin,
                                    const bool zero_centered_gamma
);

std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
                                    const at::Tensor &weight,
                                    at::Tensor ln_out,
                                    float eps,
                                    const int sm_margin,
                                    const bool zero_centered_gamma
);

at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
                           const at::Tensor &weight,
                           float eps,
550
                           const int sm_margin,
551
552
553
554
555
556
557
                           const bool zero_centered_gamma
);

/***************************************************************************************************
 * Cast
 **************************************************************************************************/

Przemek Tredak's avatar
Przemek Tredak committed
558
559
560
561
562
563
564
565
at::Tensor cast_to_fp8(const at::Tensor &input,
                       const at::Tensor &scale,
                       at::Tensor amax,
                       at::Tensor scale_inv,
                       transformer_engine::DType otype
);


566
567
568
569
570
571
572
573
574
void cast_to_fp8_noalloc(const at::Tensor &input,
                         const at::Tensor &scale,
                         at::Tensor output,
                         at::Tensor amax,
                         at::Tensor scale_inv,
                         transformer_engine::DType otype
);


Przemek Tredak's avatar
Przemek Tredak committed
575
576
577
578
579
at::Tensor cast_from_fp8(const at::Tensor &input,
                         const at::Tensor &scale_inv,
                         transformer_engine::DType itype,
                         transformer_engine::DType otype
);
580

581
582
583
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616

at::Tensor scaled_softmax_forward(at::Tensor input,
                                  float scale_factor
);


at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
                                   at::Tensor softmax_results_,
                                   float scale_factor
);


at::Tensor scaled_masked_softmax_forward(at::Tensor input,
                                         at::Tensor mask,
                                         float scale_factor
);


at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
                                          at::Tensor softmax_results_,
                                          float scale_factor
);


at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
                                                      float scale_factor
);


at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
                                                       float scale_factor
);
617

618
619
620
621
622
623
624
625
626
627
628

at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input,
                                                        float scale_factor
);


at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
                                                         float scale_factor
);

629
630
631
632
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

633
634
635
636
637
638
639
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 std::vector<at::Tensor> scale_invs,
                                                 const std::string &amax_compute_algo,
                                                 transformer_engine::DType fp8_dtype,
                                                 float margin);
640

641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

at::Tensor fused_rope_forward(const at::Tensor &input,
                              const at::Tensor &freqs,
                              const bool transpose_output_memory
);

at::Tensor fused_rope_backward(const at::Tensor &output_grads,
                               const at::Tensor &freqs,
                               const bool transpose_output_memory
);

at::Tensor fused_rope_thd_forward(const at::Tensor &input,
                                  const at::Tensor &cu_seqlens,
                                  const at::Tensor &freqs
);

at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
                                   const at::Tensor &cu_seqlens,
                                   const at::Tensor &freqs
);

/***************************************************************************************************
666
 * Miscellaneous
667
668
 **************************************************************************************************/

669
670
size_t get_cublasLt_version();

671
672
size_t get_cudnn_version();

673
674
675
bool userbuf_comm_available();

void placeholder();
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717


/***************************************************************************************************
 * Support THD format for Context Parallel
 **************************************************************************************************/

at::Tensor thd_read_half_tensor(const at::Tensor &tensor,
                                const at::Tensor &cu_seqlens,
                                int half_idx
);

void thd_second_half_lse_correction(at::Tensor lse,
                                    const at::Tensor &lse_per_step,
                                    const at::Tensor &cu_seqlens,
                                    int total_tokens
);

at::Tensor thd_read_second_half_lse(const at::Tensor &lse,
                                    const at::Tensor &cu_seqlens,
                                    int total_tokens
);

void thd_out_correction(at::Tensor out,
                        const at::Tensor &out_per_step,
                        const at::Tensor &lse,
                        const at::Tensor &lse_per_step,
                        const at::Tensor &cu_seqlens,
                        bool only_second_half
);

void thd_grad_correction(at::Tensor grad,
                         const at::Tensor &grad_per_step,
                         const at::Tensor &cu_seqlens,
                         const std::string &first_half,
                         const std::string &second_half
);

at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens,
                                       int total_tokens,
                                       int world_size,
                                       int rank
);
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758


/***************************************************************************************************
 * multi_tensor_* kernels
 **************************************************************************************************/

void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag,
                             std::vector<std::vector<at::Tensor>> tensor_lists, float scale);

std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::optional<bool> per_tensor_python);

std::tuple<at::Tensor, at::Tensor> multi_tensor_unscale_l2norm_cuda(
    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
    at::Tensor inv_scale, at::optional<bool> per_tensor_python);

void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag,
                            std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
                            const float beta1, const float beta2, const float epsilon,
                            const int step, const int mode, const int bias_correction,
                            const float weight_decay);

void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag,
                                       std::vector<std::vector<at::Tensor>> tensor_lists,
                                       at::Tensor lr, const float beta1, const float beta2,
                                       const float epsilon, at::Tensor step, const int mode,
                                       const int bias_correction, const float weight_decay,
                                       at::Tensor inv_scale);

void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag,
                                              std::vector<std::vector<at::Tensor>> tensor_lists,
                                              at::Tensor lr, const float beta1, const float beta2,
                                              const float epsilon, at::Tensor step, const int mode,
                                              const int bias_correction, const float weight_decay,
                                              at::Tensor inv_scale);

void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                           std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
                           float momentum, float dampening, float lr, bool nesterov, bool first_run,
                           bool wd_after_momentum, float scale);