ptx.cuh 65 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
 *
 * See LICENSE for license information.
 ************************************************************************/

/*! \file ptx.cuh
 *  \brief BW PTX
 */

#ifndef TRANSFORMER_ENGINE_PTX_CUH_
#define TRANSFORMER_ENGINE_PTX_CUH_

#include <cuda.h>
#include <cuda_runtime.h>

17
18
19
20
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif  // CUDA_VERSION >= 12080

21
22
#include "common/utils.cuh"

23
namespace transformer_engine {
24

25
26
namespace ptx {

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
template <int N>
struct ArchSpecific {
  constexpr static int id = N * 10;

  template <int CurrentArch, int ArchSpecific, int FamilySpecific>
  constexpr static bool compatible() {
    if constexpr (CurrentArch == id) {
      static_assert(ArchSpecific == CurrentArch,
                    "Compiled for the generic architecture, while utilizing arch-specific "
                    "features. Please compile for smXXXa architecture instead of smXXX "
                    "architecture.");
      return true;
    } else {
      return false;
    }
  }
};

template <int N>
struct FamilySpecific {
  constexpr static int id = N * 10;

  template <int CurrentArch, int ArchSpecific, int FamilySpecific>
  constexpr static bool compatible() {
    if constexpr ((CurrentArch / 100) == (id / 100)) {
      static_assert(FamilySpecific == CurrentArch,
                    "Compiled for the generic architecture, while utilizing family-specific "
                    "features. Please compile for smXXXf architecture instead of smXXX "
                    "architecture.");
      return true;
    } else {
      return false;
    }
  }
};

template <int Arch, int ArchSpecific, int FamilySpecific, class T, class... U>
constexpr bool is_supported_arch() {
  if constexpr (T::template compatible<Arch, ArchSpecific, FamilySpecific>()) {
    return true;
  } else if constexpr (sizeof...(U) != 0) {
    return is_supported_arch<Arch, ArchSpecific, FamilySpecific, U...>();
  } else {
    return false;
  }
}

wenjh's avatar
wenjh committed
74
75
76
77
78
79
80
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
    ((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
     (__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
     (__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#if CUDA_VERSION < 12090
#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL)
#define __CUDA_ARCH_SPECIFIC__ 900
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 900
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM100_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1000
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1000
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM101_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1010
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1010
#endif
#if __CUDA_ARCH_HAS_FEATURE__(SM120_ALL)
#define __CUDA_ARCH_SPECIFIC__ 1200
#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1200
#endif
#endif

wenjh's avatar
wenjh committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#ifdef __CUDA_ARCH__
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__;
#else
#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = 0;
#endif

#ifdef __CUDA_ARCH_SPECIFIC__
#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = __CUDA_ARCH_SPECIFIC__;
#else
#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = 0;
#endif

#ifdef __CUDA_ARCH_FAMILY_SPECIFIC__
#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = __CUDA_ARCH_FAMILY_SPECIFIC__;
#else
#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = 0;
#endif

#define NVTE_CUDA_ARCH_MATCHES(...)                                                               \
  [&] {                                                                                           \
    __NVTE_CURRENT_ARCH__                                                                         \
    __NVTE_ARCH_SPECIFIC__                                                                        \
    __NVTE_ARCH_FAMILY_SPECIFIC__                                                                 \
    return transformer_engine::ptx::is_supported_arch<current_arch, ArchSpecific, FamilySpecific, \
                                                      __VA_ARGS__>();                             \
  }();

#define ARCH_BLACKWELL_FAMILY                                                \
  NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>, ptx::FamilySpecific<110>, \
                         ptx::FamilySpecific<120>)
#define ARCH_HAS_STOCHASTIC_ROUNDING \
  NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>)
133
134
135

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
__device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) {
136
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
137
138
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory");
139
140
141
#else
  NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
142
143
144
145
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
__device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) {
146
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
147
148
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory");
149
150
151
#else
  NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
152
153
154
155
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) {
156
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
157
158
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory");
159
160
161
#else
  NVTE_DEVICE_ERROR("mbarrier_arrive is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
162
163
164
165
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive
__device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) {
166
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
167
168
169
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count)
               : "memory");
170
171
172
#else
  NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
173
174
175
}

__device__ __forceinline__ void fence_mbarrier_init_release_cluster() {
176
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
177
  asm volatile("fence.mbarrier_init.release.cluster;");
178
179
180
#else
  NVTE_DEVICE_ERROR("fence_mbarrier_init_release_cluster is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
181
182
183
184
185
186
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// global -> shared::cluster
__device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared(
    uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) {
187
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
188
189
190
191
192
193
194
195
196
197
198
  uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  // triggers async copy, i.e. the thread continues until wait() on mbarrier
  // barrier condition:
  // - leader must arrive (i.e. 1 thread as set above)
  // - TMA hardware substracts bytes from expect_tx counter, must reach zero
  asm volatile(
      "cp.async.bulk.shared::cta.global"
      ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr),
      "l"(src_global_ptr), "r"(size), "r"(mbar_ptr)
      : "memory");
199
200
201
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_global_to_shared is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
202
203
204
205
206
207
208
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// global -> shared::cluster
__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
    uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x,
    const uint32_t offset_y, uint64_t *mbar) {
209
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
210
211
212
213
214
215
216
217
218
219
220
  uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem);
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  // triggers async copy, i.e. the thread continues until wait() on mbarrier
  // barrier condition:
  // - leader must arrive (i.e. 1 thread as set above)
  // - TMA hardware substracts bytes from expect_tx counter, must reach zero
  asm volatile(
      "cp.async.bulk.tensor.2d.shared::cluster.global.tile"
      ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr),
      "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr)
      : "memory");
221
222
223
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_global_to_shared is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
224
225
}

226
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
227
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
228
229
230
231
232
233
234
235
236
237
  uint32_t waitComplete;
  asm volatile(
      "{\n\t .reg .pred P_OUT; \n\t"
      "mbarrier.try_wait.parity.shared::cta.b64  P_OUT, [%1], %2; \n\t"
      "selp.b32 %0, 1, 0, P_OUT; \n"
      "}"
      : "=r"(waitComplete)
      : "r"(mbar_ptr), "r"(parity)
      : "memory");
  return static_cast<bool>(waitComplete);
238
239
240
241
#else
  NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 10.0+.");
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  return true;
242
243
244
}

__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
245
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
246
247
248
  uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
  while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
  }
249
250
#else
  NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 10.0+.");
251
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
252
}
253

254
255
256
257
258
259
260
261
262
263
264
265
266
267
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;

__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
  return (biased_exp == 0) ? 1
                           : __int_as_float((254 - biased_exp)
                                            << FP32_MANTISSA_BITS);  // 127 - (biased_exp - 127)
}

__device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
  return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
}

__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
wenjh's avatar
wenjh committed
268
269
270
#ifdef __HIP_PLATFORM_AMD__
  NVTE_DEVICE_ERROR("float_to_e8m0 is not supported on rocm platform.");
#else
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
  constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
  if constexpr (is_blackwell) {
    uint16_t out;
    asm volatile(
        "{\n"
        "cvt.rp.satfinite.ue8m0x2.f32  %0, 0.0, %1;\n"
        "}"
        : "=h"(out)
        : "f"(val));
    return *reinterpret_cast<e8m0_t *>(&out);
  } else {
    // TODO: nan/inf needs to be set for any value
    // of nan/inf in input not just amax.
    if (isnan(val)) {
      return 0xFF;
    }
    if (isinf(val)) {
      return 0xFE;
    }
    if (val == 0.0f) {
      return 0x00;
    }
    uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
    e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
    uint32_t mantissa = val_u32 & 0x7FFFFF;
    // Round up exponent and deal with satfinite.
    if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
      ++exponent;
    }
    return exponent;
301
  }
wenjh's avatar
wenjh committed
302
#endif
303
304
}

305
306
307
308
309
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
                                                                         const uint64_t *src_shmem,
                                                                         const uint32_t size) {
310
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
311
312
313
314
  uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
  asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr),
               "r"(src_shmem_ptr), "r"(size)
               : "memory");
315
316
317
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_shared_to_global is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
318
319
320
321
322
323
324
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
    const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y,
    uint64_t *src_shmem) {
325
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
326
327
328
329
330
  uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem);
  asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"(
                   tensor_map_ptr),
               "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr)
               : "memory");
331
332
333
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
334
335
336
337
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() {
338
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
339
  asm volatile("cp.async.bulk.wait_group 0;");
340
341
342
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_wait_group is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
343
344
345
346
347
}

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
template <size_t W>
__device__ __forceinline__ void cp_async_bulk_wait_group_read() {
348
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
349
  asm volatile("cp.async.bulk.wait_group.read 0;");
350
351
352
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
353
354
355
356
}

template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() {
357
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
358
  asm volatile("cp.async.bulk.wait_group.read 0;");
359
360
361
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
362
363
364
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() {
365
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
366
  asm volatile("cp.async.bulk.wait_group.read 1;");
367
368
369
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
370
371
372
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() {
373
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
374
  asm volatile("cp.async.bulk.wait_group.read 2;");
375
376
377
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
378
379
380
}
template <>
__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
381
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
382
  asm volatile("cp.async.bulk.wait_group.read 4;");
383
384
385
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
386
387
}

388
389
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
390
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
391
  asm volatile("cp.async.bulk.commit_group;");
392
393
394
#else
  NVTE_DEVICE_ERROR("cp_async_bulk_commit_group is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
395
396
}

397
// Proxy fence (bi-directional):
398
399
400
401
402
403
404
__device__ __forceinline__ void fence_proxy_async() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
  asm volatile("fence.proxy.async;");
#else
  NVTE_DEVICE_ERROR("fence_proxy_async is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
405

406
__device__ __forceinline__ void fence_proxy_async_shared_cta() {
407
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
408
  asm volatile("fence.proxy.async.shared::cta;");
409
410
411
#else
  NVTE_DEVICE_ERROR("fence_proxy_async_shared_cta is only supported on SM 9.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
412
413
}

wenjh's avatar
wenjh committed
414
415
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)

416
417
418
419
420
421
template <typename T>
struct alignas(2 * sizeof(T)) FPx2 {
  T x;
  T y;
};

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
template <typename T>
struct FPx4 {
  T x1;
  T x2;
  T x3;
  T x4;
};

template <typename T>
struct Type2x {};

template <>
struct Type2x<float> {
  using type = float2;
};

template <>
struct Type2x<bf16> {
  using type = __nv_bfloat162;
};

template <>
struct Type2x<fp16> {
  using type = __half2;
};

448
449
450
451
452
453
using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>;

454
455
456
457
458
459
using floatx4 = FPx4<float>;
using bf16x4 = FPx4<bf16>;
using fp16x4 = FPx4<fp16>;
using fp8e4m3x4 = FPx4<fp8e4m3>;
using fp8e5m2x4 = FPx4<fp8e5m2>;

460
461
462
463
464
465
static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);

466
#if FP4_TYPE_SUPPORTED
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2);

// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.

// SIMD like "Fused" cast + multiplication (x4)
template <typename Tx2>
__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23,
                                           const float scale) {
  const float x0 = static_cast<float>(in01.x) * scale;
  const float x1 = static_cast<float>(in01.y) * scale;
  const float x2 = static_cast<float>(in23.x) * scale;
  const float x3 = static_cast<float>(in23.y) * scale;
  out = fp4e2m1x4(make_float4(x0, x1, x2, x3));
}
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674

__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
    const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
  uint16_t out_4x = 0;
  constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
  if constexpr (has_rs) {
    asm volatile(
        "{\n"
        ".reg.b64 v01; \n\t"
        ".reg.b64 v23; \n\t"
        ".reg.b16 v0_bf16; \n\t"
        ".reg.b16 v1_bf16; \n\t"
        ".reg.b16 v2_bf16; \n\t"
        ".reg.b16 v3_bf16; \n\t"
        ".reg.b32 v0; \n\t"
        ".reg.b32 v1; \n\t"
        ".reg.b32 v2; \n\t"
        ".reg.b32 v3; \n\t"
        "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
        "cvt.f32.bf16 v0, v0_bf16; \n\t"
        "cvt.f32.bf16 v1, v1_bf16; \n\t"
        "cvt.f32.bf16 v2, v2_bf16; \n\t"
        "cvt.f32.bf16 v3, v3_bf16; \n\t"
        "mov.b64 v01, {v0, v1}; \n\t"
        "mov.b64 v23, {v2, v3}; \n\t"
        "mul.f32x2 v01, v01, %2; \n\t"  // mind the shuffled elements order
        "mul.f32x2 v23, v23, %2; \n\t"  // mind the shuffled elements order
        "mov.b64 {v1, v0}, v01; \n\t"
        "mov.b64 {v3, v2}, v23; \n\t"
        "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t"  // mind the shuffled elements order
        "}"
        : "=h"(out_4x)
        : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
  } else {
    NVTE_DEVICE_ERROR(
        "FP4 cvt PTX instructions are architecture-specific. "
        "Try recompiling with sm_XXXa instead of sm_XXX.");
  }
  return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}

__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
                                                                    const float2 scale,
                                                                    const uint32_t rbits) {
  constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
  uint32_t out_4x = 0;  // Only need 16 bit. Using 32 bit container for packing.
  if constexpr (is_blackwell) {
    // NOTE: rbits unused for rn.
    asm volatile(
        "{\n"
        ".reg.b64 v01; \n\t"
        ".reg.b64 v23; \n\t"
        ".reg.b16 v0_bf16; \n\t"
        ".reg.b16 v1_bf16; \n\t"
        ".reg.b16 v2_bf16; \n\t"
        ".reg.b16 v3_bf16; \n\t"
        ".reg.b32 v0; \n\t"
        ".reg.b32 v1; \n\t"
        ".reg.b32 v2; \n\t"
        ".reg.b32 v3; \n\t"
        ".reg.b8 f0; \n\t"
        ".reg.b8 f1; \n\t"
        "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
        "cvt.f32.bf16 v0, v0_bf16; \n\t"
        "cvt.f32.bf16 v1, v1_bf16; \n\t"
        "cvt.f32.bf16 v2, v2_bf16; \n\t"
        "cvt.f32.bf16 v3, v3_bf16; \n\t"
        "mov.b64 v01, {v0, v1}; \n\t"
        "mov.b64 v23, {v2, v3}; \n\t"
        "mul.f32x2 v01, v01, %2; \n\t"  // mind the shuffled elements order
        "mul.f32x2 v23, v23, %2; \n\t"  // mind the shuffled elements order
        "mov.b64 {v1, v0}, v01; \n\t"
        "mov.b64 {v3, v2}, v23; \n\t"
        "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
        "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
        "mov.b32 %0, {f0, f1, f0, f1};\n\t"
        "}"
        : "=r"(out_4x)
        : "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
  } else {
    NVTE_DEVICE_ERROR(
        "FP4 cvt PTX instructions are architecture-specific. "
        "Try recompiling with sm_XXXa instead of sm_XXX.");
  }
  return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}

template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
                                                            const float2 scale,
                                                            const uint32_t rbits) {
  if constexpr (USE_STOCHASTIC_ROUNDING) {
    return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
  } else {
    return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
  }
}

__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
    const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
  uint16_t out_4x = 0;
  constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
  if constexpr (has_rs) {
    asm volatile(
        "{\n"
        ".reg.b64 v01; \n\t"
        ".reg.b64 v23; \n\t"
        ".reg.b32 v0; \n\t"
        ".reg.b32 v1; \n\t"
        ".reg.b32 v2; \n\t"
        ".reg.b32 v3; \n\t"
        "mov.b64 {v0, v1} , %1; \n\t"
        "mov.b64 {v2, v3} , %2; \n\t"
        "mov.b64 v01, {v0, v1}; \n\t"
        "mov.b64 v23, {v2, v3}; \n\t"
        "mul.f32x2 v01, v01, %3; \n\t"  // mind the shuffled elements order
        "mul.f32x2 v23, v23, %3; \n\t"  // mind the shuffled elements order
        "mov.b64 {v1, v0}, v01; \n\t"
        "mov.b64 {v3, v2}, v23; \n\t"
        "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t"  // mind the shuffled elements order
        "}"
        : "=h"(out_4x)
        : "l"(reinterpret_cast<const uint64_t &>(in01)),
          "l"(reinterpret_cast<const uint64_t &>(in23)),
          "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
  } else {
    NVTE_DEVICE_ERROR(
        "FP4 cvt PTX instructions are architecture-specific. "
        "Try recompiling with sm_XXXa instead of sm_XXX.");
  }
  return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}

__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
                                                                    const float2 in23,
                                                                    const float2 scale,
                                                                    const uint32_t rbits) {
  constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
  uint32_t out_4x = 0;  // Only need 16 bit. Using 32 bit container for packing.
  if constexpr (is_blackwell) {
    // NOTE: rbits unused for rn.
    asm volatile(
        "{\n"
        ".reg.b64 v01; \n\t"
        ".reg.b64 v23; \n\t"
        ".reg.b32 v0; \n\t"
        ".reg.b32 v1; \n\t"
        ".reg.b32 v2; \n\t"
        ".reg.b32 v3; \n\t"
        ".reg.b8 f0; \n\t"
        ".reg.b8 f1; \n\t"
        "mov.b64 {v0, v1} , %1; \n\t"
        "mov.b64 {v2, v3} , %2; \n\t"
        "mov.b64 v01, {v0, v1}; \n\t"
        "mov.b64 v23, {v2, v3}; \n\t"
        "mul.f32x2 v01, v01, %3; \n\t"  // mind the shuffled elements order
        "mul.f32x2 v23, v23, %3; \n\t"  // mind the shuffled elements order
        "mov.b64 {v1, v0}, v01; \n\t"
        "mov.b64 {v3, v2}, v23; \n\t"
        "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
        "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
        "mov.b32 %0, {f0, f1, f0, f1};\n\t"
        "}"
        : "=r"(out_4x)
        : "l"(reinterpret_cast<const uint64_t &>(in01)),
          "l"(reinterpret_cast<const uint64_t &>(in23)),
          "l"(reinterpret_cast<const uint64_t &>(scale)));
  } else {
    NVTE_DEVICE_ERROR(
        "FP4 cvt PTX instructions are architecture-specific. "
        "Try recompiling with sm_XXXa instead of sm_XXX.");
  }
  return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}

template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
                                                            const float2 scale,
                                                            const uint32_t rbits) {
  if constexpr (USE_STOCHASTIC_ROUNDING) {
    return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
  } else {
    return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
  }
}
#endif  // FP4_TYPE_SUPPORTED
675

676
677
678
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
                                           const floatx2 &scale) {
679
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
680
681
682
683
684
685
686
687
688
689
690
691
  asm volatile(
      "{\n"
      ".reg.b64 val_pair; \n\t"
      ".reg.b32 val1; \n\t"
      ".reg.b32 val2; \n\t"
      "mul.f32x2 val_pair, %1, %2; \n\t"
      "mov.b64 {val2,val1}, val_pair; \n\t"
      "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
      "}"
      : "=h"(reinterpret_cast<uint16_t &>(out))
      : "l"(reinterpret_cast<const uint64_t &>(in)),
        "l"(reinterpret_cast<const uint64_t &>(scale)));
692
693
694
#else
  NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
695
696
697
698
}

__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
                                           const floatx2 &scale) {
699
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
700
701
702
703
704
705
706
707
708
709
710
711
  asm volatile(
      "{\n"
      ".reg.b64 val_pair; \n\t"
      ".reg.b32 val1; \n\t"
      ".reg.b32 val2; \n\t"
      "mul.f32x2 val_pair, %1, %2; \n\t"
      "mov.b64 {val2,val1}, val_pair; \n\t"
      "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
      "}"
      : "=h"(reinterpret_cast<uint16_t &>(out))
      : "l"(reinterpret_cast<const uint64_t &>(in)),
        "l"(reinterpret_cast<const uint64_t &>(scale)));
712
713
714
#else
  NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
715
716
717
}

__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) {
718
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
  asm volatile(
      "{\n"
      ".reg.b64 val_pair_before; \n\t"
      ".reg.b64 val_pair_after; \n\t"
      ".reg.b32 val1; \n\t"
      ".reg.b32 val2; \n\t"
      ".reg.b16 val1_bf16; \n\t"
      ".reg.b16 val2_bf16; \n\t"
      "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
      "cvt.f32.bf16 val1, val1_bf16; \n\t"
      "cvt.f32.bf16 val2, val2_bf16; \n\t"
      "mov.b64 val_pair_before, {val1,val2}; \n\t"
      "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
      "mov.b64 {val2,val1}, val_pair_after; \n\t"
      "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
      "}"
      : "=h"(reinterpret_cast<uint16_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in)),
        "l"(reinterpret_cast<const uint64_t &>(scale)));
738
739
740
#else
  NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
741
742
743
}

__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) {
744
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
  asm volatile(
      "{\n"
      ".reg.b64 val_pair_before; \n\t"
      ".reg.b64 val_pair_after; \n\t"
      ".reg.b32 val1; \n\t"
      ".reg.b32 val2; \n\t"
      ".reg.b16 val1_bf16; \n\t"
      ".reg.b16 val2_bf16; \n\t"
      "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
      "cvt.f32.bf16 val1, val1_bf16; \n\t"
      "cvt.f32.bf16 val2, val2_bf16; \n\t"
      "mov.b64 val_pair_before, {val1,val2}; \n\t"
      "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
      "mov.b64 {val2,val1}, val_pair_after; \n\t"
      "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
      "}"
      : "=h"(reinterpret_cast<uint16_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in)),
        "l"(reinterpret_cast<const uint64_t &>(scale)));
764
765
766
#else
  NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
767
768
769
}

__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) {
770
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
  asm volatile(
      "{\n"
      ".reg.b64 val_pair_before; \n\t"
      ".reg.b64 val_pair_after; \n\t"
      ".reg.b32 val1; \n\t"
      ".reg.b32 val2; \n\t"
      ".reg.b16 val1_fp16; \n\t"
      ".reg.b16 val2_fp16; \n\t"
      "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
      "cvt.f32.f16 val1, val1_fp16; \n\t"
      "cvt.f32.f16 val2, val2_fp16; \n\t"
      "mov.b64 val_pair_before, {val1,val2}; \n\t"
      "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
      "mov.b64 {val2,val1}, val_pair_after; \n\t"
      "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
      "}"
      : "=h"(reinterpret_cast<uint16_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in)),
        "l"(reinterpret_cast<const uint64_t &>(scale)));
790
791
792
#else
  NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
793
794
795
}

__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) {
796
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
  asm volatile(
      "{\n"
      ".reg.b64 val_pair_before; \n\t"
      ".reg.b64 val_pair_after; \n\t"
      ".reg.b32 val1; \n\t"
      ".reg.b32 val2; \n\t"
      ".reg.b16 val1_fp16; \n\t"
      ".reg.b16 val2_fp16; \n\t"
      "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
      "cvt.f32.f16 val1, val1_fp16; \n\t"
      "cvt.f32.f16 val2, val2_fp16; \n\t"
      "mov.b64 val_pair_before, {val1,val2}; \n\t"
      "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
      "mov.b64 {val2,val1}, val_pair_after; \n\t"
      "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
      "}"
      : "=h"(reinterpret_cast<uint16_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in)),
        "l"(reinterpret_cast<const uint64_t &>(scale)));
816
817
818
#else
  NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
819
820
821
}

__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) {
822
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
823
824
825
826
  asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;"
               : "=r"(reinterpret_cast<uint32_t &>(dst))
               : "r"(reinterpret_cast<const uint32_t &>(p1)),
                 "r"(reinterpret_cast<const uint32_t &>(p2)));
827
828
829
#else
  NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
830
831
832
}

__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) {
833
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
834
835
836
837
  asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;"
               : "=r"(reinterpret_cast<uint32_t &>(dst))
               : "r"(reinterpret_cast<const uint32_t &>(p1)),
                 "r"(reinterpret_cast<const uint32_t &>(p2)));
838
839
840
#else
  NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
841
842
}

843
844
845
846
847
848
849
850
851
852
853
854
855
856
__device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  int32_t pred = 0;
  asm volatile(
      "{\n\t"
      ".reg .pred %px; \n"
      "elect.sync _|%px, %1; \n"
      "selp.b32 %0, 1, 0, %px; \n"
      "\n\t}"
      : "=r"(pred)
      : "r"(mask));
  return pred;
#else
  NVTE_DEVICE_ERROR("elect_one_sync is only supported on SM 10.0+.");
857
  return 0;
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void numbered_barrier_sync(uint32_t num_threads,
                                                      uint32_t barrier_id = 1u) {
  asm volatile("bar.sync %0, %1;\n" ::"r"(barrier_id), "r"(num_threads));
}

__device__ __forceinline__ void fma_f32_f16(float &out, uint16_t const &a, uint16_t const &b,
                                            float const &c = 0.0f) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  asm volatile("fma.rn.f32.f16 %0, %1, %2, %3;" : "=f"(out) : "h"(a), "h"(b), "f"(c) : "memory");
#else
  NVTE_DEVICE_ERROR("fma_f32_f16 is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void fma_f32_bf16(float &out, uint16_t const &a, uint16_t const &b,
                                             float const &c = 0.0f) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  asm volatile("fma.rn.f32.bf16 %0, %1, %2, %3;" : "=f"(out) : "h"(a), "h"(b), "f"(c) : "memory");
#else
  NVTE_DEVICE_ERROR("fma_f32_bf16 is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const &in) {
885
886
887
888
889
890
891
892
893
894
895
896
897
  constexpr bool is_sm_100f = NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>);
  if constexpr (is_sm_100f) {
    asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in));
  } else {
    asm volatile(
        "{\n\t"
        ".reg.b32 val;\n"
        "abs.f32 val, %1;\n"
        "redux.sync.max.u32 %0, val, 0xFFFFFFFF;\n"
        "}\n\t"
        : "=r"(reinterpret_cast<uint32_t &>(out))
        : "f"(in));
  }
898
899
900
901
902
903
904
905
906
907
908
}

__device__ __forceinline__ bf16 get_amax(bf16 a, bf16 b) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  bf16 r;
  asm volatile("max.xorsign.abs.bf16 %0, %1, %2;"
               : "=h"(*reinterpret_cast<int16_t *>(&r))
               : "h"(*reinterpret_cast<int16_t *>(&a)), "h"(*reinterpret_cast<int16_t *>(&b)));
  return r;
#else
  NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+.");
909
  return 0.f;
910
911
912
913
914
915
916
917
918
919
920
921
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  fp16 r;
  asm volatile("max.xorsign.abs.f16 %0, %1, %2;"
               : "=h"(*reinterpret_cast<int16_t *>(&r))
               : "h"(*reinterpret_cast<int16_t *>(&a)), "h"(*reinterpret_cast<int16_t *>(&b)));
  return r;
#else
  NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+.");
922
  return 0.f;
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in,
                                           const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
  asm volatile(
      "{\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
      "prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
      "prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
      "prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
  ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
  asm volatile(
      "{\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
      "prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
      "prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
      "prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[0])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in,
                                           const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
  asm volatile(
      "{\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
      "prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
      "prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
      "prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
  ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
  asm volatile(
      "{\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
      "prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
      "prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
      "prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[0])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in,
                                           const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
  asm volatile(
      "{\n\t"
      ".reg.b16 val1_f16;\n\t"
      ".reg.b16 val2_f16;\n\t"
      ".reg.b16 val3_f16;\n\t"
      ".reg.b16 val4_f16;\n\t"
      "mov.b32 {val1_f16, val2_f16}, %1;\n\t"
      "mov.b32 {val3_f16, val4_f16}, %2;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "cvt.f32.f16 val1, val1_f16;\n\t"
      "cvt.f32.f16 val2, val2_f16;\n\t"
      "cvt.f32.f16 val3, val3_f16;\n\t"
      "cvt.f32.f16 val4, val4_f16;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
  ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
  asm volatile(
      "{\n\t"
      ".reg.b16 val1_f16;\n\t"
      ".reg.b16 val2_f16;\n\t"
      ".reg.b16 val3_f16;\n\t"
      ".reg.b16 val4_f16;\n\t"
      "mov.b32 {val1_f16, val2_f16}, %1;\n\t"
      "mov.b32 {val3_f16, val4_f16}, %2;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "cvt.f32.f16 val1, val1_f16;\n\t"
      "cvt.f32.f16 val2, val2_f16;\n\t"
      "cvt.f32.f16 val3, val3_f16;\n\t"
      "cvt.f32.f16 val4, val4_f16;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[0])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in,
                                           const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
  asm volatile(
      "{\n\t"
      ".reg.b16 val1_f16;\n\t"
      ".reg.b16 val2_f16;\n\t"
      ".reg.b16 val3_f16;\n\t"
      ".reg.b16 val4_f16;\n\t"
      "mov.b32 {val1_f16, val2_f16}, %1;\n\t"
      "mov.b32 {val3_f16, val4_f16}, %2;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "cvt.f32.f16 val1, val1_f16;\n\t"
      "cvt.f32.f16 val2, val2_f16;\n\t"
      "cvt.f32.f16 val3, val3_f16;\n\t"
      "cvt.f32.f16 val4, val4_f16;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
  ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
  asm volatile(
      "{\n\t"
      ".reg.b16 val1_f16;\n\t"
      ".reg.b16 val2_f16;\n\t"
      ".reg.b16 val3_f16;\n\t"
      ".reg.b16 val4_f16;\n\t"
      "mov.b32 {val1_f16, val2_f16}, %1;\n\t"
      "mov.b32 {val3_f16, val4_f16}, %2;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "cvt.f32.f16 val1, val1_f16;\n\t"
      "cvt.f32.f16 val2, val2_f16;\n\t"
      "cvt.f32.f16 val3, val3_f16;\n\t"
      "cvt.f32.f16 val4, val4_f16;\n\t"
      ".reg.b64 val_1_2;\n\t"
      ".reg.b64 val_3_4;\n\t"
      "mov.b64 val_1_2, {val1, val2};\n\t"
      "mov.b64 val_3_4, {val3, val4};\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      "fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
      "fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
      "mov.b64 {val1, val2}, val_1_2;\n\t"
      "mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "r"(reinterpret_cast<const uint32_t &>(in2[0])),
        "r"(reinterpret_cast<const uint32_t &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[0])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, floatx4 const &in,
                                           const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
  asm volatile(
      "{\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      ".reg.b64 re1;\n\t"
      ".reg.b64 re2;\n\t"
      "fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
      "fma.rn.f32x2 re2, %2, %3, zeros;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "mov.b64 {val1, val2}, re1;\n\t"
      "mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "l"(reinterpret_cast<uint64_t const &>(in2[0])),
        "l"(reinterpret_cast<uint64_t const &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, floatx4 const &in,
                                           const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
  ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
  asm volatile(
      "{\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      ".reg.b64 re1;\n\t"
      ".reg.b64 re2;\n\t"
      "fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
      "fma.rn.f32x2 re2, %2, %4, zeros;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "mov.b64 {val1, val2}, re1;\n\t"
      "mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
wenjh's avatar
wenjh committed
1362
#endif
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "l"(reinterpret_cast<uint64_t const &>(in2[0])),
        "l"(reinterpret_cast<uint64_t const &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[0])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, floatx4 const &in,
                                           const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
  asm volatile(
      "{\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      ".reg.b64 re1;\n\t"
      ".reg.b64 re2;\n\t"
      "fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
      "fma.rn.f32x2 re2, %2, %3, zeros;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "mov.b64 {val1, val2}, re1;\n\t"
      "mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "l"(reinterpret_cast<uint64_t const &>(in2[0])),
        "l"(reinterpret_cast<uint64_t const &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, floatx4 const &in,
                                           const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
  ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
  asm volatile(
      "{\n\t"
      ".reg.b64 zeros;\n\t"
      "mov.b64 zeros, {0x0, 0x0};\n\t"
      ".reg.b64 re1;\n\t"
      ".reg.b64 re2;\n\t"
      "fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
      "fma.rn.f32x2 re2, %2, %4, zeros;\n\t"
      ".reg.b32 val1;\n\t"
      ".reg.b32 val2;\n\t"
      ".reg.b32 val3;\n\t"
      ".reg.b32 val4;\n\t"
      "mov.b64 {val1, val2}, re1;\n\t"
      "mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
      "cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
      ".reg.b16 r1;\n\t"
      ".reg.b16 r2;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
      "cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
      "mov.b32 %0, {r1, r2};\n\t"
#endif
      "}\n\t"
      : "=r"(reinterpret_cast<uint32_t &>(out))
      : "l"(reinterpret_cast<uint64_t const &>(in2[0])),
        "l"(reinterpret_cast<uint64_t const &>(in2[1])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[0])),
        "l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
  NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif  // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__device__ __forceinline__ void abs_max_2x(float &dst, const float &p1, const float &p2,
                                           const float &p3) {
#if (defined CUDA_VERSION) && (CUDA_VERSION >= 12090)
  asm volatile("max.abs.f32 %0, %1, %2, %3;" : "=f"(dst) : "f"(p1), "f"(p2), "f"(p3));
#else
  asm volatile(
      "max.xorsign.abs.f32 %0, %2, %3;"
      "max.xorsign.abs.f32 %0, %0, %1;"
      : "+f"(dst)
      : "f"(p1), "f"(p2), "f"(p3));
#endif
}

__device__ __forceinline__ ptx::floatx2 up_cast(const ptx::fp16x2 &in) {
  ptx::floatx2 out;
  asm volatile(
      "{\n\t"
      ".reg.b16 f16_1;\n\t"
      ".reg.b16 f16_2;\n\t"
      "mov.b32 {f16_1, f16_2}, %2;\n\t"
      "cvt.f32.f16 %0, f16_1;\n\t"
      "cvt.f32.f16 %1, f16_2;\n\t"
      "}\n\t"
      : "=f"(out.x), "=f"(out.y)
      : "r"(reinterpret_cast<int32_t const &>(in)));
  return out;
}
wenjh's avatar
wenjh committed
1477

1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
__device__ __forceinline__ floatx4 up_cast(const fp16x4 &in) {
  floatx4 out;
  asm volatile(
      "{\n\t"
      ".reg.b16 f16_1;\n\t"
      ".reg.b16 f16_2;\n\t"
      ".reg.b16 f16_3;\n\t"
      ".reg.b16 f16_4;\n\t"
      "mov.b64 {f16_1, f16_2, f16_3, f16_4}, %4;\n\t"
      "cvt.f32.f16 %0, f16_1;\n\t"
      "cvt.f32.f16 %1, f16_2;\n\t"
      "cvt.f32.f16 %2, f16_3;\n\t"
      "cvt.f32.f16 %3, f16_4;\n\t"
      "}\n\t"
      : "=f"(out.x1), "=f"(out.x2), "=f"(out.x3), "=f"(out.x4)
      : "l"(reinterpret_cast<int64_t const &>(in)));
  return out;
}

__device__ __forceinline__ ptx::floatx2 up_cast(const ptx::bf16x2 &in) {
  ptx::floatx2 out;
  asm volatile(
      "{\n\t"
      "prmt.b32 %1, 0x0, %2, 0x7632;\n\t"
      "prmt.b32 %0, 0x0, %2, 0x5410;\n\t"
      "}\n\t"
      : "=r"(reinterpret_cast<int32_t &>(out.x)), "=r"(reinterpret_cast<int32_t &>(out.y))
      : "r"(reinterpret_cast<int32_t const &>(in)));
  return out;
}

__device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
  floatx4 out;
  int32_t const *in2 = reinterpret_cast<int32_t const *>(&in);
  asm volatile(
      "{\n\t"
      "prmt.b32 %1, 0x0, %4, 0x7632;\n\t"
      "prmt.b32 %0, 0x0, %4, 0x5410;\n\t"
      "prmt.b32 %3, 0x0, %5, 0x7632;\n\t"
      "prmt.b32 %2, 0x0, %5, 0x5410;\n\t"
      "}\n\t"
      : "=r"(reinterpret_cast<int32_t &>(out.x1)), "=r"(reinterpret_cast<int32_t &>(out.x2)),
        "=r"(reinterpret_cast<int32_t &>(out.x3)), "=r"(reinterpret_cast<int32_t &>(out.x4))
      : "r"(in2[0]), "r"(in2[1]));
  return out;
}
wenjh's avatar
wenjh committed
1524
#endif
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
}  // namespace ptx

namespace {

template <int num_barriers, int THREADS_PER_BLOCK>
__forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  if (is_master_thread) {
    // Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
    for (int iter = 0; iter < num_barriers; ++iter) {
      ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK);
    }
    ptx::fence_proxy_async_shared_cta();
  }
  // Syncthreads so initialized barrier is visible to all threads.
  __syncthreads();
1542
1543
#else
  NVTE_DEVICE_ERROR("initialize_barriers is only supported on SM 10.0+.");
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

template <int num_barriers>
__forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  // Destroy barrier. This invalidates the memory region of the barrier. If
  // further computations were to take place in the kernel, this allows the
  // memory location of the shared memory barrier to be reused.
  if (is_master_thread) {
#pragma unroll
    for (int iter = 0; iter < num_barriers; ++iter) {
      ptx::mbarrier_invalid(&mbar[iter]);
    }
  }
1559
1560
#else
  NVTE_DEVICE_ERROR("destroy_barriers is only supported on SM 10.0+.");
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src,
                                                  const size_t num_bytes, uint64_t *barrier,
                                                  const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  if (is_master_thread) {
    // Initiate bulk tensor copy
    ptx::cp_async_bulk_tensor_1d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
                                                  reinterpret_cast<const uint64_t *>(src),
                                                  num_bytes, barrier);

    // Arrive on the barrier and tell how many bytes are expected to come in.
    ptx::mbarrier_arrive_expect_tx(barrier, num_bytes);
  } else {
    // Other threads just arrive
    ptx::mbarrier_arrive(barrier);
  }
1580
1581
#else
  NVTE_DEVICE_ERROR("copy_1d_to_shared is only supported on SM 10.0+.");
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X,
                                                  const size_t chunk_Y, const size_t num_bytes,
                                                  uint64_t *barrier, const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  if (is_master_thread) {
    // Initiate bulk tensor copy
    ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
                                                  reinterpret_cast<const uint64_t *>(src), chunk_X,
                                                  chunk_Y, barrier);

    // Arrive on the barrier and tell how many bytes are expected to come in.
    ptx::mbarrier_arrive_expect_tx(barrier, num_bytes);
  } else {
    // Other threads just arrive
    ptx::mbarrier_arrive(barrier);
  }
1601
1602
#else
  NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+.");
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src,
                                                    const size_t chunk_X1, const size_t chunk_Y1,
                                                    void *dst2, const void *src2,
                                                    const size_t chunk_X2, const size_t chunk_Y2,
                                                    const size_t num_bytes, uint64_t *barrier,
                                                    const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  if (is_master_thread) {
    // Initiate bulk tensor copy
    ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
                                                  reinterpret_cast<const uint64_t *>(src), chunk_X1,
                                                  chunk_Y1, barrier);

    ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst2),
                                                  reinterpret_cast<const uint64_t *>(src2),
                                                  chunk_X2, chunk_Y2, barrier);

    // Arrive on the barrier and tell how many bytes are expected to come in.
    ptx::mbarrier_arrive_expect_tx(barrier, 2 * num_bytes);
  } else {
    // Other threads just arrive
    ptx::mbarrier_arrive(barrier);
  }
1629
1630
#else
  NVTE_DEVICE_ERROR("copy_2d_to_sharedx2 is only supported on SM 10.0+.");
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

__forceinline__ __device__ void copy_2d_to_sharedx3(
    void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, void *dst2,
    const void *src2, const size_t chunk_X2, const size_t chunk_Y2, void *dst3, const void *src3,
    const size_t chunk_X3, const size_t chunk_Y3, const size_t num_bytes, uint64_t *barrier,
    const bool is_master_thread) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
  if (is_master_thread) {
    // Initiate bulk tensor copy
    ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst),
                                                  reinterpret_cast<const uint64_t *>(src), chunk_X1,
                                                  chunk_Y1, barrier);

    ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst2),
                                                  reinterpret_cast<const uint64_t *>(src2),
                                                  chunk_X2, chunk_Y2, barrier);

    ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(dst3),
                                                  reinterpret_cast<const uint64_t *>(src3),
                                                  chunk_X3, chunk_Y3, barrier);

    // Arrive on the barrier and tell how many bytes are expected to come in.
    ptx::mbarrier_arrive_expect_tx(barrier, 3 * num_bytes);
  } else {
    // Other threads just arrive
    ptx::mbarrier_arrive(barrier);
  }
1660
1661
#else
  NVTE_DEVICE_ERROR("copy_2d_to_sharedx3 is only supported on SM 10.0+.");
1662
1663
1664
1665
1666
1667
1668
#endif  // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}

}  // namespace
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_PTX_CUH_