Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
namespace transformer_engine {
namespace {
constexpr int kThreadsPerWarp = 32;
constexpr float k16x16HadamardScale = 0.25f;
template <bool kTranspose>
__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr) {
auto smem_addr = static_cast<uint32_t>(__cvta_generic_to_shared(addr));
if constexpr (kTranspose) {
asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
} else {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "r"(smem_addr));
}
}
template <bool kTranspose>
__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3,
void* addr, uint32_t stride) {
if constexpr (kTranspose) {
asm volatile(
"wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
} else {
asm volatile(
"wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 "
"{%0,%1,%2,%3}, [%4], %5;\n"
: "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3)
: "l"(addr), "r"(stride));
}
}
template <bool kTranspose>
__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1,
uint32_t& a2, uint32_t& a3, void* addr,
uint32_t stride) {
if constexpr (kTranspose) {
asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
} else {
asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n"
:
: "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride));
}
}
__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) {
asm volatile(
"movmatrix.sync.aligned.m8n8.trans.b16 "
"%0, %1;\n\t"
: "=r"(a0)
: "r"(a0));
}
__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) {
__nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16);
float f_a = __bfloat162float(bf16x2.x);
float f_b = __bfloat162float(bf16x2.y);
asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b));
float_dst = fabsf(float_dst);
}
template <bool kCalculateAmax>
__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc(
uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1,
uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3,
uint32_t& amax_result) {
uint32_t zero = 0;
uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7;
asm volatile(
"wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n"
"{%0, %1, %2, %3, %4, %5, %6, %7}, \n"
"{%8, %9, %10, %11}, \n"
"{%12, %13, %14, %15}, \n"
"{%16, %17, %18, %19, %20, %21, %22, %23};\n\t"
: "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6),
"=r"(temp7)
: "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero),
"r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4));
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6));
if constexpr (kCalculateAmax) {
uint32_t max_even;
uint32_t max_odd;
// Reduction tree to amax(abs(result)) into bf16x2 reg outparam.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3));
// N.B. mma is only called up to once per thread for identity and transpose respectively, so
// we don't have to accumulate into amax_result and can directly store into it.
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(amax_result)
: "r"(max_even), "r"(max_odd));
}
}
template <bool kReturnIdentity, bool kReturnTransposed, bool kInverseHadamardIdentity,
bool kInverseHadamardTransposed>
__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i,
uint16_t random_sign_mask,
uint32_t* had_frag_t,
uint16_t random_sign_mask_t) {
int32_t tid = threadIdx.x % 32; // Local tid
float temp_i[2];
float temp_t[2];
#pragma unroll
for (int i = 0; i < 2; i++) {
// i is the vertical fragment index.
// For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals.
uint32_t r = i * 8 + tid / 4;
#pragma unroll
for (int j = 0; j < 2; j++) {
#pragma unroll
for (int k = 0; k < 2; k++) {
// k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits.
// j is the column fragment idx selecting between even and odd fragments.
// j increments 8 columns by switching fragments.
uint32_t c = j * 8 + k + tid % 4 * 2;
// 1 -> -1.0f, 0 -> 1.0f
int32_t base_sign = __popc(r & c);
if constexpr (kReturnIdentity) {
int32_t sign_i;
// Because tensor cores want the dot product dimension,
// contiguous, the regular, non-inverse hadamard swaps
// signs of columns and rows for inverse. In a simple reference,
// x.reshape(-1, 16) @ sign @ H16, this would be opposite but
// (sign @ H16) is transposed in this fragment.
if constexpr (kInverseHadamardIdentity) {
sign_i = ((random_sign_mask >> r) ^ base_sign);
} else {
sign_i = ((random_sign_mask >> c) ^ base_sign);
}
temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31));
}
if constexpr (kReturnTransposed) {
int32_t sign_t;
if constexpr (kInverseHadamardTransposed) {
sign_t = ((random_sign_mask_t >> r) ^ base_sign);
} else {
sign_t = ((random_sign_mask_t >> c) ^ base_sign);
}
temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31));
}
}
if constexpr (kReturnIdentity) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_i[i * 2 + j])
: "f"(temp_i[1]), "f"(temp_i[0]));
}
if constexpr (kReturnTransposed) {
asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t"
: "=r"(had_frag_t[i * 2 + j])
: "f"(temp_t[1]), "f"(temp_t[0]));
}
}
}
}
__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx,
uint32_t gmem_col_idx) {
uint32_t smem_row_idx = gmem_row_idx;
uint32_t xor_factor = (smem_row_idx * 2) % 8;
uint32_t smem_col_idx = gmem_col_idx ^ xor_factor;
return smem_row_idx * 8 + smem_col_idx;
}
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment
int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);
int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;
if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_t_reg)
: "r"(local_amax_t_reg), "r"(temp_amax_t_reg));
}
if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[1]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[2])
: "r"(a_frag[2]), "r"(a_frag[3]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[2]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}
}
template <int kN>
__device__ __host__ constexpr int NextPowerOf2() {
static_assert(kN > 0, "kN must be > 0");
// Round up to the next power of 2 by counting leading zeros.
return 1 << (32 - __builtin_clz(kN - 1));
}
template <int kNumWarps, bool kReturnPreRhtAmax, bool kReturnIdentityAmax,
bool kReturnTransposedAmax>
__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax,
const float transpose_amax, float* staging_for_pre_rht,
float* staging_for_identity, float* staging_for_transpose,
float* output_pre_rht_amax_ptr,
float* output_identity_amax_ptr,
float* output_transpose_amax_ptr, const int warpid) {
// intra-warp reduction
constexpr int kWarpSize = 32;
int local_rank = threadIdx.x % 32;
float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max<kWarpSize>(pre_rht_amax) : 0.0f;
float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max<kWarpSize>(identity_amax) : 0.0f;
float warp_transpose_amax =
kReturnTransposedAmax ? warp_reduce_max<kWarpSize>(transpose_amax) : 0.0f;
// inter-warp reduction
if (threadIdx.x % 32 == 0) {
if (kReturnPreRhtAmax) {
staging_for_pre_rht[warpid] = warp_pre_rht_amax;
}
if (kReturnIdentityAmax) {
staging_for_identity[warpid] = warp_identity_amax;
}
if (kReturnTransposedAmax) {
staging_for_transpose[warpid] = warp_transpose_amax;
}
}
__syncthreads();
constexpr int kNumWarpsPow2 = NextPowerOf2<kNumWarps>();
if (warpid == 0) {
if (kReturnIdentityAmax) {
float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f;
identity_accum = warp_reduce_max<kNumWarpsPow2>(identity_accum);
if (local_rank == 0) {
atomicMaxFloat(output_identity_amax_ptr, identity_accum);
}
}
}
if (warpid == 1) {
if (kReturnTransposedAmax) {
float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f;
transpose_accum = warp_reduce_max<kNumWarpsPow2>(transpose_accum);
if (local_rank == 0) {
atomicMaxFloat(output_transpose_amax_ptr, transpose_accum);
}
}
}
if (warpid == 2) {
if (kReturnPreRhtAmax) {
float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f;
pre_rht_accum = warp_reduce_max<kNumWarpsPow2>(pre_rht_accum);
if (local_rank == 0) {
atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum);
}
}
}
}
__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
float* __restrict__ output_transpose_amax_ptr) {
if (output_pre_rht_amax_ptr != nullptr) {
*output_pre_rht_amax_ptr = 0;
}
if (output_identity_amax_ptr != nullptr) {
*output_identity_amax_ptr = 0;
}
if (output_transpose_amax_ptr != nullptr) {
*output_transpose_amax_ptr = 0;
}
}
template <typename IType, int kHadamardDimension, int CHUNK_DIM_Y, int CHUNK_DIM_X, int BUFF_DIM_Y,
int BUFF_DIM_X, int THREADS_PER_CHUNK, int THREADS_PER_Y, bool kReturnPreRhtAmax,
bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input,
float* __restrict__ output_pre_rht_amax_ptr,
float* __restrict__ output_identity_amax_ptr,
float* __restrict__ output_transpose_amax_ptr,
uint16_t random_sign_mask, uint16_t random_sign_mask_t,
uint64_t num_rows, uint64_t row_length) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0);
static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0);
constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X;
constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp;
const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X;
extern __shared__ __align__(128) char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t* dshmem = reinterpret_cast<uint8_t*>((base_shmem_ptr + 127) & ~127ULL);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
IType* in_sh_0 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_sh_1 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_shs[2] = {in_sh_0, in_sh_1};
constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t* mbar = reinterpret_cast<uint64_t*>(dshmem);
dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y);
float* max_staging_identity = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_transpose = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_pre_rht = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
initialize_barriers<STAGES_X * STAGES_Y, THREADS_PER_CHUNK * THREADS_PER_Y>(mbar,
is_master_thread);
copy_2d_to_shared(in_shs[0], reinterpret_cast<const void*>(&tensor_map_input),
input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0],
is_master_thread);
uint32_t had_frag_i[4];
uint32_t had_frag_t[4];
get_hadamard_matrix_fragment<kReturnIdentityAmax, kReturnTransposedAmax, false, false>(
had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t);
float local_pre_rht_amax = 0.0;
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_pre_rht_amax_reg = *reinterpret_cast<uint32_t*>(&local_pre_rht_amax);
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
const int next_stage = stage + 1;
const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1;
const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y;
if (next_stage < STAGES_X * STAGES_Y) {
const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y;
const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X;
copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong
reinterpret_cast<const void*>(&tensor_map_input), input_global_offset_X,
input_global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
const size_t compute_stage_x_num =
BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp));
const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y);
const size_t in_row_stride = BUFF_DIM_X;
IType* in_sh_ptr = in_shs[stage % 2];
#pragma unroll
for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) {
const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y +
threadIdx.y * kHadamardDimension);
const int in_row_offset = row_idx_offset * in_row_stride;
#pragma unroll
for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) {
ComputeKernel<IType, kHadamardDimension, BUFF_DIM_Y, BUFF_DIM_X, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads();
}
}
}
const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp;
if constexpr (kReturnPreRhtAmax) {
unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax);
}
if constexpr (kReturnIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
}
if constexpr (kReturnTransposedAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
}
ReduceMax<kNumWarps, kReturnPreRhtAmax, kReturnIdentityAmax, kReturnTransposedAmax>(
local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity,
max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr,
output_transpose_amax_ptr, warpid);
destroy_barriers<STAGES_X * STAGES_Y>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <typename T, int kHadamardDimension, bool kComputeIdentity, bool kComputeTransposed,
bool kReturnIdentity, bool kReturnTransposed, bool kUpdateIdentityAmax,
bool kUpdateTransposeAmax, bool kOutputTrueTransposed>
__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output,
T* __restrict__ output_t, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, uint64_t num_input_rows,
uint64_t num_input_cols, float* __restrict__ amax,
float* __restrict__ amax_t, bool inverse_hadamard) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported.");
// The whole threadblock will share the same smem.
extern __shared__ __align__(16) T smem[];
// Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16.
// If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices.
int32_t tid = threadIdx.x;
int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z;
int32_t local_bx = threadIdx.y;
int32_t local_by = threadIdx.z;
// Define the register fragments
uint32_t a_frag[4]; // A matrix fragment
uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major)
uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major)
uint32_t c_frag[4]; // Result fragment
// row and col for each thread. 32 threads will work together in 128 chunk to
// load the data from global memory to shared memory.
uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4));
uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4));
uint32_t smem_index = tid;
uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension;
uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension;
bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows);
if (!load) {
// Out of bound, we are returning early. No thread divergence since the whole warp
// will return early.
return;
}
uint64_t global_offset = input_start_col + input_start_row * num_input_cols;
uint64_t global_offset_t =
kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset;
T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id;
uint32_t* smem_b32 = reinterpret_cast<uint32_t*>(base_smem);
uint4* smem_b128 = reinterpret_cast<uint4*>(base_smem);
// Asynchronously load the data from global memory to shared memory.
const uint4* input_b128 = reinterpret_cast<const uint4*>(input + global_offset);
// Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each
// 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4
// to load the data in the tensor core swizzled format.
__pipeline_memcpy_async(&smem_b128[smem_index],
&input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col],
sizeof(uint4));
__pipeline_commit(); // Commit the memcpy. Wait when we are in the computation.
if (inverse_hadamard) {
get_hadamard_matrix_fragment<kComputeIdentity, kComputeTransposed,
/*kInverseHadamard=*/true,
/*kInverseHadamardTransposed=*/true>(b_frag_i, random_sign_mask,
b_frag_t, random_sign_mask_t);
} else {
get_hadamard_matrix_fragment<kComputeIdentity, kComputeTransposed,
/*kInverseHadamard=*/false,
/*kInverseHadamardTransposed=*/false>(
b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t);
}
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
__pipeline_wait_prior(0);
__syncwarp(); // ensure all lanes finished their cp.async before reading smem
// Load the A to a_frag.
if constexpr (kComputeIdentity) {
load_matrix_16x16_from_shared<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32,
kHadamardDimension);
// 16x16 @ 16x16 leveraging all threads in the warp.
mma_m16_n16_k16_b16_b16_b16_noacc<kUpdateIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg);
// Store the result to the shared memory in non-transposed order.
if constexpr (kReturnIdentity) {
uint4* output_b128 = reinterpret_cast<uint4*>(output + global_offset);
store_matrix_16x16_to_global<false>(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128,
num_input_cols);
}
}
if constexpr (kComputeTransposed) {
if (kComputeIdentity) {
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
} else {
load_matrix_16x16_from_shared<true>(a_frag[0],
a_frag[2], // NOTE: intentional index swapping
a_frag[1], // NOTE: intentional index swapping
a_frag[3], smem_b32, kHadamardDimension);
}
mma_m16_n16_k16_b16_b16_b16_noacc<kUpdateTransposeAmax>(
a_frag[0],
// 2,1 is used if we are using movmatrix instruction.
// Thus loading the matrix in 2,1 order will just be normal.
// This is to be compatible with the movmatrix instruction.
a_frag[2], // NOTE: intentional index swapping for transpose purpose.
a_frag[1], // NOTE: intentional index swapping for transpose purpose.
a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1],
c_frag[2], c_frag[3], local_amax_t_reg);
// Store the result to the shared memory in non-transposed order.
if constexpr (kReturnTransposed) {
uint4* output_t_b128 = reinterpret_cast<uint4*>(output_t + global_offset_t);
store_matrix_16x16_to_global<!kOutputTrueTransposed>(
c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128,
kOutputTrueTransposed ? num_input_rows : num_input_cols);
}
}
if constexpr (kUpdateIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
local_amax = warp_reduce_max<kThreadsPerWarp>(local_amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero);
// atomic CAS to output memory.
if (tid % kThreadsPerWarp == 0) {
atomicMaxFloat(amax, local_amax);
}
}
if constexpr (kUpdateTransposeAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
local_amax_t = warp_reduce_max<kThreadsPerWarp>(local_amax_t);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero);
// atomic CAS to output memory.
if (tid % kThreadsPerWarp == 0) {
atomicMaxFloat(amax_t, local_amax_t);
}
}
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+.");
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
}
} // namespace
void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform);
// Check tensors
// NOTE (frsun): This is non-intuitive, we are writing the result of
// transposed RHT to the output of rowwise.
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be simple tensor, but scaling mode is ",
to_string(output_.scaling_mode), ".");
const SimpleTensor& input = input_.data;
SimpleTensor output;
SimpleTensor& output_t = output_.data;
// Check requested outputs
const bool return_identity = output.dptr != nullptr;
const bool return_transposed = output_t.dptr != nullptr;
if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior.
return;
}
checkCuDriverContext(stream);
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
using IType = bf16;
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
constexpr uint64_t kThreadBlockX = 4;
// Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth.
constexpr uint64_t kThreadBlockY = 4;
uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY;
// The shared memory number of bytes required for **the whole threadblock**.
size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM;
dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY);
dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX),
DIVUP(num_rows / kHadamardDimension, kThreadBlockY));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transposed, kReturnTransposed,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
auto kernel =
HadamardTransformKernel<IType, kHadamardDimension, kReturnIdentity, kReturnTransposed,
kReturnIdentity, kReturnTransposed, false, false, true>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
reinterpret_cast<const IType*>(input.dptr), reinterpret_cast<IType*>(output.dptr),
reinterpret_cast<IType*>(output_t.dptr), random_sign_mask, random_sign_mask_t,
num_rows, row_length, nullptr, nullptr, false);););
NVTE_CHECK_CUDA(cudaGetLastError());
}
// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then
// get the absolute max value of the result.
void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform_amax);
#if CUDA_VERSION >= 12080
// Check input tensor
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor& input = input_.data;
// Check amax tensors
SimpleTensor& output_pre_rht_amax = output_.amax;
SimpleTensor output_identity_amax;
SimpleTensor& output_transpose_amax = output_.columnwise_amax;
// Check requested outputs
const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr;
const bool return_identity_amax = output_identity_amax.dptr != nullptr;
const bool return_transposed_amax = output_transpose_amax.dptr != nullptr;
if (!return_identity_amax && !return_transposed_amax &&
!return_pre_rht_amax) { // Nothing to do/ill-defined behavior.
return;
}
// Zero out amaxes if needed
ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast<float*>(output_pre_rht_amax.dptr),
reinterpret_cast<float*>(output_identity_amax.dptr),
reinterpret_cast<float*>(output_transpose_amax.dptr));
NVTE_CHECK_CUDA(cudaGetLastError());
checkCuDriverContext(stream);
using IType = bf16;
const size_t ndim = input.shape.size();
const size_t row_length = input.shape[ndim - 1];
size_t num_rows = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
num_rows *= input.shape[i];
}
constexpr int kHadamardDimension = 16;
NVTE_CHECK(row_length % kHadamardDimension == 0,
"row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(num_rows % kHadamardDimension == 0,
"num_rows must be divisible by hadamard_dimension");
constexpr uint64_t kChunkBlockXSmall = 128;
constexpr uint64_t kChunkBlockYSmall = 128;
constexpr uint64_t kBuffDimX = 64;
constexpr uint64_t kBuffDimY = 64;
alignas(64) CUtensorMap tensor_map_input{};
create_2D_tensor_map(
/*tensorMap=*/tensor_map_input,
/*tensor=*/input,
/*globalY=*/num_rows,
/*globalX=*/row_length,
/*shmemY=*/kBuffDimY,
/*shmemX=*/kBuffDimX,
/*stride_elems=*/row_length,
/*offset_elems=*/0,
/*type_num_bits=*/sizeof(IType) * 8,
/*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B);
constexpr uint64_t kThreadBlockX = 4;
constexpr uint64_t kThreadBlockY = 1;
constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY;
dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY);
dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall));
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transposed_amax, kReturnTransposedAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity_amax, kReturnIdentityAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_pre_rht_amax, kReturnPreRhtAmax,
// *2 for ping-pong
size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType);
size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) *
(kChunkBlockYSmall / kBuffDimY);
size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes = (shmem_bytes + 128);
auto kernel = HadamardAmaxTmaKernel<
IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY,
kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
tensor_map_input, reinterpret_cast<float*>(output_pre_rht_amax.dptr),
reinterpret_cast<float*>(output_identity_amax.dptr),
reinterpret_cast<float*>(output_transpose_amax.dptr), random_sign_mask,
random_sign_mask_t, num_rows, row_length);)));
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ",
CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace transformer_engine
void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform);
using namespace transformer_engine;
hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), stream);
}
void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform_amax);
using namespace transformer_engine;
hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
#include "cutlass/arch/barrier.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/helper_cuda.hpp"
#include "cutlass/util/print_error.hpp"
// clang-format off
namespace transformer_engine {
namespace detail {
namespace {
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread());
using namespace cute;
using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor
// calculate the global encode scale factor for a given global amax.
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float kFP8E4M3Max = 448.0f;
constexpr float kFP4E2M1Max = 6.0f;
// If scale is infinity, return max value of float32
float global_encode_scale = cutlass::minimum_with_nan_propagation<float>{}(
kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits<float>::max());
// If global amax is 0 or infinity, return 1
return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale;
}
template <class ElementA,
class ElementB,
class ASmemLayout,
class BSmemLayout>
struct SharedStorage {
static constexpr int AccumulatorPipelineStageCount = 16;
using AtomThrShapeMNK = cute::Shape<_1, _1, _1>;
using AccumulatorPipeline = cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
MainloopPipelineStageCount,
Shape<_1,_1,_1>,
AtomThrShapeMNK>;
using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage;
alignas(16) AccumulatorPipelineStorage accumulator;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) cute::uint64_t tma_barrier[1];
uint32_t tmem_base_ptr;
struct TensorStorage : cute::aligned_struct<128, _1> {
// cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementB, cute::cosize_v<BSmemLayout>> smem_B;
} tensors;
};
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 8>
StochasticNumericConverterBase(cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
asm volatile( \
"{\n" \
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \
"}" \
: "=h"(output_ptr[0]),
"=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]),
"f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]),
"r"(rbits[0]), "r"(rbits[1]));
#else
NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return output;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 16>
StochasticNumericConverter(cutlass::Array<float, 16> const &input, cutlass::Array<uint32_t, 4> const *rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 16>;
result_type output;
cutlass::Array<cutlass::float_e2m1_t, 8> *result_ptr = reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8> *>(&output);
cutlass::Array<float, 8> const *source_ptr = reinterpret_cast<cutlass::Array<float, 8> const *>(&input);
cutlass::Array<uint32_t, 2> const *rbits_ptr = reinterpret_cast<cutlass::Array<uint32_t, 2> const *>(rbits);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; i++) {
result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]);
}
return output;
}
template <class MShape, class NShape, class KShape, class ClusterTileShape,
class TA, class AStride, class ASmemLayout, class TmaLoadA,
class TB, class BStride, class BSmemLayout, class TmaLoadB,
class TC, class CStride, class CSmemLayout,
class TSFC,
class TiledMMA,
bool kEnableStochasticRounding = false>
__global__ static
void
rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile,
TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a,
TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b,
TC * C, CStride dC, CSmemLayout ,
TSFC * SFC,
TiledMMA mma,
float const* global_amax,
const size_t* rng_state)
{
using namespace cute;
using X = Underscore;
// static constexpr bool kApplyStochasticRounding = true;
using ElementAccumulator = float;
static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{});
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>;
static constexpr uint32_t kTmaTransactionBytes =
cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v<TA>);
static constexpr int kTmaRhtTensorTransactionBytes =
cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v<TB>);
static constexpr int AccumulatorPipelineStageCount = 16;
static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline = cutlass::PipelineTmaUmmaAsync<
MainloopPipelineStageCount,
Shape<_1,_1,_1>,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
using TmemAllocator = cute::TMEM::Allocator1Sm;
static constexpr int VectorSize = 16;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
// Preconditions
CUTE_STATIC_ASSERT(is_static<ASmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<BSmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<CSmemLayout>::value);
// Represent the full tensors
Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N));
Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16));
Tensor mC = make_tensor(cute::subbyte_iterator<TC>(C), make_shape(M,N), dC); // (M,N)
auto sfc_shape = make_shape(
M,
make_shape( make_shape(Int<16>{}, _4{}), N / 64 )
);
auto sfc_stride = make_stride(
N / 16,
make_stride( make_stride(_0{}, _1{}), _4{} )
);
auto sfc_layout = make_layout(sfc_shape, sfc_stride);
Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout);
auto cluster_shape = Shape< _1, _1, _1>{};
// Get the appropriate blocks for this Cluster
dim3 cluster_coord_in_grid = cluster_id_in_grid();
// Total number of k-tiles
const int K_TILE_MAX = min(N, K) / 64;
uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile);
uint32_t tiles_in_n = (N + 64 - 1) / 64;
uint32_t linear_tile_idx = blockIdx.x;
uint32_t tile_idx_m = linear_tile_idx % tiles_in_m;
uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
auto mainloop_tiler = Shape<_128,_16,_64>{};
auto epilogue_tiler = Shape<_128,_64,_64>{};
Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{});
Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k)
Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N)
// Allocate SMEM
extern __shared__ char shared_memory[];
using SharedStorage = SharedStorage<TA, TB, ASmemLayout, BSmemLayout>;
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(shared_memory);
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
//
// MMA: Define C accumulators and A/B partitioning
//
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k)
auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, ElementAccumulator,
128, 64,
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1,_1>>{});
ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster);
using TiledMmaEpilogue = decltype(mma_epilogue);
Tensor tCgA = thr_mma.partition_A(gA_mk);
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE)
auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{}));
auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler));
auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma,
Int<AccumulatorPipelineStageCount>{}));
auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue,
Int<AccumulatorPipelineStageCount / 4>{}));
TmemAllocator tmem_allocator{};
cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
Layout cta_layout_mnk = make_layout(cluster_shape);
Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster);
auto [tAgA, tAsA] = tma_partition(tma_load_a,
get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0,3>(tCsA), group_modes<0,3>(tCgA));
auto [tBgB, tBsB] = tma_partition(tma_load_b,
get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0,3>(tCsB), group_modes<0,3>(tCgB));
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
int warp_idx = cutlass::canonical_warp_idx_sync();
bool is_mma_warp = (warp_idx == 0);
bool is_dma_warp = (warp_idx == 1);
bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7);
if (is_epilogue_warp && elect_one_sync()) {
cute::prefetch(raw_pointer_cast(global_amax));
}
typename MainloopPipeline::Params mainloop_pipeline_params;
if (is_dma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (is_mma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp;
mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes;
mainloop_pipeline_params.initializing_warp = 0;
MainloopPipeline mainloop_pipeline(shared_storage.mainloop,
mainloop_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
MainloopPipelineState mainloop_pipe_consumer_state;
MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state<MainloopPipeline>();
using AccumulatorPipeline = cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / 4, AtomThrShapeMNK>;
using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState;
AccumulatorPipelineState accumulator_pipe_consumer_state;
AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state<AccumulatorPipeline>();
typename AccumulatorPipeline::Params accumulator_pipeline_params;
if (is_mma_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer;
}
if (is_epilogue_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params.producer_arv_count = 1;
accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128;
accumulator_pipeline_params.initializing_warp = 1;
AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator,
accumulator_pipeline_params,
cluster_shape,
cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
if (warp_idx == 2 && elect_one_sync()) {
cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1);
}
__syncthreads();
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
if (is_dma_warp) {
if (elect_one_sync()) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes);
copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0));
}
cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/);
do {
bool is_first_wave = linear_tile_idx == blockIdx.x;
uint32_t skip_wait = is_first_wave;
auto tAgA_mk = tAgA(_,tile_idx_m,_);
int k_tile = 0;
auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
CUTE_NO_UNROLL
while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) {
int k_tile_idx_n = tile_idx_n + k_tile;
++k_tile;
skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount);
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait);
if (cute::elect_one_sync()) {
copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage));
}
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
} else if (is_mma_warp) {
mma.accumulate_ = UMMA::ScaleOut::Zero;
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
tmem_allocation_result_barrier.arrive();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_mma.data() = tmem_base_ptr;
do {
uint32_t skip_wait = K_TILE_MAX <= 0;
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
CUTE_NO_UNROLL
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; )
{
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
int read_stage = mainloop_pipe_consumer_state.index();
auto tCrA_mk = tCrA(_,_,_,read_stage);
auto tCrB_nk = tCrB(_,_,0,0);
CUTE_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block)
{
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTE_UNROLL
for (int i = 0; i < 4; i++) {
auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i);
gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators);
}
accumulator_pipeline.producer_commit(accumulator_pipe_producer_state);
++accumulator_pipe_producer_state;
}
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
++mainloop_pipe_consumer_state;
++k_tile;
skip_wait = k_tile >= K_TILE_MAX;
barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
tmem_allocator.release_allocation_lock();
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
} else if (is_epilogue_warp) {
const float global_amax_val = *global_amax;
static constexpr int FragmentSize = 256 / sizeof_bits_v<TC>;
tmem_allocation_result_barrier.arrive_and_wait();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_epilogue.data() = tmem_base_ptr;
int thread_idx = threadIdx.x % 128;
Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N)
auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{}));
auto tiled_r2g = make_tiled_copy_D(Copy_Atom<SM100_STORE_256bit_CACHE_NOALLOCATION, TC>{}, tiled_t2r);
auto thr_t2r = tiled_t2r.get_slice(thread_idx);
auto thr_r2g = tiled_r2g.get_slice(thread_idx);
// NVFP4 non-E8 recipe constants and global scales
static constexpr float fp4_max = 6.0f;
const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val);
const float global_decode_scale = 1.0f / global_encode_scale;
auto sfd_converter = cutlass::NumericConverter<TSFC, float>{};
do {
for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) {
Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile);
Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile);
accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state);
auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index());
Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tTR_rAcc = make_tensor<ElementAccumulator>(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDrC = make_tensor<TC>(shape(tDgC));
Tensor tTR_rAcc_frag = recast<cutlass::Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
Tensor tDrC_frag = recast<cutlass::Array<TC, FragmentSize>>(coalesce(tDrC));
Tensor src = thr_r2g.retile_S(tDrC);
Tensor dst = thr_r2g.retile_D(tDgC);
Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout(
make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}),
make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{})
));
Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC));
Tensor tDrSFC = make_tensor<TSFC>(shape(tDgSFC));
static constexpr int NumVecs = size(tDgC) / VectorSize;
Tensor tC_rRowSFD_frg = recast<cutlass::Array<TSFC, NumVecs>>(tDrSFC);
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>, true> amax_reduction;
cutlass::Array<ElementAccumulator, NumVecs> vec_maxs;
cutlass::Array<ElementAccumulator, NumVecs> pvscales;
// TMEM_LOAD
copy(tiled_t2r, tDtC, tTR_rAcc);
cutlass::arch::fence_view_async_tmem_load();
accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state);
++accumulator_pipe_consumer_state;
// Cast data from FP32 to BF16 to FP32.
auto convert_accum_to_bf16 = cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator, FragmentSize>{};
auto convert_bf16_to_accum = cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t, FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
auto compute_frgs = reinterpret_cast<cutlass::Array< ElementAccumulator, VectorSize> *>(tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array< TC, VectorSize> *>(tDrC_frag.data());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
pvscales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(pvscales, global_encode_scale);
auto pvscales_cvted = cutlass::NumericArrayConverter<TSFC, ElementAccumulator, NumVecs>{}(pvscales);
tC_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFC, NumVecs>{}(tC_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(qpvscale_ups, global_decode_scale);
auto acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(1.0, qpvscale_scaled);
// Initialize RNG for tile
const size_t rng_sequence
= thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = uint4{0, 0, 0, 0};
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
// auto acc_scale = acc_scales[v];
if constexpr (kEnableStochasticRounding) {
random_uint4 = dist.generate4(rng);
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v],
acc_scale
),
reinterpret_cast<cutlass::Array<uint32_t, 4>*>(&random_uint4));
} else {
output_frgs[v] = cutlass::NumericArrayConverter<TC, ElementAccumulator, VectorSize>{}(cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(compute_frgs[v], acc_scale));
}
}
copy(tiled_r2g, src, dst);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC);
}
linear_tile_idx += gridDim.x;
tile_idx_m = linear_tile_idx % tiles_in_m;
tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX;
} while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n);
}
}
// this function computes RHT-GEMM for
// A: m x n: col-major
// B: 16 x 16: row-major
// C: m x n: row-major
// SFC: m x (n/16): row-major
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
void
rht_gemm_ntt_w_sfc(int m, int n,
TA const* A,
TB const* B,
TC * C,
TSFC * SFC,
float const* global_amax,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 2048)
{
using namespace cute;
// Define shapes (dynamic)
auto M = static_cast<int>(m);
auto N = static_cast<int>(n);
// Define strides (mixed)
auto dA = make_stride(Int<1>{}, m); // (dM,dK)
auto dB = make_stride(Int<1>{}, 16); // (dN,dK)
auto dC = make_stride(n, Int<1>{}); // (dM,dN)
auto cga_shape = Shape< _1, _1, _1>{};
auto cga_tile_shape = Shape<_128,_16,_16>{};
auto cluster_tile_mainloop = Shape<_128,_16,_64>{};
// Construct the MMA
auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS<TA, TB, float,
128, 16,
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1,_1>>{});
// MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never}
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma));
CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma)));
// Determine the A and B shapes
auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape)));
using TiledMma = decltype(mma);
using AtomThrID = typename TiledMma::AtomThrID;
using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(cga_tile_shape));
using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>());
auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop)));
using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{}))));
using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>());
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes
constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB);
constexpr int kReservedBytes = 256; // Reserve for barriers and other uses
constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage;
auto sP = Int<kMaxStages>{}; // SMEM pipelines
auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE)
auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE)
auto sC = Layout<_1>{}; // XXX Dummy
// Create GMEM tensors
Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N)
Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16)
// Create the TiledCopy
auto tma_load_a = make_tma_copy_A_sm100(
SM90_TMA_LOAD{},
tensorA,
sA(_,_,_,0),
cluster_tile_mainloop,
mma);
auto tma_load_b = make_tma_copy_B_sm100(
SM90_TMA_LOAD{},
tensorB,
sB(_,_,_,0),
cga_tile_shape,
mma);
// Assert checks on tile sizes -- no predication
NVTE_CHECK(M % size<0>(cga_tile_shape) == 0,
"Inner dimension must be divisible by ", static_cast<size_t>(size<0>(cga_tile_shape)), " but got ", M, ".");
NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0,
"Outer dimension must be divisible by ", 4 * static_cast<size_t>(size<1>(cga_tile_shape)),
" but got ", N, ".");
uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size));
tiles = (tiles < sm_count) ? tiles : sm_count;
dim3 dimBlock(256);
dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape));
dim3 dimGrid(tiles, 1, 1);
int smem_size = sizeof(SharedStorage<TA, TB, decltype(sA), decltype(sB)>);
auto* kernel_ptr = &rht_gemm_device<
decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape),
TA, decltype(dA), decltype(sA), decltype(tma_load_a),
TB, decltype(dB), decltype(sB), decltype(tma_load_b),
TC, decltype(dC), decltype(sC),
TSFC,
decltype(mma),
kEnableStochasticRounding>;
bool status = cudaFuncSetAttribute(*kernel_ptr,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (status != cudaSuccess) {
std::cerr << "Error: Failed to set Shared Memory size." << std::endl;
return;
}
(*kernel_ptr)
<<< dimGrid, dimBlock, smem_size, stream >>>
(M, N, k_tile_size, cga_tile_shape,
A, dA, sA, tma_load_a,
B, dB, sB, tma_load_b,
C, dC, sC,
SFC,
mma, global_amax,
rng_state);
}
// this function is used to wrap the rht_gemm_ntt_w_sfc function
//to transpose the input tensor A
template <typename TA, typename TB, typename TC, typename TSFC, bool kEnableStochasticRounding = false>
void
rht_gemm_ttt_wrapper(int m, int n,
TA const* A,
TB const* B,
TC * C,
TSFC * SFC,
float const* global_amax,
const size_t* rng_state,
uint32_t sm_count,
cudaStream_t stream,
int k_tile_size = 1024)
{
// in addition to transpose the input tensor A
// we also need to reshape m, n to at best
// ultilize as many SMs as possible while keeping
// a relatively large contiguous dimension.
// for example, after swapping m, n for transpose purposes,
// the input / output tensor shapes for RHT-GEMM are:
// A: n x m: col-major
// B: 16 x 16: row-major
// C: n x m: row-major
// SFC: n x (m/16): row-major
rht_gemm_ntt_w_sfc<TA, TB, TC, TSFC, kEnableStochasticRounding>(
n, m,
A, B, C,
SFC, global_amax,
rng_state,
sm_count, stream,
k_tile_size);
}
} // namespace
} // namespace detail
// clang-format on
void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_,
const Tensor &hadamard_matrix_,
QuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise);
// Check input and output tensors
NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be BF16 tensor, but scaling mode is ",
to_string(input_.scaling_mode), ".");
NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16,
"Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), ".");
NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor.");
const SimpleTensor &input = input_.data;
SimpleTensor &global_amax = output_.amax;
SimpleTensor &output_t = output_.data;
SimpleTensor &scale_inv_t = output_.scale_inv;
// Stochastic rounding config
const bool use_stochastic_rounding = quant_config.stochastic_rounding;
const size_t *rng_state = nullptr;
if (quant_config.rng_state != nullptr) {
Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state);
NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TC = cutlass::float_e2m1_t;
using TSFC = cutlass::float_ue4m3_t;
checkCuDriverContext(stream);
// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Hadamard matrix must be BF16 tensor, but scaling mode is ",
to_string(hadamard_matrix_.scaling_mode), ".");
NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16,
"Hadamard matrix must be BF16 tensor, but dtype is ",
to_string(hadamard_matrix_.dtype()), ".");
const SimpleTensor &hadamard_matrix = hadamard_matrix_.data;
NVTE_CHECK(
(hadamard_matrix_.shape() == std::vector<size_t>{kHadamardDimension, kHadamardDimension}),
"Hadamard matrix must have shape=",
std::vector<size_t>{kHadamardDimension, kHadamardDimension},
", but got shape=", hadamard_matrix_.shape(), ".");
const size_t hadamard_dimension = hadamard_matrix.shape[0];
const size_t ndim = input.shape.size();
const size_t n = input.shape[ndim - 1];
size_t m = 1;
for (size_t i = 0; i < ndim - 1; ++i) {
m *= input.shape[i];
}
auto sm_count = transformer_engine::cuda::sm_count();
NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension.");
NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension");
int k_tile_size = 1024;
if (m == 8192 && n == 5120) {
k_tile_size = 512;
} else if (m == 8192 && n == 10240) {
k_tile_size = 1024;
} else if (m == 8192 && n == 2560) {
k_tile_size = 1280;
} else if (m == 8192 && n == 11328) {
k_tile_size = 1024;
} else if (m == 8192 && n == 512) {
k_tile_size = 256;
} else if (m == 8192 && n == 3584) {
k_tile_size = 512;
} else if (m == 11328 && n == 8192) {
k_tile_size = 1024;
} else if (m == 5120 && n == 8192) {
k_tile_size = 512;
} else if (m == 10240 && n == 8192) {
k_tile_size = 1024;
} else if (m == 2560 && n == 8192) {
k_tile_size = 1280;
} else if (m == 512 && n == 8192) {
k_tile_size = 256;
} else if (m == 3584 && n == 8192) {
k_tile_size = 512;
} else if (m < 1024 || n < 1024) {
k_tile_size = 512;
}
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kUseStochasticRounding,
detail::rht_gemm_ttt_wrapper<TA, TB, TC, TSFC, kUseStochasticRounding>(
/*m=*/m,
/*n=*/n,
/*A=*/reinterpret_cast<TA const *>(input.dptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*C=*/reinterpret_cast<TC *>(output_t.dptr),
/*SFC=*/reinterpret_cast<TSFC *>(scale_inv_t.dptr),
/*global_amax=*/reinterpret_cast<float const *>(global_amax.dptr),
/*rng_state=*/rng_state,
/*sm_count=*/sm_count,
/*stream=*/stream,
/*k_tile_size=*/k_tile_size););
}
} // namespace transformer_engine
void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output,
const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise);
using namespace transformer_engine;
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
hadamard_transform_cast_fusion_columnwise(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(output),
*convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream);
}
......@@ -67,6 +67,11 @@ class CommOverlapCore {
std::vector<cudaStream_t> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event;
private:
void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size,
int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin,
bool use_ce, bool atomic_gemm);
public:
CommOverlapCore() {} // dummy constructor for exposing type to Python
......@@ -78,17 +83,26 @@ class CommOverlapCore {
virtual ~CommOverlapCore();
void *get_ubuf_dptr() { return _ubuf.dptr(); }
void set_ubuf_scale_inv(float *scale_inv) {
_ubuf_scale_inv = scale_inv;
_ubuf_scale_inv_initialized = true;
}
virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) {
NVTE_ERROR("Operation is not implemented.");
}
TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset,
const std::vector<size_t> &shape);
int get_tp_size() { return _tp_size; }
bool is_atomic_gemm() { return _atomic_gemm; }
bool is_p2p_overlap() { return _is_p2p; }
......@@ -150,6 +164,10 @@ class CommOverlapBase : public CommOverlapCore {
cudaStream_t _stream_comm;
cudaEvent_t _start_d2dcopy;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
bool rs_overlap_first_gemm);
public:
CommOverlapBase() {} // dummy constructor for exposing type to Python
......@@ -228,6 +246,10 @@ class CommOverlapP2PBase : public CommOverlapCore {
cudaStream_t _stream_recv;
cudaEvent_t _stop_send, _stop_recv;
private:
void initialize(const std::vector<size_t> &buffer_shape, DType buffer_dtype,
CommOverlapType comm_type, bool aggregate);
public:
CommOverlapP2PBase() {} // dummy constructor for exposing type to Python
......@@ -241,6 +263,9 @@ class CommOverlapP2PBase : public CommOverlapCore {
virtual ~CommOverlapP2PBase();
void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk,
bool rowwise = true) override;
TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id);
void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb,
......
......@@ -124,6 +124,24 @@ enum NVTE_Mask_Type {
NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5,
};
/*! \enum NVTE_Softmax_Type
* \brief Attention softmax types as described in
* Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3).
* For a given attention score S = Q*K^T, different softmax types perform different operations on S,
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
NVTE_VANILLA_SOFTMAX = 0,
/*! Off-by-one softmax */
NVTE_OFF_BY_ONE_SOFTMAX = 1,
/*! Learnable softmax */
NVTE_LEARNABLE_SOFTMAX = 2,
};
/*! \enum NVTE_Fused_Attn_Backend
* \brief Fused attention backends
*/
......@@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] qkv_layout The layout of Tensors Q, K, V.
* \param[in] bias_type The attention bias type.
* \param[in] attn_mask_type The attention mask type.
* \param[in] softmax_type The attention softmax type.
* \param[in] dropout The dropout probability.
* \param[in] num_attn_heads The number of heads in Q.
* \param[in] num_gqa_groups The number of heads in K, V.
......@@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
*
* \param[in] QKV The QKV tensor in packed format, H3D or 3HD.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
const NVTETensor rng_state, size_t max_seqlen, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* e.g. M, ZInv, rng_state.
* \param[out] dQKV The gradient of the QKV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1].
* \param[in] max_seqlen Max sequence length used for computing,
......@@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV,
NVTETensor dBias, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, size_t max_seqlen,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool deterministic, NVTETensor workspace, cudaStream_t stream);
......@@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] Q The Q tensor, in HD layouts.
* \param[in] KV The KV tensor, in 2HD or H2D layouts.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset,
NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q,
size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[out] dQ The gradient of the Q tensor.
* \param[out] dKV The gradient of the KV tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] qkv_layout QKV tensor's layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked(
void nvte_fused_attn_bwd_kvpacked(
const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO,
const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ,
NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] K The K tensor.
* \param[in] V The V tensor.
* \param[in] Bias The Bias tensor.
* \param[in] SoftmaxOffset The SoftmaxOffset tensor.
* \param[in,out] S The S tensor.
* \param[out] O The output O tensor.
* \param[out] Aux_CTX_Tensors Auxiliary output tensors when training,
......@@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream);
NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[out] dK The gradient of the K tensor.
* \param[out] dV The gradient of the V tensor.
* \param[out] dBias The gradient of the Bias tensor.
* \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor.
* \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1].
* \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1].
* \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1].
......@@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
* \param[in] qkv_layout QKV tensors' layout.
* \param[in] bias_type Bias type.
* \param[in] attn_mask_type Attention mask type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
......@@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP,
const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK,
NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q,
const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded,
NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q,
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, NVTETensor workspace,
cudaStream_t stream);
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
......
......@@ -15,9 +15,76 @@
#ifdef __cplusplus
extern "C" {
#endif
#endif // __cplusplus
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
/*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
enum NVTEMatmulConfigAttribute {
/*! Bias tensor
*
* If provided, the bias tensor is applied in the GEMM epilogue.
*/
kNVTEMatmulConfigBiasTensor = 0,
/*! Bias gradient tensor
*
* If provided, the bias gradient tensor will be filled in the GEMM epilogue.
*/
kNVTEMatmulConfigDBiasTensor = 1,
/*! Whether to compute GELU in GEMM epilogue. */
kNVTEMatmulConfigWithGELUEpilogue = 2,
/*! Whether to compute GELU backward in GEMM epilogue. */
kNVTEMatmulConfigWithDGELUEpilogue = 3,
/*! Auxilliary tensor for GEMM epilogue.
*
* For GELU, this will be filled with the GELU input. For GELU
* backward, this is expected to already be filled with the GELU
* input.
*/
kNVTEMatmulConfigEpilogueAuxTensor = 4,
/*! Whether to use split accumulator for FP8 GEMM. */
kNVTEMatmulConfigUseSplitAccumulator = 5,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEMatmulConfigSMCount = 6,
kNVTEMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
void *buf, size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors
......@@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons
NVTETensor workspace, bool accumulate, bool use_split_accumulator,
int math_sm_count, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations.
*
* Computes:
* - `D = alpha * op(A) * op(B) + beta * C`
*
* \param[in] transa Whether to transpose A matrix.
* \param[in] transb Whether to transpose B matrix.
* \param[in] alpha Scaling factor applied to matmul output.
* \param[in] A A matrix.
* \param[in] B B matrix.
* \param[in] beta Scaling factor applied to C matrix.
* \param[in] C C matrix.
* \param[out] D Output matrix.
* \param[in] workspace Workspace tensor.
* \param[in] config Additional configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A,
const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D,
NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream, bool nvte_use_hipblaslt = 0, bool nvte_use_rocblas = 0, int compute_stream_offset = 0);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations,
* allowing for using a scaling factor for the GEMM result and the accumulation input
* allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated)
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
*
* Computes:
* - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors
......@@ -133,9 +223,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream to wait on.
*/
void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor* workspace,
void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D,
const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms,
bool transa, bool transb, bool grad, NVTETensor *workspace,
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
......@@ -165,7 +255,9 @@ void nvte_cublas_batchgemm_v3(const NVTETensor A, const NVTETensor B, const NVTE
#ifdef __cplusplus
} // extern "C"
#endif
#endif // __cplusplus
#ifdef __cplusplus
/*! \namespace transformer_engine
*/
......@@ -183,6 +275,89 @@ constexpr int num_batchgemm_streams = 1;
void nvte_cublas_handle_init();
/*! \struct MatmulConfigWrapper
* \brief C++ wrapper for NVTEMatmulConfig.
*/
class MatmulConfigWrapper {
public:
MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {}
MatmulConfigWrapper(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~MatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEMatmulConfig.
*
* \return NVTEMatmulConfig held by this MatmulConfigWrapper.
*/
operator NVTEMatmulConfig() const noexcept { return config_; }
/*! \brief Set bias tensor. */
void set_bias_tensor(NVTETensor bias_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor,
sizeof(NVTETensor));
}
/*! \brief Set bias gradient tensor. */
void set_dbias_tensor(NVTETensor dbias_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor,
sizeof(NVTETensor));
}
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void set_with_gelu_epilogue(bool with_gelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue,
&with_gelu_epilogue, sizeof(bool));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void set_with_dgelu_epilogue(bool with_dgelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue,
&with_dgelu_epilogue, sizeof(bool));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor,
&epilogue_aux_tensor, sizeof(NVTETensor));
}
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void set_use_split_accumulator(bool use_split_accumulator) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator,
&use_split_accumulator, sizeof(bool));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void set_sm_count(int sm_count) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int));
}
private:
/*! \brief Wrapped NVTEMatmulConfig. */
NVTEMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine
#endif // __cplusplus
#endif // TRANSFORMER_ENGINE_GEMM_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file hadamard_transform.h
* \brief Functions for Hadamard transforms.
*/
#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
#include "transformer_engine.h"
#ifdef __cplusplus
extern "C" {
#endif
/*! \brief Perform a randomized Hadamard transform on the input tensor.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! \brief Perform the absolute maximum reduction on the input tensor with/without
* randomized hadamard transform. The rowwise result is the absolute maximum
* of the input tensor. The columnwise result is the absolute maximum of the
* input tensor transposed and applied randomized hadamard transformation.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] random_sign_mask 16-bit sign mask.
* \param[in] random_sign_mask_t 16-bit sign mask.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! \brief Perform the columnwise hadamard transform cast fusion.
*
* This function is experimental and the API is not stable.
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] output Output tensor.
* \param[in] hadamard_matrix Hadamard matrix.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output,
const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_
......@@ -124,6 +124,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -73,6 +73,7 @@ enum NVTETensorParam {
kNVTEAmax = 3, /*!< Amax tensor */
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTENumTensorParams
};
......@@ -95,10 +96,9 @@ enum NVTEScalingMode {
*/
NVTE_BLOCK_SCALING_1D = 2,
NVTE_BLOCK_SCALING_2D = 3,
/*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD),
and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD).
*/
NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4,
/*! Single scale per block of 16 elements consecutive in either
* rowwise or columnwise direction */
NVTE_NVFP4_1D_SCALING = 4,
NVTE_INVALID_SCALING = 100
};
......@@ -337,6 +337,12 @@ enum NVTEQuantizationConfigAttribute {
* likely be refactored away in the future.
*/
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState = 4,
/*! Whether to use 2D block scaling for NVFP4 */
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
kNVTEQuantizationConfigNumAttributes
};
......@@ -458,6 +464,15 @@ inline bool is_fp4_dtype(const DType t) {
#endif
}
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
*/
inline bool is_high_precision_dtype(const DType t) {
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
}
/*! \struct TensorWrapper
* \brief C++ wrapper for the NVTETensor class.
*/
......@@ -593,6 +608,11 @@ class TensorWrapper {
return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape);
}
template <typename ShapeType>
TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept {
return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape);
}
// Parameter getters
NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept {
......@@ -617,6 +637,10 @@ class TensorWrapper {
return get_parameter(kNVTEColumnwiseScaleInv);
}
NVTEBasicTensor get_columnwise_amax() const noexcept {
return get_parameter(kNVTEColumnwiseAmax);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
......@@ -865,6 +889,24 @@ class QuantizationConfigWrapper {
&format, sizeof(Float8BlockScaleTensorFormat));
}
/*! \brief Set stochastic rounding state */
void set_rng_state(NVTETensor rng_state) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state,
sizeof(NVTETensor));
}
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization,
&nvfp4_2d_quantization, sizeof(bool));
}
/*! \brief Set whether to use stochastic rounding */
void set_stochastic_rounding(bool stochastic_rounding) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding,
&stochastic_rounding, sizeof(bool));
}
private:
/*! \brief Wrapped NVTEQuantizationConfig. */
NVTEQuantizationConfig config_ = nullptr;
......
......@@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp_scaling(z->scaling_mode)) {
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
......@@ -65,11 +65,11 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_mxfp_scaling(z->scaling_mode),
!is_mxfp8_scaling(z->scaling_mode),
"Cudnn backend is need by block scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
......
......@@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp_scaling(z->scaling_mode)) {
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
......@@ -51,11 +51,11 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
bool is_aligned = true;
#ifdef USE_ROCM
NVTE_CHECK(
!is_mxfp_scaling(z->scaling_mode),
!is_mxfp8_scaling(z->scaling_mode),
"Cudnn backend is need by mxfp scaling mode for normalization! Not surpported in rocm yet.");
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#else
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode);
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
#endif
if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) {
......
......@@ -4,7 +4,6 @@
"""This module provides predefined FP8 recipes."""
from __future__ import annotations
import warnings
import os
from enum import Enum
from typing import Literal, Optional, Union, Callable, NamedTuple
......@@ -23,9 +22,12 @@ class _FormatHelper(NamedTuple):
class Format(Enum):
"""
Supported FP8 formats.
Supported FP4 formats.
Values
------
E2M1 :
All FP4 tensors are in e2m1 format
E4M3 :
All FP8 tensors are in e4m3 format
E5M2 :
......@@ -35,6 +37,7 @@ class Format(Enum):
FP8 tensors in the backward pass are in e5m2 format
"""
E2M1 = _FormatHelper(max_fwd=6, max_bwd=6)
E4M3 = _FormatHelper(max_fwd=448, max_bwd=448)
E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344)
HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd)
......@@ -42,9 +45,13 @@ class Format(Enum):
@dataclass(frozen=True)
class MMParams:
"""for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator)
apply split accumulator or not, turning it on will increase accuracy but impact gemm performance,
so only turn it on for certain gemms
"""Matrix multiplication options.
Parameters
----------
use_split_accumulator : bool, default = `True`
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
use_split_accumulator: bool = True
......@@ -55,10 +62,24 @@ class QParams:
"""Quantization parameters.
power_2_scale: use power of 2 scale parameter
amax_epsilon: optional minimum value of abs max
random_hadamard_transform: whether to use random hadamard transform
stochastic_rounding: whether to use stocastic rounding
"""
power_2_scale: bool = False
amax_epsilon: float = 0.0
random_hadamard_transform: bool = False
stochastic_rounding: bool = False
fp4_2d_quantization: bool = False
def __repr__(self) -> str:
return (
f"Qparams(\npower_2_scale={self.power_2_scale},\n"
f"amax_epsilon={self.amax_epsilon},\n"
f"random_hadamard_transform={self.random_hadamard_transform},\n"
f"stochastic_rounding={self.stochastic_rounding},\n"
f"fp4_2d_quantization={self.fp4_2d_quantization}\n)"
)
class Recipe:
......@@ -66,6 +87,10 @@ class Recipe:
Base recipe class.
"""
def nvfp4(self):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
......@@ -184,6 +209,7 @@ class DelayedScaling(Recipe):
f"margin={self.margin}, "
f"format={str(self.fp8_format).split('.')[1]}, "
f"amax_history_len={self.amax_history_len}, "
f"reduce_amax={self.reduce_amax}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
......@@ -201,10 +227,11 @@ class Float8CurrentScaling(Recipe):
pass.
"""
use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1"
fp8_format: Format = Format.HYBRID
fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0)
fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0)
fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False)
fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True)
fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True)
......@@ -213,9 +240,6 @@ class Float8CurrentScaling(Recipe):
def __post_init__(self) -> None:
assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported."
assert (
not self.fp8_dpa and not self.fp8_mha
), "FP8 attention is not supported for Float8CurrentScaling."
def __repr__(self) -> str:
return (
......@@ -351,3 +375,84 @@ class Float8BlockScaling(Recipe):
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}"
)
@dataclass()
class NVFP4BlockScaling(Recipe):
"""
Use the NVFP4 scaling strategy.
This is a 2-level block scaling strategy. In level 1, each group of
16 consecutive values is scaled together using their own scaling
factor. The type of the scaling factor is E4M3 (4 bits of exponent,
3 bits of mantissa). In level 2, a global per tensor FP32 scaling
factor is used to scale the entire tensor.
Since the scaling happens in a particular direction (either rowwise
or columnwise), in this recipe the quantized tensor and its transpose
are not numerically equivalent. Due to this, when Transformer Engine
needs both the tensor and its transpose (e.g. to calculate both
forward and backward pass), during the quantization both versions are
computed from the high precision input to avoid double quantization
errors.
Parameters
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
fp8_format : {Format.E4M3}, default = Format.E4M3
FP8 data type. Only E4M3 is supported.
fp8_dpa: bool, default = `False`
FP8 dot product attention. Not yet supported.
fp8_mha: bool, default = `False`
FP8 multi-head attention. Not yet supported.
"""
# Configuration envvars
disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1"
disable_stochastic_rounding: bool = (
os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1"
)
disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1"
fp4_format: Format = Format.E2M1
fp8_format: Format = Format.E4M3
# Not applying quantization to attention for now
fp8_dpa: bool = False
fp8_mha: bool = False
def __post_init__(self) -> None:
assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling"
assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling"
# Quantization params
# Note: RHT is currently only applied to column-wise usage so that
# it can be used for wgrad GEMM.
self.fp4_quant_fwd_inp = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=False,
fp4_2d_quantization=False,
)
self.fp4_quant_fwd_weight = QParams(
random_hadamard_transform=False,
stochastic_rounding=False,
fp4_2d_quantization=not self.disable_2d_quantization,
)
self.fp4_quant_bwd_grad = QParams(
random_hadamard_transform=not self.disable_rht,
stochastic_rounding=not self.disable_stochastic_rounding,
fp4_2d_quantization=False,
)
def __repr__(self) -> str:
return (
f"recipe_type={self.__class__.__name__}, "
f"fp4_format={str(self.fp4_format).split('.')[1]}, "
f"fp8_format={str(self.fp8_format).split('.')[1]}, "
f"fp8_dpa={self.fp8_dpa}, "
f"fp8_mha={self.fp8_mha}, "
f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, "
f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, "
f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, "
)
......@@ -27,6 +27,13 @@ namespace {
constexpr int amax_kernel_threads = 512;
__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
*amax_ptr = 0;
}
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
......@@ -131,7 +138,8 @@ template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
cudaStream_t stream) {
// Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
// Return immediately if tensor is empty
if (N == 0) {
......@@ -216,15 +224,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
// Check output tensor
NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
auto &output = *convertNVTETensorCheck(output_);
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING ||
output.scaling_mode == NVTE_NVFP4_1D_SCALING,
"Output tensor for amax computation must be FP8 tensor with per-tensor scaling or "
"NVFP4 1D scaling, "
"but got scaling_mode=",
to_string(output.scaling_mode));
NVTE_CHECK(output.amax.numel() == 1,
"Output tensor for amax computation has invalid amax tensor "
"(expected 1 entry, got shape=",
output.amax.shape, ")");
NVTE_CHECK(output.amax.dptr != nullptr,
NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr,
"Output tensor for amax computation has amax tensor without data");
NVTE_CHECK(output.amax.dtype == DType::kFloat32,
"Output tensor for amax computation has invalid amax tensor "
......@@ -243,11 +253,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt
}
// Compute amax
float *amax_ptr = reinterpret_cast<float *>(
(output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
noop_ptr, stream);); // NOLINT(*)
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel<nvec>(
reinterpret_cast<const IType *>(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr,
stream);); // NOLINT(*)
}
} // anonymous namespace
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include <cassert>
#include "../common.h"
#include "../utils.cuh"
namespace transformer_engine {
namespace nvfp4_recipe {
// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0;
constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0);
// Kernel to compute alpha *= amax_A * amax_B / factor
__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A,
const float *amax_B, float *alpha_out) {
// factor is defined in the enclosing namespace
*alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv;
}
} // namespace nvfp4_recipe
} // namespace transformer_engine
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out,
cudaStream_t stream) {
NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale);
using namespace transformer_engine;
auto *tA = convertNVTETensor(inpA);
auto *tB = convertNVTETensor(inpB);
auto *tOut = convertNVTETensor(alpha_out);
void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr;
void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr;
void *alpha_ptr = tOut->data.dptr;
// check for not null pointers
NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null");
NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null");
NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null");
nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>(
alpha_in, reinterpret_cast<const float *>(amax_A_ptr),
reinterpret_cast<const float *>(amax_B_ptr), reinterpret_cast<float *>(alpha_ptr));
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -18,7 +18,9 @@
namespace transformer_engine {
namespace {
constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32;
constexpr int MXFP8_BLOCK_SIZE = 32;
constexpr int NVFP4_BLOCK_SIZE = 16;
constexpr __device__ __host__ int TB_DIM = 32;
constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16;
constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4;
......@@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
const int original_K = kernel_args.original_k_list[tensor_id];
constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE;
constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4;
// Get block index in grid. Emulate 2D grid.
const int num_tiles_k = K / SF_TILE_DIM_K;
......@@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) {
NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + ".");
}
NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING ||
input->scaling_mode == NVTE_BLOCK_SCALING_1D ||
input->scaling_mode == NVTE_BLOCK_SCALING_2D ||
input->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
......@@ -345,13 +349,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
CheckInputTensor(*output, "scaling_factor_output");
auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Unsupported scaling mode for swizzling.");
bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
// 1D block scaling, row-wise or colum-wise
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
const int m =
input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1];
const int k =
input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0];
int m, k;
if (input->has_data()) {
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else {
if (nvfp4) {
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
} else {
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
}
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
......@@ -375,16 +391,35 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const bool rowwise_swizzle = input->has_data() || nvfp4;
const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4;
dim3 block_size(TB_DIM, TB_DIM);
if (input->has_data()) {
if (rowwise_swizzle) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */
if (vec_load_size == 3) vec_load_size = 1;
int n_tiles_in_tb = TB_DIM * vec_load_size;
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_first_dim();
const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
int original_M, original_K;
void *input_scale_inv_ptr, *output_scale_inv_ptr;
if (!nvfp4 || input->has_data()) {
int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / block_scale_size;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else {
original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
}
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
......@@ -392,21 +427,21 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 2:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 1:
cudaFuncSetAttribute((const void *)swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>,
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size);
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
#else
case 4:
......@@ -415,7 +450,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
......@@ -423,7 +458,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
......@@ -431,16 +466,15 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_row_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(
input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K);
input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
if (input->has_columnwise_data()) {
if (columnwise_swizzle) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
int n_tiles_in_tb = TB_DIM * vec_load_size;
......@@ -448,6 +482,9 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
case 4:
......@@ -481,8 +518,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int4, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
output->columnwise_scale_inv.dptr, m, k,
original_M, original_K);
break;
case 2:
NVTE_CHECK_CUDA(
......@@ -490,8 +527,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int2, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
output->columnwise_scale_inv.dptr, m, k,
original_M, original_K);
break;
case 1:
NVTE_CHECK_CUDA(
......@@ -499,20 +536,14 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size));
swizzle_col_scaling_kernel<int, SF_TILE_DIM_M, SF_TILE_DIM_K>
<<<num_blocks, block_size, slm_size, stream>>>(input->columnwise_scale_inv.dptr,
output->columnwise_scale_inv.dptr, m,
k, original_M, original_K);
output->columnwise_scale_inv.dptr, m, k,
original_M, original_K);
break;
#endif
default:
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
// 2D block scaling
} else {
NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans.");
}
NVTE_CHECK_CUDA(cudaGetLastError());
......@@ -650,6 +681,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args,
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
// TODO(nvfp4): Add NVFP4 support.
void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
std::vector<Tensor*>& output, cudaStream_t stream) {
auto num_tensors = input.size();
......@@ -776,7 +809,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
* WIP (Phuong):
* - Opt for bank conflicts
* - Adding swizzle for 2d-block scaling.
*/
*/
void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swizzle_scaling_factors);
using namespace transformer_engine;
......
......@@ -11,6 +11,7 @@
#include <cstring>
#include <iostream>
#include <mutex>
#include <utility>
#include "common.h"
#include "common/util/cuda_runtime.h"
......@@ -67,8 +68,8 @@ std::string to_string(const NVTEScalingMode &mode) {
return "NVTE_DELAYED_TENSOR_SCALING";
case NVTE_MXFP8_1D_SCALING:
return "NVTE_MXFP8_1D_SCALING";
case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING:
return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING";
case NVTE_NVFP4_1D_SCALING:
return "NVTE_NVFP4_1D_SCALING";
case NVTE_INVALID_SCALING:
return "NVTE_INVALID_SCALING";
}
......@@ -98,12 +99,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
t.columnwise_scale_inv.shape, ")");
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING ||
t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
// Need (4, 128) alignment even for e8 scaling factor
auto block_alignment = std::vector<size_t>{128ul, 4ul};
size_t expected_x, expected_y, alignment;
const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16;
const size_t block_size_rowwise = 32;
const size_t block_size_colwise = 32;
if (t.has_data()) {
......@@ -114,6 +114,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
expected_y =
DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(block_size_rowwise)), alignment) *
alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
......@@ -126,11 +127,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
alignment;
alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
} else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) {
if (t.has_data()) {
const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128);
const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4);
const auto &expected = std::vector<size_t>{expected_y, expected_x};
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128);
const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4);
const auto &expected = std::vector<size_t>{expected_y, expected_x};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
}
}
}
......@@ -158,6 +177,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
"(expected Float32 or Byte, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else if (is_fp4_dtype(type)) {
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
"_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
name,
"_columnwise_scale_inverse has invalid dtype "
"(expected DType::kFloat8E4M3, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
......@@ -199,10 +238,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
"(expected Float32 or Float8E8M0, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else if (is_fp4_dtype(type)) {
// FP4 output needs to have the scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
"_scale_inverse has invalid dtype "
"(expected Float8E4M3, got ",
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
name,
"_columnwise_scale_inverse has invalid dtype "
"(expected Float8E4M3, got ",
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
// Note: amax is supported for non-FP8 output as it can be fused into the computation
// and later used for quantization with no need to compute it separately
// Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
// NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
......@@ -507,6 +565,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
case kNVTEColumnwiseScaleInv:
t->columnwise_scale_inv = *param;
break;
case kNVTEColumnwiseAmax:
t->columnwise_amax = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
}
......@@ -530,6 +591,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
return t.scale_inv;
case kNVTEColumnwiseScaleInv:
return t.columnwise_scale_inv;
case kNVTEColumnwiseAmax:
return t.columnwise_amax;
default:
NVTE_ERROR("Unknown tensor parameter!");
}
......@@ -645,6 +708,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(&config_.rng_state, buf, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
}
......
......@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::detail {
......@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream);
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor &noop_tensor, cudaStream_t stream);
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace transformer_engine {
#if CUDA_VERSION >= 12080
namespace quantize_transpose_nvfp4 {
namespace {
using std::int32_t;
using std::uint32_t;
using std::uint8_t;
using transformer_engine::detail::TypeExtrema;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr int kThreadsPerWarp = 32;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr int kNFP4PerContainer = 2;
// Hyperparameters for performance tuning
constexpr int kTileDim = 128;
// constexpr int kScaleDim = 32;
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem");
constexpr int kSMemRow = kTileDim;
constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1;
constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem;
constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16
constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
// for 2D block scaling, we need to reduce amax in warp
static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080};
// max for every group_size elements in warp
template <int group_size, int shfl_down_stride>
__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) {
for (int offset = group_size / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride));
}
return val;
}
template <typename ScaleType>
__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax,
const float global_encode_scale) {
float decode_scale = amax / TypeExtrema<fp4e2m1>::max;
decode_scale = decode_scale * global_encode_scale;
decode_scale = fminf(decode_scale, TypeExtrema<float>::max);
return static_cast<ScaleType>(decode_scale);
}
template <typename ScaleType>
__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale,
const float global_decode_scale) {
return fminf(1.0f / (static_cast<float>(decode_scale) * global_decode_scale),
TypeExtrema<float>::max);
}
template <typename IType, typename ScaleType>
__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) {
return static_cast<float>(input) * encode_scale;
}
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.f || global_encode_scale == 0.f) {
return 1.f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
template <class ScaleType>
__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx,
uint32_t col_length) {
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr uint32_t kTotalRowsPerBaseBlock = 128;
constexpr uint32_t kRowsPerBaseBlockCol = 32;
constexpr uint32_t kColsPerBaseBlockCol = 4;
const size_t rb = row_idx / kTotalRowsPerBaseBlock;
const size_t rem = row_idx % kTotalRowsPerBaseBlock;
const size_t d4 = rem / kRowsPerBaseBlockCol;
const size_t d3 = rem % kRowsPerBaseBlockCol;
const size_t cbg = col_idx / kColsPerBaseBlockCol;
const size_t d5 = col_idx % kColsPerBaseBlockCol;
const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5;
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t out_4x;
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
"}"
: "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile(
"{\n"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template <bool kApplyStochasticRounding>
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const uint32_t rbits) {
if constexpr (kApplyStochasticRounding) {
return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits);
} else {
return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits);
}
}
template <bool kReturnIdentity, bool kReturnTranspose, bool kIsE8Scaling, bool kAligned,
typename CType, typename IType, typename OType, typename ScaleType, bool kSwizzledScale,
bool kApplyStochasticRounding, bool kIs2DBlockScaling>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, const float* global_amax, OType* const output_c,
OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t,
const size_t row_length, const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y,
const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state,
const float* noop_ptr) {
constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecContainer>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
const size_t block_idx_x = blockIdx.x;
const size_t block_idx_y = blockIdx.y;
const size_t rng_sequence =
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if constexpr (kIs2DBlockScaling && kIsE8Scaling) {
return;
}
if constexpr (kIs2DBlockScaling && !kReturnIdentity) {
return;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr int kFP4BlockScalingSize = 16;
constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1;
constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4
constexpr int k2DBlockAmaxReduceDim =
kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1;
__shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim];
__shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim];
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = (c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0); // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
const int kNumThreadsReduce = kScaleBlockDim / kNVecOut;
const float global_encode_scale =
kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]);
const float global_decode_scale = 1.0 / global_encode_scale;
// Step 2: Cast and store to output_c
if constexpr (kReturnIdentity) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele =
(c_g < row_length ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(row_length - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if constexpr (kIs2DBlockScaling) {
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
int tid_in_warp_x = threadIdx.x % kNumThreadsStore;
int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp;
CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>(
amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]);
// now T0 ~ T8 in each warp has the reduced amax values
int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y;
if (tid_in_warp_y == 0) {
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]
[warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced;
}
__syncthreads();
if (data_row_idx % kFP4BlockScalingSize == 0) {
CType amax_2d = 0.0;
for (int i = 0; i < k2DBlockAmaxReduceDim; i++) {
amax_2d = fmaxf(amax_2d,
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]);
}
amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d;
}
__syncthreads();
// every thread now knows 2D amax
amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x];
}
// Step 2.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
write_scale_inv &= (c_g < row_length);
}
if (write_scale_inv) {
size_t row_idx = block_idx_y * kTileDim + r_s;
size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce;
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(row_length, kScaleBlockDim));
tile_scales_inv_c[offset] = scale_inv;
} else {
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[0], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[1], encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[0], encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[1], encode_scale);
const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g += stride_g / kNFP4PerContainer;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if constexpr (kReturnTranspose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = (c_g < num_rows ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(num_rows - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
if constexpr (kIs2DBlockScaling) {
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
constexpr int kNumColsPerWarp =
kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements
constexpr int kNumWarpsPerBlock =
kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block
constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock;
int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp;
int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore;
int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x;
amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize];
} else {
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
}
// Step 3.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// Step 3.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
write_scale_inv &= (c_g < num_rows);
}
if (write_scale_inv) {
size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce);
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim));
tile_scales_inv_t[offset] = scale_inv;
} else {
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x =
ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i].data.elt[smem_idx], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i + 1].data.elt[smem_idx],
encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1)].data.elt[smem_idx],
encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx],
encode_scale);
const uint32_t rbits =
kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0,
num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g += stride_g / kNFP4PerContainer;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
} // namespace
} // namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace detail {
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now");
if (!return_identity && !return_transpose) {
return;
}
if (use_2d_quantization && !return_identity) {
return;
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
num_elements *= input.shape.at(i);
}
// Early return if the input tensor is empty
if (num_elements == 0) {
return;
}
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
if (return_identity) {
scale_stride_x = 1;
scale_stride_y = scale_inv.shape[1];
}
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}
using namespace transformer_engine::quantize_transpose_nvfp4;
const size_t num_blocks_x = DIVUP(row_length, static_cast<size_t>(kTileDim));
const size_t num_blocks_y = DIVUP(num_rows, static_cast<size_t>(kTileDim));
// noop tensor for cuda graph
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
const size_t* rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t*>(rng_state_te_tensor.data.dptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(
output.dtype, 2, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16;
constexpr bool kPow2Scale = false;
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
swizzled_scale, kSwizzledScale,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kApplyStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_2d_quantization, kIs2DBlockScaling,
size_t smem_bytes = kSMemSize * sizeof(InputType);
auto kernel = block_scaled_1d_cast_transpose_kernel<
kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned,
float, InputType, OutputType, ScaleType, kSwizzledScale,
kApplyStochasticRounding, kIs2DBlockScaling>;
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes);
NVTE_CHECK(err == cudaSuccess,
"Failed to set dynamic shared memory size.");
} kernel<<<grid, kThreadsPerBlock, smem_bytes,
stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const float*>(global_amax.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<ScaleType*>(scale_inv.dptr),
reinterpret_cast<ScaleType*>(scale_inv_t.dptr), row_length,
num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, kScaleBlockDim, epsilon, rng_state,
noop_ptr);) // kIs2DBlockScaling
) // kApplyStochasticRounding
) // kSwizzledScale
) // kAligned
) // kReturnTranspose
) // kReturnIdentity
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace detail
} // namespace transformer_engine
......@@ -603,6 +603,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
......@@ -833,6 +834,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
......@@ -956,6 +958,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
......@@ -1274,7 +1277,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_gated<ParamOP, ActOP>(gated_input, output, stream);
}
}
} else if (is_mxfp_scaling(output->scaling_mode)) {
} else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
} else {
......
......@@ -25,6 +25,7 @@
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
......@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
......@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
......@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
......@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
if (rowwise_scale_is_within_bounds) {
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
......@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
......@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM;
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset]));
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset]));
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
......@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const size_t shmem_thread_offset =
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const size_t shmem_elt_idx = swizzled_group_offset + e;
const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
......@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const size_t scaling_block = threadIdx.x / SCALE_DIM_X;
const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const size_t dbias_stride = cols;
const size_t dbias_offset_Y = blockIdx.y;
const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const int dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
......@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
} // namespace mxfp8_kernel
namespace nvfp4_kernel {
using namespace ptx;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 16;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType, bool COLWISE_SCALING, size_t CHUNK_DIM_Y,
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0,
const float *noop, float *const amax_ptr,
const float *const nvfp4_second_stage_scale_ptr, const size_t rows,
const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool ROWWISE_SCALING = true;
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE;
static_assert(BUFF_DIM_Y >= SCALE_DIM_Y &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size
constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0);
constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0);
constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0);
constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data);
fp8e4m3 *out_rowwise_scales_sh =
reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
e8m0_t *out_colwise_scales_sh = reinterpret_cast<e8m0_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc =
(nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr);
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const int buff_offset_in = buff * BUFF_IN_DIM;
const int buff_offset_out = buff * BUFF_OUT_DIM;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_IN_DIM;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM_Y];
IType in_colwise_IType[SCALE_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(block_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
if (colwise_scale_is_within_bounds) {
scales_colwise_e8m0[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (int it = 0; it < ITERATIONS_ROWWISE; ++it) {
const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const int shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const int shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if (rowwise_scale_is_within_bounds) {
const int scales_offset_Y =
scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE;
const int scales_offset_X = scales_offset_X_rowwise;
const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X;
scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8;
}
#else
const int shmem_scales_offset_Y =
stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise;
const int shmem_scales_offset_X = tid_X_rowwise;
const int scale_idx =
shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X;
out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8;
#endif
// Compute "correct" per-block encoding scaling factor
const float block_scale_inverse =
__fdiv_rn(S_enc, static_cast<float>(S_dec_b_fp8)); // S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out; // Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in01 = in_IType[w].data.elt[2 * e];
in23 = in_IType[w].data.elt[2 * e + 1];
} else if constexpr (IS_CACHED_ACT_OP) {
in01.x = in_cached[w].data.elt[4 * e];
in01.y = in_cached[w].data.elt[4 * e + 1];
in23.x = in_cached[w].data.elt[4 * e + 2];
in23.y = in_cached[w].data.elt[4 * e + 3];
} else {
const int j = w * PACK_SIZE + 4 * e;
in01.x = in_compute_rowwise[j];
in01.y = in_compute_rowwise[j + 1];
in23.x = in_compute_rowwise[j + 2];
in23.y = in_compute_rowwise[j + 3];
}
fp4e2m1x4 &out_quad = reinterpret_cast<fp4e2m1x4 &>(out.data.elt[e]);
ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
__builtin_assume(block_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM;
const int buff_offset_mxfp8 = buff * BUFF_IN_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset_nvfp4]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset_mxfp8]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if constexpr (ROWWISE_SCALING) {
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise;
const int scale_idx_global =
scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise;
const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) &&
(scales_offset_X_rowwise < (cols / SCALE_DIM_X))) {
using ScalesVec_t = Vec<fp8e4m3, NVFP4_SCALING_FACTORS_PER_CHUNK_ROW>;
const ScalesVec_t &scales =
*reinterpret_cast<ScalesVec_t *>(&out_rowwise_scales_sh[scale_idx_shmem]);
scales.store_to(&scales_rowwise_e4m3[scale_idx_global]);
}
}
#endif
float chunk_amax = 0.0f;
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
chunk_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(thread_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, chunk_amax);
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128;
......@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
......@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#endif
}
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)>
void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) {
using namespace nvfp4_kernel;
using namespace ptx;
checkCuDriverContext(stream);
NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated.");
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
bool use_colwise_scaling = output->has_columnwise_data();
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = output->scale_inv.shape[1];
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast<fp8e4m3 *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_e8m0_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const ScalingType scaling_type =
use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const nvfp4_second_stage_scale_ptr =
reinterpret_cast<const float *>(output->scale.dptr);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const DType output_data_type =
use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4);
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8);
}
constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t);
constexpr size_t in_mem = buff_size_aligned_in;
const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4;
const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0;
const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales;
const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0;
const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem +
out_rowwise_scales_mem + out_colwise_scales_mem +
TMA_SHMEM_ALIGNMENT;
const size_t dshmem_size = in_mem + out_mem;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}); // NOLINT(*)
); // NOLINT(*)
}
namespace detail {
using Empty = transformer_engine::Empty;
......@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor();
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
......@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
cast_transpose(*input_tensor, noop_tensor, output_tensor, stream);
cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>(
*input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor,
......@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
}
} else if (output_tensor->has_data()) {
fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor,
*input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor,
*input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 &&
output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization");
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream);
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
......@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor.data, stream);
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
......
......@@ -19,6 +19,8 @@
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../common.h"
......@@ -28,6 +30,7 @@
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
......@@ -337,8 +340,83 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*)
); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#endif
#endif // __HIP_PLATFORM_AMD__
}
#if CUDA_VERSION >= 12080
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if CUDA_VERSION >= 12080
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // CUDA_VERSION >= 12080
}
} // namespace dequantization
namespace detail {
......@@ -347,16 +425,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
if (is_tensor_scaling(input.scaling_mode)) {
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
dequantization::fp8_dequantize(input, output, stream);
} else if (is_mxfp_scaling(input.scaling_mode)) {
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
} else {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
break;
}
case NVTE_NVFP4_1D_SCALING: {
dequantization::fp4_dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment