Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
......@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::detail {
......@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream);
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor &noop_tensor, cudaStream_t stream);
} // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
......@@ -18,6 +18,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
......@@ -901,6 +902,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream);
if (transformer_engine::cuda::sm_arch() >= 100) {
NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ",
"with MXFP8, which requires using power of two scaling factors.");
}
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1;
......
......@@ -24,6 +24,7 @@
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh"
namespace transformer_engine {
......@@ -1480,6 +1481,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
if (transformer_engine::cuda::sm_arch() >= 100) {
NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ",
"with MXFP8, which requires using power of two scaling factors.");
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace transformer_engine {
#if CUDA_VERSION >= 12080
namespace quantize_transpose_nvfp4 {
namespace {
using std::int32_t;
using std::uint32_t;
using std::uint8_t;
using transformer_engine::detail::TypeExtrema;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr int kThreadsPerWarp = 32;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr int kNFP4PerContainer = 2;
// Hyperparameters for performance tuning
constexpr int kTileDim = 128;
// constexpr int kScaleDim = 32;
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem");
constexpr int kSMemRow = kTileDim;
constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1;
constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem;
constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16
constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
// for 2D block scaling, we need to reduce amax in warp
static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080};
// max for every group_size elements in warp
template <int group_size, int shfl_down_stride>
__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) {
for (int offset = group_size / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride));
}
return val;
}
template <typename ScaleType>
__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax,
const float global_encode_scale) {
float decode_scale = amax / TypeExtrema<fp4e2m1>::max;
decode_scale = decode_scale * global_encode_scale;
decode_scale = fminf(decode_scale, TypeExtrema<float>::max);
return static_cast<ScaleType>(decode_scale);
}
template <typename ScaleType>
__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale,
const float global_decode_scale) {
return fminf(1.0f / (static_cast<float>(decode_scale) * global_decode_scale),
TypeExtrema<float>::max);
}
template <typename IType, typename ScaleType>
__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) {
return static_cast<float>(input) * encode_scale;
}
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.f || global_encode_scale == 0.f) {
return 1.f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
template <class ScaleType>
__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx,
uint32_t col_length) {
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr uint32_t kTotalRowsPerBaseBlock = 128;
constexpr uint32_t kRowsPerBaseBlockCol = 32;
constexpr uint32_t kColsPerBaseBlockCol = 4;
const size_t rb = row_idx / kTotalRowsPerBaseBlock;
const size_t rem = row_idx % kTotalRowsPerBaseBlock;
const size_t d4 = rem / kRowsPerBaseBlockCol;
const size_t d3 = rem % kRowsPerBaseBlockCol;
const size_t cbg = col_idx / kColsPerBaseBlockCol;
const size_t d5 = col_idx % kColsPerBaseBlockCol;
const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5;
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t out_4x;
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
"}"
: "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile(
"{\n"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template <bool kApplyStochasticRounding>
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const uint32_t rbits) {
if constexpr (kApplyStochasticRounding) {
return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits);
} else {
return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits);
}
}
template <bool kReturnIdentity, bool kReturnTranspose, bool kIsE8Scaling, bool kAligned,
typename CType, typename IType, typename OType, typename ScaleType, bool kSwizzledScale,
bool kApplyStochasticRounding, bool kIs2DBlockScaling>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, const float* global_amax, OType* const output_c,
OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t,
const size_t row_length, const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y,
const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state,
const float* noop_ptr) {
constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecContainer>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
const size_t block_idx_x = blockIdx.x;
const size_t block_idx_y = blockIdx.y;
const size_t rng_sequence =
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if constexpr (kIs2DBlockScaling && kIsE8Scaling) {
return;
}
if constexpr (kIs2DBlockScaling && !kReturnIdentity) {
return;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr int kFP4BlockScalingSize = 16;
constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1;
constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4
constexpr int k2DBlockAmaxReduceDim =
kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1;
__shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim];
__shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim];
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = (c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0); // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
const int kNumThreadsReduce = kScaleBlockDim / kNVecOut;
const float global_encode_scale =
kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]);
const float global_decode_scale = 1.0 / global_encode_scale;
// Step 2: Cast and store to output_c
if constexpr (kReturnIdentity) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele =
(c_g < row_length ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(row_length - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if constexpr (kIs2DBlockScaling) {
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
int tid_in_warp_x = threadIdx.x % kNumThreadsStore;
int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp;
CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>(
amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]);
// now T0 ~ T8 in each warp has the reduced amax values
int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y;
if (tid_in_warp_y == 0) {
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]
[warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced;
}
__syncthreads();
if (data_row_idx % kFP4BlockScalingSize == 0) {
CType amax_2d = 0.0;
for (int i = 0; i < k2DBlockAmaxReduceDim; i++) {
amax_2d = fmaxf(amax_2d,
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]);
}
amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d;
}
__syncthreads();
// every thread now knows 2D amax
amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x];
}
// Step 2.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
write_scale_inv &= (c_g < row_length);
}
if (write_scale_inv) {
size_t row_idx = block_idx_y * kTileDim + r_s;
size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce;
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(row_length, kScaleBlockDim));
tile_scales_inv_c[offset] = scale_inv;
} else {
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[0], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[1], encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[0], encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[1], encode_scale);
const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g += stride_g / kNFP4PerContainer;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if constexpr (kReturnTranspose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = (c_g < num_rows ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(num_rows - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
if constexpr (kIs2DBlockScaling) {
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
constexpr int kNumColsPerWarp =
kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements
constexpr int kNumWarpsPerBlock =
kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block
constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock;
int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp;
int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore;
int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x;
amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize];
} else {
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
}
// Step 3.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// Step 3.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
write_scale_inv &= (c_g < num_rows);
}
if (write_scale_inv) {
size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce);
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim));
tile_scales_inv_t[offset] = scale_inv;
} else {
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x =
ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i].data.elt[smem_idx], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i + 1].data.elt[smem_idx],
encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1)].data.elt[smem_idx],
encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx],
encode_scale);
const uint32_t rbits =
kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0,
num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g += stride_g / kNFP4PerContainer;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
} // namespace
} // namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace detail {
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now");
if (!return_identity && !return_transpose) {
return;
}
if (use_2d_quantization && !return_identity) {
return;
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
num_elements *= input.shape.at(i);
}
// Early return if the input tensor is empty
if (num_elements == 0) {
return;
}
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
if (return_identity) {
scale_stride_x = 1;
scale_stride_y = scale_inv.shape[1];
}
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}
using namespace transformer_engine::quantize_transpose_nvfp4;
const size_t num_blocks_x = DIVUP(row_length, static_cast<size_t>(kTileDim));
const size_t num_blocks_y = DIVUP(num_rows, static_cast<size_t>(kTileDim));
// noop tensor for cuda graph
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
const size_t* rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t*>(rng_state_te_tensor.data.dptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(
output.dtype, 2, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16;
constexpr bool kPow2Scale = false;
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
swizzled_scale, kSwizzledScale,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kApplyStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_2d_quantization, kIs2DBlockScaling,
size_t smem_bytes = kSMemSize * sizeof(InputType);
auto kernel = block_scaled_1d_cast_transpose_kernel<
kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned,
float, InputType, OutputType, ScaleType, kSwizzledScale,
kApplyStochasticRounding, kIs2DBlockScaling>;
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes);
NVTE_CHECK(err == cudaSuccess,
"Failed to set dynamic shared memory size.");
} kernel<<<grid, kThreadsPerBlock, smem_bytes,
stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const float*>(global_amax.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<ScaleType*>(scale_inv.dptr),
reinterpret_cast<ScaleType*>(scale_inv_t.dptr), row_length,
num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, kScaleBlockDim, epsilon, rng_state,
noop_ptr);) // kIs2DBlockScaling
) // kApplyStochasticRounding
) // kSwizzledScale
) // kAligned
) // kReturnTranspose
) // kReturnIdentity
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace detail
} // namespace transformer_engine
......@@ -58,7 +58,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols) {
const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
......@@ -164,7 +165,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
......@@ -174,6 +174,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
......@@ -181,18 +187,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = act_x * grad_elt;
float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
......@@ -200,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, {}) * gate_elt;
const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
......@@ -305,7 +320,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
const size_t scale_stride_colwise, const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
......@@ -481,25 +496,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
float after_act_elt;
float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_elt = ActOP(act_elt, p) * gate_elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
......@@ -603,6 +630,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
......@@ -724,27 +752,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt;
float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_elt = ActOP(act_elt, p) * gate_elt;
after_act_rowwise[j] = after_act_elt;
}
......@@ -833,6 +873,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
......@@ -889,7 +930,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
......@@ -956,6 +997,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
......@@ -966,8 +1008,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
#endif
......@@ -975,7 +1016,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__
assert(false);
......@@ -1109,7 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::COLWISE:
......@@ -1126,7 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::BIDIMENSIONAL:
......@@ -1135,7 +1176,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
......@@ -1143,7 +1183,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}); // NOLINT(*)
......@@ -1152,12 +1192,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(),
"Wrong output shape. Expected (after flattening) [", input.flat_first_dim(),
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
......@@ -1179,7 +1216,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), {}, stream);
output->flat_last_dim(), p, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
......@@ -1188,7 +1225,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
......@@ -1217,7 +1255,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), {}, stream);
grad.flat_last_dim(), p, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
......@@ -1226,7 +1264,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) {
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
......@@ -1266,17 +1304,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if (is_delayed_tensor_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input, output, stream);
cast_gated<ParamOP, ActOP>(gated_input, output, p, stream);
}
}
} else if (is_mxfp_scaling(output->scaling_mode)) {
} else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", gated_input.data.shape);
......@@ -1292,7 +1330,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
cudaStream_t stream) {
ParamOP p, cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
......@@ -1301,13 +1339,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
output_tensor, stream);
output_tensor, p, stream);
} else {
if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, stream);
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, p,
stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, stream);
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, p, stream);
}
} else {
// MX scaling
......
......@@ -25,6 +25,7 @@
#include "../util/vectorized_pointwise.h"
#include "../utils.cuh"
#include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
......@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
......@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
......@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
......@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
if (rowwise_scale_is_within_bounds) {
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
......@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
......@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM;
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset]));
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset]));
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
......@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const size_t shmem_thread_offset =
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const size_t shmem_elt_idx = swizzled_group_offset + e;
const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
......@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const size_t scaling_block = threadIdx.x / SCALE_DIM_X;
const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const size_t dbias_stride = cols;
const size_t dbias_offset_Y = blockIdx.y;
const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const int dbias_stride = cols;
const int dbias_offset_Y = blockIdx.y;
const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
......@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__
} // namespace mxfp8_kernel
namespace nvfp4_kernel {
using namespace ptx;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 16;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType, bool COLWISE_SCALING, size_t CHUNK_DIM_Y,
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0,
const float *noop, float *const amax_ptr,
const float *const nvfp4_second_stage_scale_ptr, const size_t rows,
const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool ROWWISE_SCALING = true;
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE;
static_assert(BUFF_DIM_Y >= SCALE_DIM_Y &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size
constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0);
constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0);
constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0);
constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data);
fp8e4m3 *out_rowwise_scales_sh =
reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
e8m0_t *out_colwise_scales_sh = reinterpret_cast<e8m0_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc =
(nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr);
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const int buff_offset_in = buff * BUFF_IN_DIM;
const int buff_offset_out = buff * BUFF_OUT_DIM;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_IN_DIM;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM_Y];
IType in_colwise_IType[SCALE_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(block_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
if (colwise_scale_is_within_bounds) {
scales_colwise_e8m0[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (int it = 0; it < ITERATIONS_ROWWISE; ++it) {
const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const int shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const int shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if (rowwise_scale_is_within_bounds) {
const int scales_offset_Y =
scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE;
const int scales_offset_X = scales_offset_X_rowwise;
const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X;
scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8;
}
#else
const int shmem_scales_offset_Y =
stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise;
const int shmem_scales_offset_X = tid_X_rowwise;
const int scale_idx =
shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X;
out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8;
#endif
// Compute "correct" per-block encoding scaling factor
const float block_scale_inverse =
__fdiv_rn(S_enc, static_cast<float>(S_dec_b_fp8)); // S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out; // Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in01 = in_IType[w].data.elt[2 * e];
in23 = in_IType[w].data.elt[2 * e + 1];
} else if constexpr (IS_CACHED_ACT_OP) {
in01.x = in_cached[w].data.elt[4 * e];
in01.y = in_cached[w].data.elt[4 * e + 1];
in23.x = in_cached[w].data.elt[4 * e + 2];
in23.y = in_cached[w].data.elt[4 * e + 3];
} else {
const int j = w * PACK_SIZE + 4 * e;
in01.x = in_compute_rowwise[j];
in01.y = in_compute_rowwise[j + 1];
in23.x = in_compute_rowwise[j + 2];
in23.y = in_compute_rowwise[j + 3];
}
fp4e2m1x4 &out_quad = reinterpret_cast<fp4e2m1x4 &>(out.data.elt[e]);
ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
__builtin_assume(block_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM;
const int buff_offset_mxfp8 = buff * BUFF_IN_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset_nvfp4]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset_mxfp8]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if constexpr (ROWWISE_SCALING) {
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise;
const int scale_idx_global =
scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise;
const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) &&
(scales_offset_X_rowwise < (cols / SCALE_DIM_X))) {
using ScalesVec_t = Vec<fp8e4m3, NVFP4_SCALING_FACTORS_PER_CHUNK_ROW>;
const ScalesVec_t &scales =
*reinterpret_cast<ScalesVec_t *>(&out_rowwise_scales_sh[scale_idx_shmem]);
scales.store_to(&scales_rowwise_e4m3[scale_idx_global]);
}
}
#endif
float chunk_amax = 0.0f;
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
chunk_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(thread_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, chunk_amax);
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128;
......@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
......@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#endif
}
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)>
void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) {
using namespace nvfp4_kernel;
using namespace ptx;
checkCuDriverContext(stream);
NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated.");
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
bool use_colwise_scaling = output->has_columnwise_data();
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = output->scale_inv.shape[1];
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast<fp8e4m3 *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_e8m0_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const ScalingType scaling_type =
use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const nvfp4_second_stage_scale_ptr =
reinterpret_cast<const float *>(output->scale.dptr);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const DType output_data_type =
use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4);
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8);
}
constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t);
constexpr size_t in_mem = buff_size_aligned_in;
const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4;
const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0;
const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales;
const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0;
const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem +
out_rowwise_scales_mem + out_colwise_scales_mem +
TMA_SHMEM_ALIGNMENT;
const size_t dshmem_size = in_mem + out_mem;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}); // NOLINT(*)
); // NOLINT(*)
}
namespace detail {
using Empty = transformer_engine::Empty;
......@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp =
reinterpret_cast<const QuantizationConfig *>(quant_config);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr;
const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor();
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) {
......@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
cast_transpose(*input_tensor, noop_tensor, output_tensor, stream);
cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else {
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>(
*input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor,
......@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
}
} else if (output_tensor->has_data()) {
fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor,
*input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor,
*input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
break;
}
case NVTE_NVFP4_1D_SCALING: {
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 &&
output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization");
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream);
/*noop_tensor=*/noop_tensor->data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f;
bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp
? quant_config_cpp->float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT
: false;
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
......@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor.data, stream);
columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break;
}
default:
......
......@@ -19,6 +19,8 @@
#include <transformer_engine/cast.h>
#include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits>
#include "../common.h"
......@@ -28,6 +30,7 @@
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h"
namespace transformer_engine {
......@@ -339,6 +342,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
NVTE_CHECK_CUDA(cudaGetLastError());
#endif
}
#if CUDA_VERSION >= 12080
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if CUDA_VERSION >= 12080
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // CUDA_VERSION >= 12080
}
} // namespace dequantization
namespace detail {
......@@ -347,16 +425,24 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output");
if (is_tensor_scaling(input.scaling_mode)) {
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
dequantization::fp8_dequantize(input, output, stream);
} else if (is_mxfp_scaling(input.scaling_mode)) {
break;
}
case NVTE_MXFP8_1D_SCALING: {
if (is_supported_by_CC_100()) {
dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
} else {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
break;
}
case NVTE_NVFP4_1D_SCALING: {
dequantization::fp4_dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
......
......@@ -23,6 +23,8 @@
#endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP
......@@ -147,4 +149,12 @@
#endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
......@@ -11,6 +11,11 @@ namespace transformer_engine {
struct Empty {};
struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
};
template <typename OType, typename IType>
__device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val;
......@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s);
}
template <typename OType, typename IType>
__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) {
const float cval = val;
Empty e = {};
return cval * sigmoid<float, float>(alpha * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
return qgelu_with_alpha<OType, IType>(val, 1.702f);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) {
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
Empty e = {};
return alpha * cval * dsigmoid<float, float>(alpha * cval, e) +
sigmoid<float, float>(alpha * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val;
return 1.702f * cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
return dqgelu_with_alpha<OType, IType>(val, 1.702f);
}
template <typename OType, typename IType>
......@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
return cval * sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) {
const float cval = min(p.limit, static_cast<float>(val)); // Clamping
return qgelu_with_alpha<OType, float>(cval, p.alpha);
}
template <typename OType, typename IType>
__device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) {
const bool dclamp_val = static_cast<float>(val) <= p.limit;
const float clamp_val = min(static_cast<float>(val), p.limit);
const float dsilu_val = dqgelu_with_alpha<OType, float>(clamp_val, p.alpha);
return dclamp_val ? dsilu_val : 0.0f;
}
template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file nvfp4_transpose.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "curanddx.hpp"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#if CUDA_VERSION > 12080
namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
using namespace ptx;
using nvfp4_scale_t = fp8e4m3;
constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_NUM = 128;
constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM;
constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM;
constexpr size_t RNG_GENS_PER_THREAD =
SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers
constexpr size_t TILE_DIM_Y = 32;
constexpr size_t TILE_DIM_X = 128;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8
constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X;
constexpr size_t STAGES = TILES_Y * TILES_X;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = TILE_DIM_Y;
constexpr size_t BUFF_DIM_X = TILE_DIM_X;
constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X;
constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM / PACK_SIZE;
constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8
constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16
constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2
constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM;
constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
// Compute the global encode scale factor for a given global amax
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
using namespace detail;
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
} else {
return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
}
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
} else {
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
// const size_t tid_X_t = 0;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise =
(row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr size_t BLOCK_DIM = 16;
constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2
constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8
constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile
constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
const size_t warp_id = threadIdx.x / 32;
const size_t lane_id = threadIdx.x % 32;
float thread_amax = 0.0f;
const size_t block_in_warp = lane_id / BLOCKS_PER_WARP;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
__shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1];
// Helper function for warp reduction
auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float {
#pragma unroll
for (int delta = 8; delta >= 1; delta /= 2) {
float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta);
thread_amax = fmaxf(thread_amax, other_amax);
}
return thread_amax;
};
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
#pragma unroll
for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
const size_t block_in_tile_y = block_iter;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
for (int elem = 0; elem < BLOCK_DIM; elem += 2) {
const size_t elem_0_row = block_iter * BLOCK_DIM + elem;
const size_t elem_1_row = elem_0_row + 1;
const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
const size_t elem_1_col = elem_0_col;
const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col;
const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col;
IType2 val_2x;
val_2x.x = in_sh[shmem_offset_0];
val_2x.y = in_sh[shmem_offset_1];
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x);
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else {
for (int elem = 0; elem < BLOCK_DIM; ++elem) {
const size_t elem_row = block_iter * BLOCK_DIM + elem;
const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
// Bounds checking
const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows);
const bool col_out_of_bounds = (block_offset_X + elem_col >= cols);
if (!row_out_of_bounds && !col_out_of_bounds) {
const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col;
float elt = static_cast<float>(in_sh[shmem_offset]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
}
}
// Warp reduction to get block amax
block_amax = warp_reduce_amax(thread_amax, block_in_warp);
if (lane_id == 0 || lane_id == 16) {
block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads();
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 3. Scale elements
// Load data in
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
}
} else {
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// // Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = tid_X_rowwise;
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
}
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
using namespace nvfp4_transpose;
using namespace ptx;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", false);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if (return_transpose) {
NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
"Transposed output must have FP4 type.");
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Transposed scaling tensor must be allocated");
}
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;
const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride_transpose =
return_transpose ? output->columnwise_scale_inv.shape[1] : 0;
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
nvfp4_scale_t *const scales_transpose_ptr =
reinterpret_cast<nvfp4_scale_t *>(output->columnwise_scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const amax_rowwise_ptr = reinterpret_cast<const float *>(output->amax.dptr);
const float *const amax_colwise_ptr =
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
using IType = bf16;
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_data_mem = buff_size_aligned_out;
constexpr size_t out_data_transpose_mem = buff_size_aligned_out;
constexpr size_t out_scales_transpose_mem = buff_size_scales;
constexpr size_t out_mem = out_data_mem + out_data_transpose_mem;
constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = nvfp4_transpose_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
if constexpr (use_2d_quantization) {
kernel = nvfp4_transpose_kernel_2D<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
......@@ -14,6 +14,10 @@
#include <cuda.h>
#include <cuda_runtime.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
namespace transformer_engine {
namespace ptx {
......@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
}
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t out;
asm volatile(
"{\n"
......@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
T y;
};
template <typename T>
struct FPx4 {
T x1;
T x2;
T x3;
T x4;
};
template <typename T>
struct Type2x {};
template <>
struct Type2x<float> {
using type = float2;
};
template <>
struct Type2x<bf16> {
using type = __nv_bfloat162;
};
template <>
struct Type2x<fp16> {
using type = __half2;
};
using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>;
using floatx4 = FPx4<float>;
using bf16x4 = FPx4<bf16>;
using fp16x4 = FPx4<fp16>;
using fp8e4m3x4 = FPx4<fp8e4m3>;
using fp8e5m2x4 = FPx4<fp8e5m2>;
static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);
#if CUDA_VERSION >= 12080
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2);
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template <typename Tx2>
__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23,
const float scale) {
const float x0 = static_cast<float>(in01.x) * scale;
const float x1 = static_cast<float>(in01.y) * scale;
const float x2 = static_cast<float>(in23.x) * scale;
const float x3 = static_cast<float>(in23.y) * scale;
out = fp4e2m1x4(make_float4(x0, x1, x2, x3));
}
#endif // CUDA_VERSION >= 12080
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) {
......@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx
......
......@@ -27,6 +27,7 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
.value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
......@@ -41,6 +42,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
......@@ -11,7 +11,7 @@
#include "../common.h"
#include "../utils.cuh"
#include "math.h"
namespace transformer_engine {
/* \brief Helper class that enables storing multiple values of type DType
......@@ -345,7 +345,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output,
const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N,
const Param params, cudaStream_t stream) {
const Param &params, cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output);
......@@ -379,7 +379,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax,
fp32 *scale_inv, const size_t N, const Param params,
fp32 *scale_inv, const size_t N, const Param &params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output);
......@@ -438,7 +438,13 @@ __launch_bounds__(unary_kernel_threads) __global__
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) {
__builtin_assume(max >= 0);
......@@ -539,10 +545,18 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) {
const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]);
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
bool dgate_in = true;
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
}
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p);
ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f;
if (requires_amax) {
__builtin_assume(max >= 0);
......
......@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
////////////////////////////////////////////////////////////////////////////////////////////////////
// Device-side error
#define NVTE_DEVICE_ERROR(message) \
do { \
printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
__func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \
(message)); \
assert(0); \
} while (false)
// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message) \
do { \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
threadIdx.y == 0 && threadIdx.z == 0) { \
NVTE_DEVICE_ERROR(message); \
} \
} while (false)
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y};
......
......@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.quantization import _default_sf_compute
def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
......
......@@ -290,10 +290,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
for stat in config["stats"]:
self.check_if_stat_is_supported(stat, recipe_name)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = (
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
start_step,
end_step,
start_end_list,
"fp8",
)
......
......@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage
from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
......@@ -123,17 +123,23 @@ class LogTensorStats(BaseLogTensorStats):
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert (
type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase]
type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage]
and tensor.dtype != torch.uint8
), (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using"
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = (
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
start_step,
end_step,
start_end_list,
)
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
......
......@@ -172,11 +172,19 @@ class StatsBuffers:
if self.at_least_one_layer_fed:
return True
iteration = TEDebugState.get_iteration()
for _, next_iter in self.layers_to_next_iter.items():
layers_to_remove = []
for layer_name, next_iter in self.layers_to_next_iter.items():
# When next_iter is None the feature will no longer run.
if next_iter is None:
layers_to_remove.append(layer_name)
continue
# Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it.
if iteration >= next_iter:
return True
for layer_name in layers_to_remove:
self.layers_to_next_iter.pop(layer_name, None)
return False
def reset(self):
......
......@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor,
Quantizer,
QuantizedTensorBase,
QuantizedTensorStorage,
prepare_for_saving,
restore_from_saved,
)
......@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer):
self._update_parent_quantizer_usage()
class DebugQuantizedTensor(QuantizedTensorBase):
class DebugQuantizedTensor(QuantizedTensorStorage):
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
......
......@@ -34,7 +34,7 @@ load_framework_extension("jax")
from . import flax
from . import quantize
from .quantize import fp8_autocast, update_collections, get_delayed_scaling
from .quantize import autocast, fp8_autocast, update_collections
from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource
......@@ -45,9 +45,9 @@ from ..common.utils import DeprecatedEnum
__all__ = [
"NVTE_FP8_COLLECTION_NAME",
"autocast",
"fp8_autocast",
"update_collections",
"get_delayed_scaling",
"MeshResource",
"flax",
"quantize",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment