extensions.h 25.9 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "common.h"
Tim Moon's avatar
Tim Moon committed
8
#include "common/common.h"
Przemek Tredak's avatar
Przemek Tredak committed
9

10
11
12
13
/***************************************************************************************************
 * Attention
 **************************************************************************************************/

14
15
16
17
18
19
NVTE_Fused_Attn_Backend get_fused_attn_backend(
                const transformer_engine::DType q_dtype,
                const transformer_engine::DType kv_dtype,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
20
21
22
23
                float p_dropout,
                size_t num_attn_heads, size_t num_gqa_groups,
                size_t max_seqlen_q, size_t max_seqlen_kv,
                size_t head_dim);
cyanguwa's avatar
cyanguwa committed
24
25

std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
26
                size_t max_seqlen, bool is_training,
27
28
29
30
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
31
32
33
34
35
36
37
38
39
                const at::Tensor cu_seqlens,
                const at::Tensor QKV,
                const transformer_engine::DType qkv_type,
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_O,
                c10::optional<at::Tensor> amax_S,
                c10::optional<at::Tensor> amax_O,
                const c10::optional<at::Tensor> Bias,
40
41
                const c10::optional<at::Generator> rng_gen,
                size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
42
43

std::vector<at::Tensor> fused_attn_bwd_qkvpacked(
44
                size_t max_seqlen, float attn_scale,
45
46
47
48
                float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
                const at::Tensor cu_seqlens,
                const at::Tensor QKV,
                const at::Tensor O,
                const at::Tensor dO,
                const transformer_engine::DType qkv_type,
                const std::vector<at::Tensor> Aux_CTX_Tensors,
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> descale_S,
                const c10::optional<at::Tensor> descale_O,
                const c10::optional<at::Tensor> descale_dO,
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_dP,
                const c10::optional<at::Tensor> scale_dQKV,
                c10::optional<at::Tensor> amax_dP,
63
                c10::optional<at::Tensor> amax_dQKV);
cyanguwa's avatar
cyanguwa committed
64
65

std::vector<at::Tensor> fused_attn_fwd_kvpacked(
66
                size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
67
68
69
70
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
71
72
73
74
75
76
77
78
79
80
81
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor KV,
                const transformer_engine::DType qkv_type,
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_O,
                c10::optional<at::Tensor> amax_S,
                c10::optional<at::Tensor> amax_O,
                const c10::optional<at::Tensor> Bias,
82
83
                const c10::optional<at::Generator> rng_gen,
                size_t rng_elts_per_thread);
cyanguwa's avatar
cyanguwa committed
84
85

std::vector<at::Tensor> fused_attn_bwd_kvpacked(
86
87
                size_t max_seqlen_q, size_t max_seqlen_kv,
                float attn_scale, float p_dropout, bool set_zero,
88
89
90
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
cyanguwa's avatar
cyanguwa committed
91
92
93
94
95
96
97
98
99
100
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor KV,
                const at::Tensor O,
                const at::Tensor dO,
                const transformer_engine::DType qkv_type,
                const std::vector<at::Tensor> Aux_CTX_Tensors,
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> descale_S,
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
                const c10::optional<at::Tensor> descale_O,
                const c10::optional<at::Tensor> descale_dO,
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_dP,
                const c10::optional<at::Tensor> scale_dQKV,
                c10::optional<at::Tensor> amax_dP,
                c10::optional<at::Tensor> amax_dQKV);

std::vector<at::Tensor> fused_attn_fwd(
                size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor K,
                const at::Tensor V,
                const transformer_engine::DType qkv_type,
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_O,
                c10::optional<at::Tensor> amax_S,
                c10::optional<at::Tensor> amax_O,
                const c10::optional<at::Tensor> Bias,
                const c10::optional<at::Generator> rng_gen,
                size_t rng_elts_per_thread);

std::vector<at::Tensor> fused_attn_bwd(
                size_t max_seqlen_q, size_t max_seqlen_kv,
                float attn_scale, float p_dropout, bool set_zero,
                NVTE_QKV_Layout qkv_layout,
                NVTE_Bias_Type bias_type,
                NVTE_Mask_Type attn_mask_type,
                const at::Tensor cu_seqlens_q,
                const at::Tensor cu_seqlens_kv,
                const at::Tensor Q,
                const at::Tensor K,
                const at::Tensor V,
                const at::Tensor O,
                const at::Tensor dO,
                const transformer_engine::DType qkv_type,
                const std::vector<at::Tensor> Aux_CTX_Tensors,
                const c10::optional<at::Tensor> descale_QKV,
                const c10::optional<at::Tensor> descale_S,
cyanguwa's avatar
cyanguwa committed
147
148
149
150
151
152
                const c10::optional<at::Tensor> descale_O,
                const c10::optional<at::Tensor> descale_dO,
                const c10::optional<at::Tensor> scale_S,
                const c10::optional<at::Tensor> scale_dP,
                const c10::optional<at::Tensor> scale_dQKV,
                c10::optional<at::Tensor> amax_dP,
153
                c10::optional<at::Tensor> amax_dQKV);
Przemek Tredak's avatar
Przemek Tredak committed
154

155
156
157
at::Tensor fa_prepare_fwd(at::Tensor qkvi);
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v);

158
159
160
161
/***************************************************************************************************
 * GEMM
 **************************************************************************************************/

Przemek Tredak's avatar
Przemek Tredak committed
162
163
164
165
166
167
168
169
170
void te_gemm(at::Tensor A,
             at::Tensor A_scale_inverse,
             transformer_engine::DType A_type,
             bool transa,
             at::Tensor B,
             at::Tensor B_scale_inverse,
             transformer_engine::DType B_type,
             bool transb,
             at::Tensor D,
171
             at::Tensor D_scale,
Przemek Tredak's avatar
Przemek Tredak committed
172
             transformer_engine::DType D_type,
173
             at::Tensor D_amax,
Przemek Tredak's avatar
Przemek Tredak committed
174
             at::Tensor bias,
175
             transformer_engine::DType bias_type,
Przemek Tredak's avatar
Przemek Tredak committed
176
177
178
179
180
             at::Tensor pre_gelu_out,
             bool grad,
             at::Tensor workspace,
             size_t workspaceSize,
             bool accumulate,
181
182
             bool use_split_accumulator,
             int math_sm_count
Przemek Tredak's avatar
Przemek Tredak committed
183
184
);

185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
void te_atomic_gemm(at::Tensor A,
                    at::Tensor A_scale_inverse,
                    transformer_engine::DType A_type,
                    bool transa,
                    at::Tensor B,
                    at::Tensor B_scale_inverse,
                    transformer_engine::DType B_type,
                    bool transb,
                    at::Tensor D,
                    at::Tensor D_scale,
                    transformer_engine::DType D_type,
                    at::Tensor D_amax,
                    at::Tensor bias,
                    transformer_engine::DType bias_type,
                    at::Tensor pre_gelu_out,
                    bool grad,
                    at::Tensor workspace,
                    size_t workspaceSize,
                    bool accumulate,
                    bool use_split_accumulator,
                    int math_sm_count,
                    int m_split,
                    int n_split,
                    bool gemm_producer,
                    at::Tensor counter
);
Przemek Tredak's avatar
Przemek Tredak committed
211

212
213
214
215
/***************************************************************************************************
 * Transpose
 **************************************************************************************************/

Przemek Tredak's avatar
Przemek Tredak committed
216
217
218
219
220
221
222
223
224
225
void fused_cast_transpose(at::Tensor input,
                          at::Tensor scale,
                          at::Tensor amax,
                          at::Tensor scale_inv,
                          at::Tensor input_cast,
                          at::Tensor input_transpose,
                          transformer_engine::DType otype
);


226
227
228
229
230
231
232
233
234
235
236
void fused_cast_transpose_noop(at::Tensor input,
                               at::Tensor noop,
                               at::Tensor scale,
                               at::Tensor amax,
                               at::Tensor scale_inv,
                               at::Tensor input_cast,
                               at::Tensor input_transpose,
                               transformer_engine::DType otype
);


Przemek Tredak's avatar
Przemek Tredak committed
237
238
239
240
241
242
243
244
std::vector<at::Tensor> fused_cast_transpose_bgrad(at::Tensor grad_output,
                                                   at::Tensor scale,
                                                   at::Tensor amax,
                                                   at::Tensor scale_inv,
                                                   transformer_engine::DType otype
);


245
246
247
248
249
250
251
252
253
std::vector<at::Tensor> fused_fp8_transpose_bgrad(at::Tensor grad_output,
                                              at::Tensor scale,
                                              at::Tensor amax,
                                              at::Tensor scale_inv,
                                              transformer_engine::DType otype,
                                              transformer_engine::DType grad_bias_type
);


Przemek Tredak's avatar
Przemek Tredak committed
254
255
256
257
258
259
260
261
262
std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
                                                         at::Tensor gelu_input,
                                                         at::Tensor scale,
                                                         at::Tensor amax,
                                                         at::Tensor scale_inv,
                                                         transformer_engine::DType otype
);


Tim Moon's avatar
Tim Moon committed
263
264
265
266
267
268
269
270
271
272
void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
                                std::vector<at::Tensor> scale_list,
                                std::vector<at::Tensor> cast_output_list,
                                std::vector<at::Tensor> transposed_output_list,
                                std::vector<at::Tensor> amax_output_list,
                                std::vector<at::Tensor> scale_inv_output_list,
                                transformer_engine::DType otype
);


Przemek Tredak's avatar
Przemek Tredak committed
273
274
275
276
at::Tensor fp8_transpose(at::Tensor input,
                         transformer_engine::DType otype
);

277
278
279
280
281
282
283
284
285
286
287
void fp8_transpose_noalloc(at::Tensor input,
                           at::Tensor output,
                           transformer_engine::DType otype
);

void fp8_transpose_noalloc_noop(at::Tensor input,
                                at::Tensor output,
                                at::Tensor noop,
                                transformer_engine::DType otype
);

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
/***************************************************************************************************
 * Activations
 **************************************************************************************************/

at::Tensor gelu(at::Tensor input,
                at::Tensor scale,
                at::Tensor amax,
                at::Tensor scale_inv,
                transformer_engine::DType otype
);

at::Tensor relu(at::Tensor input,
                at::Tensor scale,
                at::Tensor amax,
                at::Tensor scale_inv,
                transformer_engine::DType otype
);

at::Tensor geglu(at::Tensor input,
                 at::Tensor scale,
                 at::Tensor amax,
                 at::Tensor scale_inv,
                 transformer_engine::DType otype
);

at::Tensor reglu(at::Tensor input,
                 at::Tensor scale,
                 at::Tensor amax,
                 at::Tensor scale_inv,
                 transformer_engine::DType otype
);

at::Tensor swiglu(at::Tensor input,
                  at::Tensor scale,
                  at::Tensor amax,
                  at::Tensor scale_inv,
                  transformer_engine::DType otype
);

327
328
329
330
331
332
333
at::Tensor qgelu(at::Tensor input,
                  at::Tensor scale,
                  at::Tensor amax,
                  at::Tensor scale_inv,
                  transformer_engine::DType otype
);

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
at::Tensor dgelu(at::Tensor grad,
                 at::Tensor input,
                 transformer_engine::DType otype
);

at::Tensor drelu(at::Tensor grad,
                 at::Tensor input,
                 transformer_engine::DType otype
);

at::Tensor dgeglu(at::Tensor grad,
                  at::Tensor input,
                  transformer_engine::DType otype
);

at::Tensor dreglu(at::Tensor grad,
                  at::Tensor input,
                  transformer_engine::DType otype
);
Przemek Tredak's avatar
Przemek Tredak committed
353

354
355
356
at::Tensor dswiglu(at::Tensor grad,
                   at::Tensor input,
                   transformer_engine::DType otype
Przemek Tredak's avatar
Przemek Tredak committed
357
358
);

359
360
361
362
363
at::Tensor dqgelu(at::Tensor grad,
                   at::Tensor input,
                   transformer_engine::DType otype
);

364
365
366
/***************************************************************************************************
 * LayerNorm
 **************************************************************************************************/
Przemek Tredak's avatar
Przemek Tredak committed
367
368
369
370
371

std::vector<at::Tensor> layernorm_bwd(const at::Tensor &dz,
                                      const at::Tensor &x,
                                      const at::Tensor &mu,
                                      const at::Tensor &rsigma,
372
                                      const at::Tensor &gamma,
373
374
                                      const int sm_margin,
                                      const bool zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
375
376
377
378
379
380
381
382
383
384
);


std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
                                          const at::Tensor &weight,
                                          const at::Tensor &bias,
                                          float eps,
                                          at::Tensor scale,
                                          at::Tensor amax,
                                          at::Tensor scale_inv,
385
                                          transformer_engine::DType otype,
386
387
                                          const int sm_margin,
                                          const bool zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
388
389
);

390
391
392
393
394
395
396
397
398
399
400
401
402
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
                                                  const at::Tensor &weight,
                                                  const at::Tensor &bias,
                                                  float eps,
                                                  at::Tensor scale,
                                                  at::Tensor ln_out,
                                                  at::Tensor amax,
                                                  at::Tensor scale_inv,
                                                  transformer_engine::DType otype,
                                                  const int sm_margin,
                                                  const bool zero_centered_gamma
);

403
404
405
406
407
408
409
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
                                 const at::Tensor &weight,
                                 const at::Tensor &bias,
                                 float eps,
                                 at::Tensor scale,
                                 at::Tensor amax,
                                 at::Tensor scale_inv,
410
411
                                 transformer_engine::DType otype,
                                 const bool zero_centered_gamma
412
);
Przemek Tredak's avatar
Przemek Tredak committed
413
414
415
416

std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
                                      const at::Tensor &weight,
                                      const at::Tensor &bias,
417
                                      float eps,
418
419
                                      const int sm_margin,
                                      const bool zero_centered_gamma
Przemek Tredak's avatar
Przemek Tredak committed
420
421
);

422
423
424
425
426
427
428
429
430
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
                                      const at::Tensor &weight,
                                      const at::Tensor &bias,
                                      at::Tensor ln_out,
                                      float eps,
                                      const int sm_margin,
                                      const bool zero_centered_gamma
);

431
432
433
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
                             const at::Tensor &weight,
                             const at::Tensor &bias,
434
435
                             float eps,
                             const bool zero_centered_gamma
436
);
Przemek Tredak's avatar
Przemek Tredak committed
437

438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
/***************************************************************************************************
 * RMSNorm
 **************************************************************************************************/

std::vector<at::Tensor> rmsnorm_bwd(const at::Tensor &dz,
                                    const at::Tensor &x,
                                    const at::Tensor &rsigma,
                                    const at::Tensor &gamma,
                                    const int sm_margin,
                                    const bool zero_centered_gamma
);


std::vector<at::Tensor> rmsnorm_fwd_fp8(const at::Tensor &input,
                                        const at::Tensor &weight,
                                        float eps,
                                        at::Tensor scale,
                                        at::Tensor amax,
                                        at::Tensor scale_inv,
                                        transformer_engine::DType otype,
                                        const int sm_margin,
                                        const bool zero_centered_gamma
);

std::vector<at::Tensor> rmsnorm_fwd_fp8_noalloc(const at::Tensor &input,
                                                const at::Tensor &weight,
                                                float eps,
                                                at::Tensor scale,
                                                at::Tensor ln_out,
                                                at::Tensor amax,
                                                at::Tensor scale_inv,
                                                transformer_engine::DType otype,
                                                const int sm_margin,
                                                const bool zero_centered_gamma
);

at::Tensor rmsnorm_fwd_fp8_inf(const at::Tensor &input,
                               const at::Tensor &weight,
                               float eps,
                               at::Tensor scale,
                               at::Tensor amax,
                               at::Tensor scale_inv,
                               transformer_engine::DType otype,
                               const bool zero_centered_gamma
);

std::vector<at::Tensor> rmsnorm_fwd(const at::Tensor &input,
                                    const at::Tensor &weight,
                                    float eps,
                                    const int sm_margin,
                                    const bool zero_centered_gamma
);

std::vector<at::Tensor> rmsnorm_fwd_noalloc(const at::Tensor &input,
                                    const at::Tensor &weight,
                                    at::Tensor ln_out,
                                    float eps,
                                    const int sm_margin,
                                    const bool zero_centered_gamma
);

at::Tensor rmsnorm_fwd_inf(const at::Tensor &input,
                           const at::Tensor &weight,
                           float eps,
                           const bool zero_centered_gamma
);

/***************************************************************************************************
 * Cast
 **************************************************************************************************/

Przemek Tredak's avatar
Przemek Tredak committed
509
510
511
512
513
514
515
516
at::Tensor cast_to_fp8(const at::Tensor &input,
                       const at::Tensor &scale,
                       at::Tensor amax,
                       at::Tensor scale_inv,
                       transformer_engine::DType otype
);


517
518
519
520
521
522
523
524
525
void cast_to_fp8_noalloc(const at::Tensor &input,
                         const at::Tensor &scale,
                         at::Tensor output,
                         at::Tensor amax,
                         at::Tensor scale_inv,
                         transformer_engine::DType otype
);


Przemek Tredak's avatar
Przemek Tredak committed
526
527
528
529
530
at::Tensor cast_from_fp8(const at::Tensor &input,
                         const at::Tensor &scale_inv,
                         transformer_engine::DType itype,
                         transformer_engine::DType otype
);
531

532
533
534
/***************************************************************************************************
 * Softmax
 **************************************************************************************************/
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567

at::Tensor scaled_softmax_forward(at::Tensor input,
                                  float scale_factor
);


at::Tensor scaled_softmax_backward(at::Tensor output_grad_,
                                   at::Tensor softmax_results_,
                                   float scale_factor
);


at::Tensor scaled_masked_softmax_forward(at::Tensor input,
                                         at::Tensor mask,
                                         float scale_factor
);


at::Tensor scaled_masked_softmax_backward(at::Tensor output_grad_,
                                          at::Tensor softmax_results_,
                                          float scale_factor
);


at::Tensor scaled_upper_triang_masked_softmax_forward(at::Tensor input,
                                                      float scale_factor
);


at::Tensor scaled_upper_triang_masked_softmax_backward(at::Tensor output_grads_,
                                                       at::Tensor softmax_results_,
                                                       float scale_factor
);
568

569
570
571
572
573
574
575
576
577
578
579

at::Tensor scaled_aligned_causal_masked_softmax_forward(at::Tensor input,
                                                        float scale_factor
);


at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads_,
                                                         at::Tensor softmax_results_,
                                                         float scale_factor
);

580
581
582
583
/***************************************************************************************************
 * FP8 recipe
 **************************************************************************************************/

584
585
586
587
588
589
590
void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                 std::vector<at::Tensor> amax_histories,
                                                 std::vector<at::Tensor> scales,
                                                 std::vector<at::Tensor> scale_invs,
                                                 const std::string &amax_compute_algo,
                                                 transformer_engine::DType fp8_dtype,
                                                 float margin);
591

592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
/***************************************************************************************************
 * Rotary positional embedding
 **************************************************************************************************/

at::Tensor fused_rope_forward(const at::Tensor &input,
                              const at::Tensor &freqs,
                              const bool transpose_output_memory
);

at::Tensor fused_rope_backward(const at::Tensor &output_grads,
                               const at::Tensor &freqs,
                               const bool transpose_output_memory
);

at::Tensor fused_rope_thd_forward(const at::Tensor &input,
                                  const at::Tensor &cu_seqlens,
                                  const at::Tensor &freqs
);

at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads,
                                   const at::Tensor &cu_seqlens,
                                   const at::Tensor &freqs
);

/***************************************************************************************************
617
 * Miscellaneous
618
619
 **************************************************************************************************/

620
621
size_t get_cublasLt_version();

622
623
size_t get_cudnn_version();

624
625
626
bool userbuf_comm_available();

void placeholder();