Commit 063ef88d authored by wenjh's avatar wenjh
Browse files

Merge nv main up to v2.10.0.dev0


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 91670b05 5624dbb4
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
#include "../common.h" #include "../common.h"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine::detail { namespace transformer_engine::detail {
...@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
const bool pow_2_scale, const SimpleTensor &noop_tensor, const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream); cudaStream_t stream);
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv,
SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor &noop_tensor, cudaStream_t stream);
} // namespace transformer_engine::detail } // namespace transformer_engine::detail
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh" #include "common/util/ptx.cuh"
#include "common/utils.cuh" #include "common/utils.cuh"
...@@ -901,6 +902,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -901,6 +902,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
NVTE_API_CALL(quantize_transpose_square_blockwise); NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream); checkCuDriverContext(stream);
if (transformer_engine::cuda::sm_arch() >= 100) {
NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ",
"with MXFP8, which requires using power of two scaling factors.");
}
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1; size_t num_rows = 1;
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h" #include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -1480,6 +1481,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -1480,6 +1481,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_API_CALL(quantize_transpose_vector_blockwise);
if (transformer_engine::cuda::sm_arch() >= 100) {
NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ",
"with MXFP8, which requires using power of two scaling factors.");
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length; size_t num_elements = row_length;
size_t num_rows = 1; size_t num_rows = 1;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "curanddx.hpp"
namespace transformer_engine {
#if CUDA_VERSION >= 12080
namespace quantize_transpose_nvfp4 {
namespace {
using std::int32_t;
using std::uint32_t;
using std::uint8_t;
using transformer_engine::detail::TypeExtrema;
// Define a cuRANDDx descriptor
// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10.
// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g.,
// if shared memory, if needed, is enough for the described problem, usually not applicable.
// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr int kThreadsPerWarp = 32;
// for fp4, we use uint8_t to store 2 fp4 numbers
constexpr int kNFP4PerContainer = 2;
// Hyperparameters for performance tuning
constexpr int kTileDim = 128;
// constexpr int kScaleDim = 32;
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem");
constexpr int kSMemRow = kTileDim;
constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1;
constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem;
constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16
constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8
// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
// for 2D block scaling, we need to reduce amax in warp
static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = {
0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080};
// max for every group_size elements in warp
template <int group_size, int shfl_down_stride>
__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) {
for (int offset = group_size / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride));
}
return val;
}
template <typename ScaleType>
__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax,
const float global_encode_scale) {
float decode_scale = amax / TypeExtrema<fp4e2m1>::max;
decode_scale = decode_scale * global_encode_scale;
decode_scale = fminf(decode_scale, TypeExtrema<float>::max);
return static_cast<ScaleType>(decode_scale);
}
template <typename ScaleType>
__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale,
const float global_decode_scale) {
return fminf(1.0f / (static_cast<float>(decode_scale) * global_decode_scale),
TypeExtrema<float>::max);
}
template <typename IType, typename ScaleType>
__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) {
return static_cast<float>(input) * encode_scale;
}
__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) {
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.f || global_encode_scale == 0.f) {
return 1.f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t* const rbits_arr = reinterpret_cast<uint32_t*>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
template <class ScaleType>
__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx,
uint32_t col_length) {
// This function takes in indices from the scale factor matrix and returns an offset in the
// swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled
// index). col_length is the column length of the scale factor matrix. tile_scales_inv is the
// pointer to the scale factor matrix.
// https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
// For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4
// columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4
// columns.
// NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix.
// To think in high level, the swizzled scale factor matrix could be composed as:
// unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8)
// cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64
// rb_cnt = M // 128 # Assuming M is divisible by 128
// tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
// tmp = torch.permute(tmp, (0, 3, 2, 1, 4))
// swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4))
constexpr uint32_t kTotalRowsPerBaseBlock = 128;
constexpr uint32_t kRowsPerBaseBlockCol = 32;
constexpr uint32_t kColsPerBaseBlockCol = 4;
const size_t rb = row_idx / kTotalRowsPerBaseBlock;
const size_t rem = row_idx % kTotalRowsPerBaseBlock;
const size_t d4 = rem / kRowsPerBaseBlockCol;
const size_t d3 = rem % kRowsPerBaseBlockCol;
const size_t cbg = col_idx / kColsPerBaseBlockCol;
const size_t d5 = col_idx % kColsPerBaseBlockCol;
const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol);
// row-major offset in the logical shape
// (rb_cnt , cbg_cnt , 32 , 4 , 4)
// Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] /
// 32 = [0-4])
return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5;
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
uint16_t out_4x;
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t"
"}"
: "=h"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits));
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x);
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const uint32_t rbits) {
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
// NOTE: rbits unused for rn.
uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing.
asm volatile(
"{\n"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x));
return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0];
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
uint16_t dummy = 0;
return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy);
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
}
template <bool kApplyStochasticRounding>
__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const uint32_t rbits) {
if constexpr (kApplyStochasticRounding) {
return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits);
} else {
return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits);
}
}
template <bool kReturnIdentity, bool kReturnTranspose, bool kIsE8Scaling, bool kAligned,
typename CType, typename IType, typename OType, typename ScaleType, bool kSwizzledScale,
bool kApplyStochasticRounding, bool kIs2DBlockScaling>
__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel(
const IType* const input, const float* global_amax, OType* const output_c,
OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t,
const size_t row_length, const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y,
const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state,
const float* noop_ptr) {
constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer;
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecContainer>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
const size_t block_idx_x = blockIdx.x;
const size_t block_idx_y = blockIdx.y;
const size_t rng_sequence =
threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode.
// Instead of static_assert, return early if these invalid modes are detected.
if constexpr (kIs2DBlockScaling && kIsE8Scaling) {
return;
}
if constexpr (kIs2DBlockScaling && !kReturnIdentity) {
return;
}
// for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4
// use constexpr to define the size, when not using 2D, use minimal size 1x1
constexpr int kFP4BlockScalingSize = 16;
constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1;
constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4
constexpr int k2DBlockAmaxReduceDim =
kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1;
__shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim];
__shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim];
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = (c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0); // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory
// for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
const int kNumThreadsReduce = kScaleBlockDim / kNVecOut;
const float global_encode_scale =
kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]);
const float global_decode_scale = 1.0 / global_encode_scale;
// Step 2: Cast and store to output_c
if constexpr (kReturnIdentity) {
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele =
(c_g < row_length ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(row_length - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// doing shuffle sync for 2D block scaling (not applicable for E8 scaling)
if constexpr (kIs2DBlockScaling) {
// first amax shuffle sync in warp, then reduce in smem
// T0 T8 T16 T24 should do amax reduction together
constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
int tid_in_warp_x = threadIdx.x % kNumThreadsStore;
int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp;
CType amax_warp_reduced = groupMax<kNumRowsPerWarp, kNumThreadsStore>(
amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]);
// now T0 ~ T8 in each warp has the reduced amax values
int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y;
if (tid_in_warp_y == 0) {
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]
[warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced;
}
__syncthreads();
if (data_row_idx % kFP4BlockScalingSize == 0) {
CType amax_2d = 0.0;
for (int i = 0; i < k2DBlockAmaxReduceDim; i++) {
amax_2d = fmaxf(amax_2d,
amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]);
}
amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d;
}
__syncthreads();
// every thread now knows 2D amax
amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x];
}
// Step 2.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
write_scale_inv &= (c_g < row_length);
}
if (write_scale_inv) {
size_t row_idx = block_idx_y * kTileDim + r_s;
size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce;
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(row_length, kScaleBlockDim));
tile_scales_inv_c[offset] = scale_inv;
} else {
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[0], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i].data.elt[1], encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[0], encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[i + 1].data.elt[1], encode_scale);
const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory
// for not aligned case)
output_g += stride_g / kNFP4PerContainer;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if constexpr (kReturnTranspose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = (c_g < num_rows ? min(static_cast<size_t>(kNVecOut / kNFP4PerContainer),
(num_rows - c_g) / kNFP4PerContainer)
: 0); // For not aligned case
OType* output_g =
&output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane =
(threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
if constexpr (kIs2DBlockScaling) {
// TODO(zhongbo): 2D block scaling, directly read from amax_smem
int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7
constexpr int kNumColsPerWarp =
kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements
constexpr int kNumWarpsPerBlock =
kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block
constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock;
int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp;
int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore;
int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x;
amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize];
} else {
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
}
// Step 3.3: Reduce amax
if constexpr (kIsE8Scaling) {
#pragma unroll
for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
}
// Step 3.4: Compute scale
ScaleType scale_inv = ComputeDecodeScaleFP4<ScaleType>(amax, global_encode_scale);
float encode_scale = ComputeEncodeScaleFP4<ScaleType>(scale_inv, global_decode_scale);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
write_scale_inv &= (c_g < num_rows);
}
if (write_scale_inv) {
size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) +
(threadIdx.x % kNumThreadsStore) / kNumThreadsReduce);
if constexpr (kSwizzledScale) {
size_t offset = scale_factor_swizzled_offset<ScaleType>(
row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim));
tile_scales_inv_t[offset] = scale_inv;
} else {
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) {
// Pack two elements into __nv_bfloat162
float2 f2_a;
float2 f2_b;
f2_a.x =
ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i].data.elt[smem_idx], encode_scale);
f2_a.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * i + 1].data.elt[smem_idx],
encode_scale);
f2_b.x = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1)].data.elt[smem_idx],
encode_scale);
f2_b.y = ComputeOutputFP4<IType, ScaleType>(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx],
encode_scale);
const uint32_t rbits =
kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0;
// Convert to __nv_fp4x4_e2m1
__nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x<kApplyStochasticRounding>(f2_a, f2_b, rbits);
output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0];
output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1];
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0,
num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global
// memory for not aligned case)
output_g += stride_g / kNFP4PerContainer;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
} // namespace
} // namespace quantize_transpose_nvfp4
#endif // CUDA_VERSION >= 12080
namespace detail {
void quantize_transpose_vector_blockwise_fp4(
const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon,
const bool return_identity, const bool return_transpose, const bool pow2_scale,
const bool swizzled_scale, const bool use_stochastic_rounding,
const NVTETensor rng_state_tensor, const bool use_2d_quantization,
const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4);
#if CUDA_VERSION >= 12080
// pow 2 scale is for MXFP4 since it's using E8M0 scaling
// raise error if pow2_scale is true
NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now");
if (!return_identity && !return_transpose) {
return;
}
if (use_2d_quantization && !return_identity) {
return;
}
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
num_elements *= input.shape.at(i);
}
// Early return if the input tensor is empty
if (num_elements == 0) {
return;
}
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
if (return_identity) {
scale_stride_x = 1;
scale_stride_y = scale_inv.shape[1];
}
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}
using namespace transformer_engine::quantize_transpose_nvfp4;
const size_t num_blocks_x = DIVUP(row_length, static_cast<size_t>(kTileDim));
const size_t num_blocks_y = DIVUP(num_rows, static_cast<size_t>(kTileDim));
// noop tensor for cuda graph
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
const size_t* rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t*>(rng_state_te_tensor.data.dptr);
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(
output.dtype, 2, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16;
constexpr bool kPow2Scale = false;
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_identity, kReturnIdentity,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
swizzled_scale, kSwizzledScale,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kApplyStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_2d_quantization, kIs2DBlockScaling,
size_t smem_bytes = kSMemSize * sizeof(InputType);
auto kernel = block_scaled_1d_cast_transpose_kernel<
kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned,
float, InputType, OutputType, ScaleType, kSwizzledScale,
kApplyStochasticRounding, kIs2DBlockScaling>;
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_bytes);
NVTE_CHECK(err == cudaSuccess,
"Failed to set dynamic shared memory size.");
} kernel<<<grid, kThreadsPerBlock, smem_bytes,
stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<const float*>(global_amax.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<ScaleType*>(scale_inv.dptr),
reinterpret_cast<ScaleType*>(scale_inv_t.dptr), row_length,
num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x,
scale_t_stride_y, kScaleBlockDim, epsilon, rng_state,
noop_ptr);) // kIs2DBlockScaling
) // kApplyStochasticRounding
) // kSwizzledScale
) // kAligned
) // kReturnTranspose
) // kReturnIdentity
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace detail
} // namespace transformer_engine
...@@ -58,7 +58,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -58,7 +58,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate, const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr, float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols) { const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
...@@ -164,7 +165,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -164,7 +165,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems; OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
...@@ -174,6 +174,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -174,6 +174,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]); float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]); float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]); float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
...@@ -181,18 +187,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -181,18 +187,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { const float x = min(act_elt, p.limit);
const float s = sigmoidf(x); const float s = sigmoidf(p.alpha * x);
act_x = x * s; act_x = x * s;
dact_x = x * s * (1 - s) + s; if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else { } else {
act_x = ActOP(x, {}); if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
dact_x = DActOP(x, {}); const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
} }
float after_dact = dact_x * grad_elt * gate_elt; float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = act_x * grad_elt; float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact); out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate); out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
...@@ -200,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -200,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
amax = fmaxf(amax, fabsf(after_dact)); amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate)); amax = fmaxf(amax, fabsf(after_dgate));
} else { } else {
const float after_act = ActOP(act_elt, {}) * gate_elt; const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act); out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act)); amax = fmaxf(amax, fabsf(after_act));
} }
...@@ -305,7 +320,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -305,7 +320,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) { const size_t scale_stride_colwise, const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using IType2 = typename ptx::FPx2<IType>; using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>; using OType2 = typename ptx::FPx2<OType>;
...@@ -481,25 +496,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -481,25 +496,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]); float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
float after_act_elt; float after_act_elt;
float after_gate_elt; float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]); float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { const float x = min(act_elt, p.limit);
const float s = sigmoidf(x); const float s = sigmoidf(p.alpha * x);
act_x = x * s; act_x = x * s;
dact_x = x * s * (1 - s) + s; dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else { } else {
act_x = ActOP(x, {}); if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
dact_x = DActOP(x, {}); const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
} }
after_act_elt = dact_x * grad_elt * gate_elt; after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt; after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
} else { } else {
after_act_elt = ActOP(act_elt, {}) * gate_elt; after_act_elt = ActOP(act_elt, p) * gate_elt;
} }
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) { if constexpr (!std::is_same_v<IType, float>) {
...@@ -603,6 +630,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -603,6 +630,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
...@@ -724,27 +752,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -724,27 +752,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate.data.elt[e]); float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt; float after_act_elt;
float after_gate_elt; float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]); float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { const float x = min(act_elt, p.limit);
const float s = sigmoidf(x); const float s = sigmoidf(p.alpha * x);
act_x = x * s; act_x = x * s;
dact_x = x * s * (1 - s) + s; dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else { } else {
act_x = ActOP(x, {}); if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
dact_x = DActOP(x, {}); const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
} }
after_act_elt = dact_x * grad_elt * gate_elt; after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt; after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
after_act_rowwise[j] = after_act_elt; after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt; after_gate_rowwise[j] = after_gate_elt;
} else { } else {
after_act_elt = ActOP(act_elt, {}) * gate_elt; after_act_elt = ActOP(act_elt, p) * gate_elt;
after_act_rowwise[j] = after_act_elt; after_act_rowwise[j] = after_act_elt;
} }
...@@ -833,6 +873,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -833,6 +873,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
} }
} }
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
...@@ -889,7 +930,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -889,7 +930,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
...@@ -956,6 +997,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -956,6 +997,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t in_gate_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out;
const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
(out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
...@@ -966,8 +1008,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -966,8 +1008,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType> cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>( <<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
cols);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
#endif #endif
...@@ -975,7 +1016,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -975,7 +1016,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
...@@ -1109,7 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1109,7 +1150,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
...@@ -1126,7 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1126,7 +1167,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
...@@ -1135,7 +1176,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1135,7 +1176,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
OType, true, true, OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>, THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE> true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>( <<<grid, block_size, shmem_size, stream>>>(
...@@ -1143,7 +1183,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1143,7 +1183,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
}); // NOLINT(*) }); // NOLINT(*)
...@@ -1152,12 +1192,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1152,12 +1192,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
} }
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(),
"Wrong output shape. Expected (after flattening) [", input.flat_first_dim(),
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(input.flat_last_dim() % 2 == 0, NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
...@@ -1179,7 +1216,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -1179,7 +1216,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(), reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), {}, stream); output->flat_last_dim(), p, stream);
} else { } else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*) }); // NOLINT(*)
...@@ -1188,7 +1225,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -1188,7 +1225,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &), template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input"); CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output"); CheckOutputTensor(*output, "dgated_act_output");
...@@ -1217,7 +1255,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt ...@@ -1217,7 +1255,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(), reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), {}, stream); grad.flat_last_dim(), p, stream);
} else { } else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*) }); // NOLINT(*)
...@@ -1226,7 +1264,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt ...@@ -1226,7 +1264,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p,
cudaStream_t stream) { cudaStream_t stream) {
constexpr bool allow_empty = false; constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input"); CheckInputTensor(gated_input, "gated_input");
...@@ -1266,17 +1304,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -1266,17 +1304,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if (is_delayed_tensor_scaling(output->scaling_mode)) { if (is_delayed_tensor_scaling(output->scaling_mode)) {
if (use_tma_kernels) { if (use_tma_kernels) {
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else { } else {
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else { } else {
cast_gated<ParamOP, ActOP>(gated_input, output, stream); cast_gated<ParamOP, ActOP>(gated_input, output, p, stream);
} }
} }
} else if (is_mxfp_scaling(output->scaling_mode)) { } else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) { if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else { } else {
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", gated_input.data.shape); "by 32, got input of shape ", gated_input.data.shape);
...@@ -1292,7 +1330,7 @@ namespace detail { ...@@ -1292,7 +1330,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
cudaStream_t stream) { ParamOP p, cudaStream_t stream) {
using namespace gated_kernels; using namespace gated_kernels;
Tensor grad_empty_tensor; Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
...@@ -1301,13 +1339,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, ...@@ -1301,13 +1339,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
if (is_supported_by_CC_100()) { if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
output_tensor, stream); output_tensor, p, stream);
} else { } else {
if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) {
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, stream); cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, p,
stream);
} else { } else {
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, stream); cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, p, stream);
} }
} else { } else {
// MX scaling // MX scaling
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "../util/vectorized_pointwise.h" #include "../util/vectorized_pointwise.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "math.h" #include "math.h"
#include "nvfp4_transpose.cuh"
#include "ptx.cuh" #include "ptx.cuh"
#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transformer_engine.h"
...@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -110,6 +111,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
// helps resolving bank conflicts in shmem // helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK; const int bank_group = thread_lane / THREADS_PER_BANK;
...@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -137,8 +140,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem); IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem); IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise); OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
...@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -286,7 +290,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float scaled_out = in * block_scale_inverse; const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out); out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
} }
} }
...@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -410,10 +414,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent = const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise; const int stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent; if (rowwise_scale_is_within_bounds) {
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
...@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -441,7 +447,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
} }
} }
...@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -456,19 +462,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y; const int global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X; const int global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM; const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) { if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X, reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset])); global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
} }
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X, reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset])); global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
} }
// Create a "bulk async-group" out of the previous bulk copy operation. // Create a "bulk async-group" out of the previous bulk copy operation.
...@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -489,18 +495,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Added extra 1-element padding per thread_X to reduce bank conflicts // Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem); float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const size_t shmem_thread_offset = const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll #pragma unroll
for (int w = 0; w < WAVES; ++w) { for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll #pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) { for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e; const int j = w * PACK_SIZE + e;
const size_t shmem_elt_idx = swizzled_group_offset + e; const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
} }
} }
...@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -508,15 +514,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int i = 0; i < THREADS_Y; ++i) { for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32] // Add extra element offset per MXFP8 scaling block [1x32]
const size_t scaling_block = threadIdx.x / SCALE_DIM_X; const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias += thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
} }
} }
const size_t dbias_stride = cols; const int dbias_stride = cols;
const size_t dbias_offset_Y = blockIdx.y; const int dbias_offset_Y = blockIdx.y;
const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) { if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias; dbias_workspace[dbias_idx] = thread_partial_dbias;
...@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -539,6 +545,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
} // namespace mxfp8_kernel } // namespace mxfp8_kernel
namespace nvfp4_kernel {
using namespace ptx;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 16;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// return S_dec_b_fp8;
return static_cast<fp8e4m3>(block_amax * rcp_6f * S_enc);
}
#define DIRECT_SCALING_FACTORS_STORE 1
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, typename OType, bool COLWISE_SCALING, size_t CHUNK_DIM_Y,
size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap tensor_map_output_colwise,
fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0,
const float *noop, float *const amax_ptr,
const float *const nvfp4_second_stage_scale_ptr, const size_t rows,
const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool ROWWISE_SCALING = true;
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE;
static_assert(BUFF_DIM_Y >= SCALE_DIM_Y &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size
constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE;
// static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of
// // threads to process one row in a single iteration
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * CHUNK_DIM_X;
const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int tid_Y_colwise = 0;
const int tid_X_colwise = threadIdx.x;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const int thread_offset_Y_colwise = tid_Y_colwise;
const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements
const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const int col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols;
const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0);
constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0);
constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0);
constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise_data);
fp8e4m3 *out_rowwise_scales_sh =
reinterpret_cast<fp8e4m3 *>(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
e8m0_t *out_colwise_scales_sh = reinterpret_cast<e8m0_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factor for all S_dec_b
const float S_enc =
(nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr);
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int buff = stage % BUFFS_NUM;
const int next_stage = stage + 1;
const int stage_offset_Y = stage * BUFF_DIM_Y;
const int buff_offset_in = buff * BUFF_IN_DIM;
const int buff_offset_out = buff * BUFF_OUT_DIM;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const int next_buff = next_stage % BUFFS_NUM;
const int next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int global_offset_X = block_offset_X;
const int next_buff_offset = next_buff * BUFF_IN_DIM;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM_Y];
IType in_colwise_IType[SCALE_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(block_amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_offset_Y_colwise + stage;
const int global_scales_offset_X = scales_offset_X_colwise;
const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
if (colwise_scale_is_within_bounds) {
scales_colwise_e8m0[scale_idx] = biased_exponent;
}
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (int it = 0; it < ITERATIONS_ROWWISE; ++it) {
const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const int shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const int shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc);
#if DIRECT_SCALING_FACTORS_STORE
// Check boundaries
if (rowwise_scale_is_within_bounds) {
const int scales_offset_Y =
scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE;
const int scales_offset_X = scales_offset_X_rowwise;
const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X;
scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8;
}
#else
const int shmem_scales_offset_Y =
stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise;
const int shmem_scales_offset_X = tid_X_rowwise;
const int scale_idx =
shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X;
out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8;
#endif
// Compute "correct" per-block encoding scaling factor
const float block_scale_inverse =
__fdiv_rn(S_enc, static_cast<float>(S_dec_b_fp8)); // S_enc_b_fp8
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out; // Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
in01 = in_IType[w].data.elt[2 * e];
in23 = in_IType[w].data.elt[2 * e + 1];
} else if constexpr (IS_CACHED_ACT_OP) {
in01.x = in_cached[w].data.elt[4 * e];
in01.y = in_cached[w].data.elt[4 * e + 1];
in23.x = in_cached[w].data.elt[4 * e + 2];
in23.y = in_cached[w].data.elt[4 * e + 3];
} else {
const int j = w * PACK_SIZE + 4 * e;
in01.x = in_compute_rowwise[j];
in01.y = in_compute_rowwise[j + 1];
in23.x = in_compute_rowwise[j + 2];
in23.y = in_compute_rowwise[j + 3];
}
fp4e2m1x4 &out_quad = reinterpret_cast<fp4e2m1x4 &>(out.data.elt[e]);
ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
__builtin_assume(block_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM;
const int buff_offset_mxfp8 = buff * BUFF_IN_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset_nvfp4]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset_mxfp8]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
#if !DIRECT_SCALING_FACTORS_STORE
// Vectorized store of scaling factors.
// Each thread stores multiple scaling factors in one store instruction.
if constexpr (ROWWISE_SCALING) {
// Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X
const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x;
const int scales_offset_X_rowwise = scales_block_offset_X_rowwise;
const int scale_idx_global =
scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise;
const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW;
if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) &&
(scales_offset_X_rowwise < (cols / SCALE_DIM_X))) {
using ScalesVec_t = Vec<fp8e4m3, NVFP4_SCALING_FACTORS_PER_CHUNK_ROW>;
const ScalesVec_t &scales =
*reinterpret_cast<ScalesVec_t *>(&out_rowwise_scales_sh[scale_idx_shmem]);
scales.store_to(&scales_rowwise_e4m3[scale_idx_global]);
}
}
#endif
float chunk_amax = 0.0f;
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
chunk_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(thread_amax, warp_id);
}
if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, chunk_amax);
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_CHUNK_DIM_X = 128;
constexpr size_t FP8_THREADS_PER_CHUNK = 128; constexpr size_t FP8_THREADS_PER_CHUNK = 128;
...@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, ...@@ -903,7 +1431,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows,
} }
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)> template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) {
const size_t N = product(input.data.shape); const size_t N = product(input.data.shape);
const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); const bool isFullTile = (N % ELEMS_PER_BLOCK == 0);
...@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1192,6 +1720,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#endif #endif
} }
// This kernel supports only two scaling cases:
// 1. r16c0 - Rowwise NVFP4
// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &)>
void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) {
using namespace nvfp4_kernel;
using namespace ptx;
checkCuDriverContext(stream);
NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated.");
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
bool use_colwise_scaling = output->has_columnwise_data();
if (use_colwise_scaling) {
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Columnwise scaling tensor must be allocated");
}
CheckNoopTensor(*noop, "cast_noop");
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = output->scale_inv.shape[1];
const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;
fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast<fp8e4m3 *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_e8m0_ptr =
use_colwise_scaling ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const ScalingType scaling_type =
use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const nvfp4_second_stage_scale_ptr =
reinterpret_cast<const float *>(output->scale.dptr);
// Output data type is only required for the column-wise MXFP8 scaling.
// It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work
const DType output_data_type =
use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3;
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y,
BUFF_DIM_X, cols, 0, sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4);
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8);
}
constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_nvfp4 =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out_mxfp8 =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_nvfp4_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3);
constexpr size_t buff_size_mxfp8_scales =
(CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t);
constexpr size_t in_mem = buff_size_aligned_in;
const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4;
const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0;
const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales;
const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0;
const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem +
out_rowwise_scales_mem + out_colwise_scales_mem +
TMA_SHMEM_ALIGNMENT;
const size_t dshmem_size = in_mem + out_mem;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, false, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_nvfp4_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType, OType, true, CHUNK_DIM_Y,
CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise,
scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr,
nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}); // NOLINT(*)
); // NOLINT(*)
}
namespace detail { namespace detail {
using Empty = transformer_engine::Empty; using Empty = transformer_engine::Empty;
...@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1417,13 +2080,26 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
auto dbias_tensor = convertNVTETensor(dbias); auto dbias_tensor = convertNVTETensor(dbias);
auto workspace_tensor = convertNVTETensor(workspace); auto workspace_tensor = convertNVTETensor(workspace);
const QuantizationConfig *quant_config_cpp = // Quantization config
reinterpret_cast<const QuantizationConfig *>(quant_config); QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// extract noop tensor from quant_config_cpp if it's not null // Noop flag
const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; Tensor dummy_tensor;
const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Dispatch to quantization kernel depending on data format
switch (output_tensor->scaling_mode) { switch (output_tensor->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
...@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1435,7 +2111,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
NVTE_CHECK(output_tensor->has_data(), NVTE_CHECK(output_tensor->has_data(),
"Quantizing in only the columnwise direction not supported yet!"); "Quantizing in only the columnwise direction not supported yet!");
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) {
cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream);
} else { } else {
cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>( cast_transpose_fused<IS_DBIAS, IS_DACT, IS_ACT, float, ParamOP, OP>(
*input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor,
...@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1443,51 +2119,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
} }
} else if (output_tensor->has_data()) { } else if (output_tensor->has_data()) {
fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>( fp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream); workspace_tensor, stream);
} }
break; break;
} }
case NVTE_MXFP8_1D_SCALING: { case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>( mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(
*input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream); workspace_tensor, stream);
break; break;
} }
case NVTE_NVFP4_1D_SCALING: {
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
CheckOutputTensor(*output_tensor, "output", false);
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 &&
output_tensor->has_data();
// Launch NVFP4 quantize kernel
if (use_optimized_kernel) {
if (quant_config_cpp.nvfp4_2d_quantization) {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, true>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
} else {
nvfp4_quantize_transpose<IS_ACT, ParamOP, OP, false>(
*input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream);
}
} else {
auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax
: output_tensor->columnwise_amax;
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for "
"2D quantization");
quantize_transpose_vector_blockwise_fp4(
/*input=*/input_tensor->data, /*global_amax=*/global_amax,
/*scale_inv=*/output_tensor->scale_inv,
/*scale_inv_t=*/output_tensor->columnwise_scale_inv,
/*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data,
/*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(),
/*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false,
/*swizzled_scale=*/false,
/*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding,
/*rng_state=*/quant_config_cpp.rng_state,
/*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization,
/*noop_tensor=*/noop_tensor->data, /*stream=*/stream);
}
break;
}
case NVTE_BLOCK_SCALING_2D: { case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; float epsilon = quant_config_cpp.amax_epsilon;
quantize_transpose_square_blockwise( quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream); /*noop_tensor=*/noop_tensor->data, stream);
break; break;
} }
case NVTE_BLOCK_SCALING_1D: { case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales;
float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; float epsilon = quant_config_cpp.amax_epsilon;
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) { if (output_tensor->has_data()) {
bool rowwise_compact = quant_config_cpp bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
? quant_config_cpp->float8_block_scale_tensor_format == Float8BlockScaleTensorFormat::COMPACT);
Float8BlockScaleTensorFormat::COMPACT
: false;
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
} }
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = quant_config_cpp bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
? quant_config_cpp->float8_block_scale_tensor_format == Float8BlockScaleTensorFormat::COMPACT);
Float8BlockScaleTensorFormat::COMPACT
: false;
columnwise_option = columnwise_compact columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
...@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1495,7 +2210,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_vector_blockwise( quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor.data, stream); columnwise_option, force_pow_2_scales, noop_tensor->data, stream);
break; break;
} }
default: default:
......
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
#include <transformer_engine/cast.h> #include <transformer_engine/cast.h>
#include <cfloat> #include <cfloat>
#include <cstddef>
#include <cstdint>
#include <limits> #include <limits>
#include "../common.h" #include "../common.h"
...@@ -28,6 +30,7 @@ ...@@ -28,6 +30,7 @@
#include "math.h" #include "math.h"
#include "ptx.cuh" #include "ptx.cuh"
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/transformer_engine.h"
#include "transformer_engine/transpose.h" #include "transformer_engine/transpose.h"
namespace transformer_engine { namespace transformer_engine {
...@@ -339,6 +342,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ...@@ -339,6 +342,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
#endif #endif
} }
#if CUDA_VERSION >= 12080
template <typename OType>
__global__ void __launch_bounds__(512)
dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales,
const float *const tensor_amax, const size_t N, const size_t M,
const size_t scale_stride) {
const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const size_t x = thread_idx % M;
const size_t y = thread_idx / M;
union fp4vec {
uint64_t vec;
fp4e2m1x4 small_vec[4];
};
using OVec = Vec<OType, 4>;
const uint64_t *const input_vectorized = reinterpret_cast<const uint64_t *>(input);
OVec *output_vec = reinterpret_cast<OVec *>(output);
const size_t my_index = x + y * M;
const size_t my_scale_index = x + y * scale_stride;
const size_t my_output_index = (x + y * M) * 4;
fp4vec value;
value.vec = input_vectorized[my_index];
fp8e4m3 scale = scales[my_scale_index];
float amax = *tensor_amax;
constexpr float factor_inv = 1.0 / (6.0 * 448.0);
float final_scale = static_cast<float>(scale) * amax * factor_inv;
#pragma unroll
for (int i = 0; i < 4; i++) {
float4 current = static_cast<float4>(value.small_vec[i]);
OVec out;
out.data.elt[0] = static_cast<OType>(current.x * final_scale);
out.data.elt[1] = static_cast<OType>(current.y * final_scale);
out.data.elt[2] = static_cast<OType>(current.z * final_scale);
out.data.elt[3] = static_cast<OType>(current.w * final_scale);
output_vec[my_output_index + i] = out;
}
}
#endif // CUDA_VERSION
void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) {
#if CUDA_VERSION >= 12080
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output");
NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type.");
NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
constexpr int FP4_BLOCK_SIZE = 16;
const size_t N = input.flat_first_dim();
const size_t M = input.flat_last_dim();
NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ",
FP4_BLOCK_SIZE, ", but got ", input.data.shape, ".");
const size_t Mread = M / FP4_BLOCK_SIZE;
const size_t total = N * Mread;
const size_t threads = 512;
const size_t blocks = DIVUP(total, threads);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->data.dtype, OType,
dequantize_fp4_kernel<<<blocks, threads, 0, stream>>>(
input.data.dptr, reinterpret_cast<OType *>(output->data.dptr),
reinterpret_cast<fp8e4m3 *>(input.scale_inv.dptr),
reinterpret_cast<float *>(input.amax.dptr), N, Mread,
input.scale_inv.shape.back());); // NOLINT(*)
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!");
#endif // CUDA_VERSION >= 12080
}
} // namespace dequantization } // namespace dequantization
namespace detail { namespace detail {
...@@ -347,17 +425,25 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -347,17 +425,25 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
CheckInputTensor(input, "cast_input"); CheckInputTensor(input, "cast_input");
CheckOutputTensor(*output, "cast_output"); CheckOutputTensor(*output, "cast_output");
if (is_tensor_scaling(input.scaling_mode)) { switch (input.scaling_mode) {
dequantization::fp8_dequantize(input, output, stream); case NVTE_DELAYED_TENSOR_SCALING: {
} else if (is_mxfp_scaling(input.scaling_mode)) { dequantization::fp8_dequantize(input, output, stream);
if (is_supported_by_CC_100()) { break;
dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
} }
} else { case NVTE_MXFP8_1D_SCALING: {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING if (is_supported_by_CC_100()) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); dequantization::mxfp8_dequantize(input, output, stream);
} else {
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
dequantization::fp4_dequantize(input, output, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
} }
} }
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h> #include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP #ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h> #include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
...@@ -147,4 +149,12 @@ ...@@ -147,4 +149,12 @@
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
...@@ -11,6 +11,11 @@ namespace transformer_engine { ...@@ -11,6 +11,11 @@ namespace transformer_engine {
struct Empty {}; struct Empty {};
struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
};
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType gelu(const IType val, const Empty&) { __device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
...@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) { ...@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s); return s * (1.f - s);
} }
template <typename OType, typename IType>
__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) {
const float cval = val;
Empty e = {};
return cval * sigmoid<float, float>(alpha * cval, e);
}
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) { __device__ inline OType qgelu(const IType val, const Empty& e) {
return qgelu_with_alpha<OType, IType>(val, 1.702f);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) {
const float cval = val; const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e); Empty e = {};
return alpha * cval * dsigmoid<float, float>(alpha * cval, e) +
sigmoid<float, float>(alpha * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) { __device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val; return dqgelu_with_alpha<OType, IType>(val, 1.702f);
return 1.702f * cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
...@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) { ...@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
return cval * sigmoid<float, float>(cval, e); return cval * sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType>
__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) {
const float cval = min(p.limit, static_cast<float>(val)); // Clamping
return qgelu_with_alpha<OType, float>(cval, p.alpha);
}
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dsilu(const IType val, const Empty& e) { __device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e); return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType>
__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) {
const bool dclamp_val = static_cast<float>(val) <= p.limit;
const float clamp_val = min(static_cast<float>(val), p.limit);
const float dsilu_val = dqgelu_with_alpha<OType, float>(clamp_val, p.alpha);
return dclamp_val ? dsilu_val : 0.0f;
}
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty&) { __device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f); return fmaxf(value, 0.f);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file nvfp4_transpose.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "curanddx.hpp"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#if CUDA_VERSION > 12080
namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
using namespace ptx;
using nvfp4_scale_t = fp8e4m3;
constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_NUM = 128;
constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM;
constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM;
constexpr size_t RNG_GENS_PER_THREAD =
SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers
constexpr size_t TILE_DIM_Y = 32;
constexpr size_t TILE_DIM_X = 128;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8
constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X;
constexpr size_t STAGES = TILES_Y * TILES_X;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = TILE_DIM_Y;
constexpr size_t BUFF_DIM_X = TILE_DIM_X;
constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X;
constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM / PACK_SIZE;
constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8
constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16
constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2
constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM;
constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
// Compute the global encode scale factor for a given global amax
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
using namespace detail;
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
} else {
return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
}
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
} else {
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
// const size_t tid_X_t = 0;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise =
(row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr size_t BLOCK_DIM = 16;
constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2
constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8
constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile
constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
const size_t warp_id = threadIdx.x / 32;
const size_t lane_id = threadIdx.x % 32;
float thread_amax = 0.0f;
const size_t block_in_warp = lane_id / BLOCKS_PER_WARP;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
__shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1];
// Helper function for warp reduction
auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float {
#pragma unroll
for (int delta = 8; delta >= 1; delta /= 2) {
float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta);
thread_amax = fmaxf(thread_amax, other_amax);
}
return thread_amax;
};
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
#pragma unroll
for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
const size_t block_in_tile_y = block_iter;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
for (int elem = 0; elem < BLOCK_DIM; elem += 2) {
const size_t elem_0_row = block_iter * BLOCK_DIM + elem;
const size_t elem_1_row = elem_0_row + 1;
const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
const size_t elem_1_col = elem_0_col;
const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col;
const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col;
IType2 val_2x;
val_2x.x = in_sh[shmem_offset_0];
val_2x.y = in_sh[shmem_offset_1];
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x);
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else {
for (int elem = 0; elem < BLOCK_DIM; ++elem) {
const size_t elem_row = block_iter * BLOCK_DIM + elem;
const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
// Bounds checking
const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows);
const bool col_out_of_bounds = (block_offset_X + elem_col >= cols);
if (!row_out_of_bounds && !col_out_of_bounds) {
const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col;
float elt = static_cast<float>(in_sh[shmem_offset]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
}
}
// Warp reduction to get block amax
block_amax = warp_reduce_amax(thread_amax, block_in_warp);
if (lane_id == 0 || lane_id == 16) {
block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads();
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 3. Scale elements
// Load data in
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
}
} else {
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// // Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = tid_X_rowwise;
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
}
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
using namespace nvfp4_transpose;
using namespace ptx;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", false);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if (return_transpose) {
NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
"Transposed output must have FP4 type.");
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Transposed scaling tensor must be allocated");
}
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;
const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride_transpose =
return_transpose ? output->columnwise_scale_inv.shape[1] : 0;
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
nvfp4_scale_t *const scales_transpose_ptr =
reinterpret_cast<nvfp4_scale_t *>(output->columnwise_scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const amax_rowwise_ptr = reinterpret_cast<const float *>(output->amax.dptr);
const float *const amax_colwise_ptr =
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
using IType = bf16;
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_data_mem = buff_size_aligned_out;
constexpr size_t out_data_transpose_mem = buff_size_aligned_out;
constexpr size_t out_scales_transpose_mem = buff_size_scales;
constexpr size_t out_mem = out_data_mem + out_data_transpose_mem;
constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = nvfp4_transpose_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
if constexpr (use_2d_quantization) {
kernel = nvfp4_transpose_kernel_2D<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
namespace transformer_engine { namespace transformer_engine {
namespace ptx { namespace ptx {
...@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { ...@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS); return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
} }
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ #if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out; uint16_t out;
asm volatile( asm volatile(
"{\n" "{\n"
...@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 { ...@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
T y; T y;
}; };
template <typename T>
struct FPx4 {
T x1;
T x2;
T x3;
T x4;
};
template <typename T>
struct Type2x {};
template <>
struct Type2x<float> {
using type = float2;
};
template <>
struct Type2x<bf16> {
using type = __nv_bfloat162;
};
template <>
struct Type2x<fp16> {
using type = __half2;
};
using floatx2 = FPx2<float>; using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>; using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>; using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>; using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>; using fp8e5m2x2 = FPx2<fp8e5m2>;
using floatx4 = FPx4<float>;
using bf16x4 = FPx4<bf16>;
using fp16x4 = FPx4<fp16>;
using fp8e4m3x4 = FPx4<fp8e4m3>;
using fp8e5m2x4 = FPx4<fp8e5m2>;
static_assert(sizeof(floatx2) == 8); static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2);
#if CUDA_VERSION >= 12080
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2);
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template <typename Tx2>
__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23,
const float scale) {
const float x0 = static_cast<float>(in01.x) * scale;
const float x1 = static_cast<float>(in01.y) * scale;
const float x2 = static_cast<float>(in23.x) * scale;
const float x3 = static_cast<float>(in23.y) * scale;
out = fp4e2m1x4(make_float4(x0, x1, x2, x3));
}
#endif // CUDA_VERSION >= 12080
// SIMD like "Fused" cast + multiplication (x2) // SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) { const floatx2 &scale) {
...@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const ...@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
"r"(reinterpret_cast<const uint32_t &>(p2))); "r"(reinterpret_cast<const uint32_t &>(p2)));
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
.value("kInt8", transformer_engine::DType::kInt8); \ .value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \ pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
...@@ -41,6 +42,10 @@ ...@@ -41,6 +42,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \ pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "../common.h" #include "../common.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "math.h"
namespace transformer_engine { namespace transformer_engine {
/* \brief Helper class that enables storing multiple values of type DType /* \brief Helper class that enables storing multiple values of type DType
...@@ -345,7 +345,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen ...@@ -345,7 +345,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
typename OutputType> typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output,
const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N,
const Param params, cudaStream_t stream) { const Param &params, cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output); auto align = CheckAlignment(N, nvec, input, output);
...@@ -379,7 +379,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In ...@@ -379,7 +379,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename InputTypeGrad, typename OutputType> typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax, OutputType *output, const fp32 *scale, fp32 *amax,
fp32 *scale_inv, const size_t N, const Param params, fp32 *scale_inv, const size_t N, const Param &params,
cudaStream_t stream) { cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output); auto align = CheckAlignment(N, nvec, input, grad, output);
...@@ -438,7 +438,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -438,7 +438,13 @@ __launch_bounds__(unary_kernel_threads) __global__
#pragma unroll #pragma unroll
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]); ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2); ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
...@@ -539,10 +545,18 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -539,10 +545,18 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]); const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]);
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]); const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]); ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
bool dgate_in = true;
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
}
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p); ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f;
if (requires_amax) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
......
...@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32; ...@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Device-side error
#define NVTE_DEVICE_ERROR(message) \
do { \
printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
__func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \
(message)); \
assert(0); \
} while (false)
// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message) \
do { \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
threadIdx.y == 0 && threadIdx.z == 0) { \
NVTE_DEVICE_ERROR(message); \
} \
} while (false)
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__) #if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y}; return {a.x + b.x, a.y + b.y};
......
...@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format ...@@ -19,7 +19,7 @@ from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.fp8 import _default_sf_compute from transformer_engine.pytorch.quantization import _default_sf_compute
def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
......
...@@ -290,10 +290,16 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -290,10 +290,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
for stat in config["stats"]: for stat in config["stats"]:
self.check_if_stat_is_supported(stat, recipe_name) self.check_if_stat_is_supported(stat, recipe_name)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = ( options = (
config.get("start_step", None), start_step,
config.get("end_step", None), end_step,
config.get("start_end_list", None), start_end_list,
"fp8", "fp8",
) )
......
...@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api ...@@ -15,8 +15,8 @@ import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
...@@ -123,17 +123,23 @@ class LogTensorStats(BaseLogTensorStats): ...@@ -123,17 +123,23 @@ class LogTensorStats(BaseLogTensorStats):
"""API call used to collect the data about the tensor before process_tensor()/quantization.""" """API call used to collect the data about the tensor before process_tensor()/quantization."""
assert ( assert (
type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase] type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage]
and tensor.dtype != torch.uint8 and tensor.dtype != torch.uint8
), ( ), (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using" f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using"
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors." " log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
) )
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = ( options = (
config.get("start_step", None), start_step,
config.get("end_step", None), end_step,
config.get("start_end_list", None), start_end_list,
) )
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
......
...@@ -172,11 +172,19 @@ class StatsBuffers: ...@@ -172,11 +172,19 @@ class StatsBuffers:
if self.at_least_one_layer_fed: if self.at_least_one_layer_fed:
return True return True
iteration = TEDebugState.get_iteration() iteration = TEDebugState.get_iteration()
for _, next_iter in self.layers_to_next_iter.items(): layers_to_remove = []
for layer_name, next_iter in self.layers_to_next_iter.items():
# When next_iter is None the feature will no longer run.
if next_iter is None:
layers_to_remove.append(layer_name)
continue
# Note that layer can be not run for many iterations, # Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it. # in this case we will synchronize until every step until we get any information from it.
if iteration >= next_iter: if iteration >= next_iter:
return True return True
for layer_name in layers_to_remove:
self.layers_to_next_iter.pop(layer_name, None)
return False return False
def reset(self): def reset(self):
......
...@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe ...@@ -18,7 +18,7 @@ from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.quantized_tensor import ( from transformer_engine.pytorch.tensor.quantized_tensor import (
QuantizedTensor, QuantizedTensor,
Quantizer, Quantizer,
QuantizedTensorBase, QuantizedTensorStorage,
prepare_for_saving, prepare_for_saving,
restore_from_saved, restore_from_saved,
) )
...@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer): ...@@ -557,7 +557,7 @@ class DebugQuantizer(Quantizer):
self._update_parent_quantizer_usage() self._update_parent_quantizer_usage()
class DebugQuantizedTensor(QuantizedTensorBase): class DebugQuantizedTensor(QuantizedTensorStorage):
""" """
Class containing quantized tensors after debug. Depending on configuration Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method it can contain one or two different objects. These objects can be accessed by the method
......
...@@ -34,7 +34,7 @@ load_framework_extension("jax") ...@@ -34,7 +34,7 @@ load_framework_extension("jax")
from . import flax from . import flax
from . import quantize from . import quantize
from .quantize import fp8_autocast, update_collections, get_delayed_scaling from .quantize import autocast, fp8_autocast, update_collections
from .quantize import NVTE_FP8_COLLECTION_NAME from .quantize import NVTE_FP8_COLLECTION_NAME
from .sharding import MeshResource from .sharding import MeshResource
...@@ -45,9 +45,9 @@ from ..common.utils import DeprecatedEnum ...@@ -45,9 +45,9 @@ from ..common.utils import DeprecatedEnum
__all__ = [ __all__ = [
"NVTE_FP8_COLLECTION_NAME", "NVTE_FP8_COLLECTION_NAME",
"autocast",
"fp8_autocast", "fp8_autocast",
"update_collections", "update_collections",
"get_delayed_scaling",
"MeshResource", "MeshResource",
"flax", "flax",
"quantize", "quantize",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment