Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -101,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
}
}
}
......@@ -112,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
}
}
} else if (row < padded_num_rows) {
// padding
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_zero;
output[static_cast<size_t>(row) * row_length + col + j2] = local_zero;
}
}
}
......@@ -185,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
local_input.data.elt[j2] = input[row * row_length + col + j2];
local_input.data.elt[j2] = input[static_cast<size_t>(row) * row_length + col + j2];
}
}
}
......@@ -196,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if (row < num_rows) {
for (int j2 = 0; j2 < nvec; ++j2) {
if (col + j2 < row_length) {
output[row * row_length + col + j2] = local_output.data.elt[j2];
output[static_cast<size_t>(row) * row_length + col + j2] = local_output.data.elt[j2];
}
}
}
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -840,8 +840,688 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
}
__device__ __forceinline__ int32_t elect_one_sync(uint32_t mask = 0xFFFFFFFFu) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
int32_t pred = 0;
asm volatile(
"{\n\t"
".reg .pred %px; \n"
"elect.sync _|%px, %1; \n"
"selp.b32 %0, 1, 0, %px; \n"
"\n\t}"
: "=r"(pred)
: "r"(mask));
return pred;
#else
NVTE_DEVICE_ERROR("elect_one_sync is only supported on SM 10.0+.");
return 0;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void numbered_barrier_sync(uint32_t num_threads,
uint32_t barrier_id = 1u) {
asm volatile("bar.sync %0, %1;\n" ::"r"(barrier_id), "r"(num_threads));
}
__device__ __forceinline__ void fma_f32_f16(float &out, uint16_t const &a, uint16_t const &b,
float const &c = 0.0f) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile("fma.rn.f32.f16 %0, %1, %2, %3;" : "=f"(out) : "h"(a), "h"(b), "f"(c) : "memory");
#else
NVTE_DEVICE_ERROR("fma_f32_f16 is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void fma_f32_bf16(float &out, uint16_t const &a, uint16_t const &b,
float const &c = 0.0f) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile("fma.rn.f32.bf16 %0, %1, %2, %3;" : "=f"(out) : "h"(a), "h"(b), "f"(c) : "memory");
#else
NVTE_DEVICE_ERROR("fma_f32_bf16 is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void reduce_sync_max_abs_f32(float &out, float const &in) {
constexpr bool is_sm_100f = NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>);
if constexpr (is_sm_100f) {
asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" : "=f"(out) : "f"(in));
} else {
asm volatile(
"{\n\t"
".reg.b32 val;\n"
"abs.f32 val, %1;\n"
"redux.sync.max.u32 %0, val, 0xFFFFFFFF;\n"
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "f"(in));
}
}
__device__ __forceinline__ bf16 get_amax(bf16 a, bf16 b) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
bf16 r;
asm volatile("max.xorsign.abs.bf16 %0, %1, %2;"
: "=h"(*reinterpret_cast<int16_t *>(&r))
: "h"(*reinterpret_cast<int16_t *>(&a)), "h"(*reinterpret_cast<int16_t *>(&b)));
return r;
#else
NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+.");
return 0.f;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ fp16 get_amax(fp16 a, fp16 b) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
fp16 r;
asm volatile("max.xorsign.abs.f16 %0, %1, %2;"
: "=h"(*reinterpret_cast<int16_t *>(&r))
: "h"(*reinterpret_cast<int16_t *>(&a)), "h"(*reinterpret_cast<int16_t *>(&b)));
return r;
#else
NVTE_DEVICE_ERROR("get_amax is only supported on SM 10.0+.");
return 0.f;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in,
const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
asm volatile(
"{\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
"prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
"prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
"prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const bf16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
asm volatile(
"{\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
"prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
"prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
"prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale2[0])),
"l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in,
const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
asm volatile(
"{\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
"prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
"prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
"prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const bf16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::bf16x2 const *in2 = reinterpret_cast<ptx::bf16x2 const *>(&in);
ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
asm volatile(
"{\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"prmt.b32 val2, 0x0, %1, 0x7632;\n\t"
"prmt.b32 val1, 0x0, %1, 0x5410;\n\t"
"prmt.b32 val4, 0x0, %2, 0x7632;\n\t"
"prmt.b32 val3, 0x0, %2, 0x5410;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale2[0])),
"l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in,
const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
asm volatile(
"{\n\t"
".reg.b16 val1_f16;\n\t"
".reg.b16 val2_f16;\n\t"
".reg.b16 val3_f16;\n\t"
".reg.b16 val4_f16;\n\t"
"mov.b32 {val1_f16, val2_f16}, %1;\n\t"
"mov.b32 {val3_f16, val4_f16}, %2;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"cvt.f32.f16 val1, val1_f16;\n\t"
"cvt.f32.f16 val2, val2_f16;\n\t"
"cvt.f32.f16 val3, val3_f16;\n\t"
"cvt.f32.f16 val4, val4_f16;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, const fp16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
asm volatile(
"{\n\t"
".reg.b16 val1_f16;\n\t"
".reg.b16 val2_f16;\n\t"
".reg.b16 val3_f16;\n\t"
".reg.b16 val4_f16;\n\t"
"mov.b32 {val1_f16, val2_f16}, %1;\n\t"
"mov.b32 {val3_f16, val4_f16}, %2;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"cvt.f32.f16 val1, val1_f16;\n\t"
"cvt.f32.f16 val2, val2_f16;\n\t"
"cvt.f32.f16 val3, val3_f16;\n\t"
"cvt.f32.f16 val4, val4_f16;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale2[0])),
"l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in,
const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
asm volatile(
"{\n\t"
".reg.b16 val1_f16;\n\t"
".reg.b16 val2_f16;\n\t"
".reg.b16 val3_f16;\n\t"
".reg.b16 val4_f16;\n\t"
"mov.b32 {val1_f16, val2_f16}, %1;\n\t"
"mov.b32 {val3_f16, val4_f16}, %2;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"cvt.f32.f16 val1, val1_f16;\n\t"
"cvt.f32.f16 val2, val2_f16;\n\t"
"cvt.f32.f16 val3, val3_f16;\n\t"
"cvt.f32.f16 val4, val4_f16;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, const fp16x4 &in, const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::fp16x2 const *in2 = reinterpret_cast<ptx::fp16x2 const *>(&in);
ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
asm volatile(
"{\n\t"
".reg.b16 val1_f16;\n\t"
".reg.b16 val2_f16;\n\t"
".reg.b16 val3_f16;\n\t"
".reg.b16 val4_f16;\n\t"
"mov.b32 {val1_f16, val2_f16}, %1;\n\t"
"mov.b32 {val3_f16, val4_f16}, %2;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"cvt.f32.f16 val1, val1_f16;\n\t"
"cvt.f32.f16 val2, val2_f16;\n\t"
"cvt.f32.f16 val3, val3_f16;\n\t"
"cvt.f32.f16 val4, val4_f16;\n\t"
".reg.b64 val_1_2;\n\t"
".reg.b64 val_3_4;\n\t"
"mov.b64 val_1_2, {val1, val2};\n\t"
"mov.b64 val_3_4, {val3, val4};\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;\n\t"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;\n\t"
"mov.b64 {val1, val2}, val_1_2;\n\t"
"mov.b64 {val3, val4}, val_3_4;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in2[0])),
"r"(reinterpret_cast<const uint32_t &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale2[0])),
"l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, floatx4 const &in,
const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
asm volatile(
"{\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
".reg.b64 re1;\n\t"
".reg.b64 re2;\n\t"
"fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
"fma.rn.f32x2 re2, %2, %3, zeros;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"mov.b64 {val1, val2}, re1;\n\t"
"mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "l"(reinterpret_cast<uint64_t const &>(in2[0])),
"l"(reinterpret_cast<uint64_t const &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e5m2x4 &out, floatx4 const &in,
const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
asm volatile(
"{\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
".reg.b64 re1;\n\t"
".reg.b64 re2;\n\t"
"fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
"fma.rn.f32x2 re2, %2, %4, zeros;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"mov.b64 {val1, val2}, re1;\n\t"
"mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "l"(reinterpret_cast<uint64_t const &>(in2[0])),
"l"(reinterpret_cast<uint64_t const &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale2[0])),
"l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, floatx4 const &in,
const ptx::floatx2 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
asm volatile(
"{\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
".reg.b64 re1;\n\t"
".reg.b64 re2;\n\t"
"fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
"fma.rn.f32x2 re2, %2, %3, zeros;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"mov.b64 {val1, val2}, re1;\n\t"
"mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "l"(reinterpret_cast<uint64_t const &>(in2[0])),
"l"(reinterpret_cast<uint64_t const &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mul_cvt_4x(fp8e4m3x4 &out, floatx4 const &in,
const floatx4 &scale) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx::floatx2 const *in2 = reinterpret_cast<ptx::floatx2 const *>(&in);
ptx::floatx2 const *scale2 = reinterpret_cast<ptx::floatx2 const *>(&scale);
asm volatile(
"{\n\t"
".reg.b64 zeros;\n\t"
"mov.b64 zeros, {0x0, 0x0};\n\t"
".reg.b64 re1;\n\t"
".reg.b64 re2;\n\t"
"fma.rn.f32x2 re1, %1, %3, zeros;\n\t"
"fma.rn.f32x2 re2, %2, %4, zeros;\n\t"
".reg.b32 val1;\n\t"
".reg.b32 val2;\n\t"
".reg.b32 val3;\n\t"
".reg.b32 val4;\n\t"
"mov.b64 {val1, val2}, re1;\n\t"
"mov.b64 {val3, val4}, re2;\n\t"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;\n\t"
#else
".reg.b16 r1;\n\t"
".reg.b16 r2;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;\n\t"
"mov.b32 %0, {r1, r2};\n\t"
#endif
"}\n\t"
: "=r"(reinterpret_cast<uint32_t &>(out))
: "l"(reinterpret_cast<uint64_t const &>(in2[0])),
"l"(reinterpret_cast<uint64_t const &>(in2[1])),
"l"(reinterpret_cast<const uint64_t &>(scale2[0])),
"l"(reinterpret_cast<const uint64_t &>(scale2[1])), "r"(0x80008000));
#else
NVTE_DEVICE_ERROR("mul_cvt_4x is only supported on SM 10.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void abs_max_2x(float &dst, const float &p1, const float &p2,
const float &p3) {
#if (defined CUDA_VERSION) && (CUDA_VERSION >= 12090)
asm volatile("max.abs.f32 %0, %1, %2, %3;" : "=f"(dst) : "f"(p1), "f"(p2), "f"(p3));
#else
asm volatile(
"max.xorsign.abs.f32 %0, %2, %3;"
"max.xorsign.abs.f32 %0, %0, %1;"
: "+f"(dst)
: "f"(p1), "f"(p2), "f"(p3));
#endif
}
__device__ __forceinline__ ptx::floatx2 up_cast(const ptx::fp16x2 &in) {
ptx::floatx2 out;
asm volatile(
"{\n\t"
".reg.b16 f16_1;\n\t"
".reg.b16 f16_2;\n\t"
"mov.b32 {f16_1, f16_2}, %2;\n\t"
"cvt.f32.f16 %0, f16_1;\n\t"
"cvt.f32.f16 %1, f16_2;\n\t"
"}\n\t"
: "=f"(out.x), "=f"(out.y)
: "r"(reinterpret_cast<int32_t const &>(in)));
return out;
}
__device__ __forceinline__ floatx4 up_cast(const fp16x4 &in) {
floatx4 out;
asm volatile(
"{\n\t"
".reg.b16 f16_1;\n\t"
".reg.b16 f16_2;\n\t"
".reg.b16 f16_3;\n\t"
".reg.b16 f16_4;\n\t"
"mov.b64 {f16_1, f16_2, f16_3, f16_4}, %4;\n\t"
"cvt.f32.f16 %0, f16_1;\n\t"
"cvt.f32.f16 %1, f16_2;\n\t"
"cvt.f32.f16 %2, f16_3;\n\t"
"cvt.f32.f16 %3, f16_4;\n\t"
"}\n\t"
: "=f"(out.x1), "=f"(out.x2), "=f"(out.x3), "=f"(out.x4)
: "l"(reinterpret_cast<int64_t const &>(in)));
return out;
}
__device__ __forceinline__ ptx::floatx2 up_cast(const ptx::bf16x2 &in) {
ptx::floatx2 out;
asm volatile(
"{\n\t"
"prmt.b32 %1, 0x0, %2, 0x7632;\n\t"
"prmt.b32 %0, 0x0, %2, 0x5410;\n\t"
"}\n\t"
: "=r"(reinterpret_cast<int32_t &>(out.x)), "=r"(reinterpret_cast<int32_t &>(out.y))
: "r"(reinterpret_cast<int32_t const &>(in)));
return out;
}
__device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
floatx4 out;
int32_t const *in2 = reinterpret_cast<int32_t const *>(&in);
asm volatile(
"{\n\t"
"prmt.b32 %1, 0x0, %4, 0x7632;\n\t"
"prmt.b32 %0, 0x0, %4, 0x5410;\n\t"
"prmt.b32 %3, 0x0, %5, 0x7632;\n\t"
"prmt.b32 %2, 0x0, %5, 0x5410;\n\t"
"}\n\t"
: "=r"(reinterpret_cast<int32_t &>(out.x1)), "=r"(reinterpret_cast<int32_t &>(out.x2)),
"=r"(reinterpret_cast<int32_t &>(out.x3)), "=r"(reinterpret_cast<int32_t &>(out.x4))
: "r"(in2[0]), "r"(in2[1]));
return out;
}
#endif
} // namespace ptx
namespace {
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -88,7 +88,8 @@
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT) \
.value("INVALID", transformer_engine::Float8BlockScaleTensorFormat::INVALID); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""The utilities for Transformer Engine"""
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment