Commit 53fa872c authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_release_v2.8' into release_v2.8

parents 27ddce40 40c69e75
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
#include <nvrtc.h> #include <nvrtc.h>
#include "nccl.h"
#ifdef NVTE_WITH_CUBLASMP #ifdef NVTE_WITH_CUBLASMP
#include <cublasmp.h> #include <cublasmp.h>
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
...@@ -147,4 +149,12 @@ ...@@ -147,4 +149,12 @@
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
#define NVTE_CHECK_NCCL(expr) \
do { \
const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \
if (status_NVTE_CHECK_NCCL != ncclSuccess) { \
NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \
} \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file nvfp4_transpose.cuh
* \brief CUDA kernels to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#if CUDA_VERSION > 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION > 12080
#include <cfloat>
#include "../common.h"
#include "../utils.cuh"
#include "curanddx.hpp"
#include "math.h"
#include "ptx.cuh"
#include "transformer_engine/transformer_engine.h"
namespace transformer_engine {
#if CUDA_VERSION > 12080
namespace nvfp4_transpose {
using RNG = decltype(curanddx::Generator<curanddx::philox4_32>() + curanddx::PhiloxRounds<10>() +
curanddx::SM<800>() + curanddx::Thread());
using namespace ptx;
using nvfp4_scale_t = fp8e4m3;
constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_NUM = 128;
constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM;
constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM;
constexpr size_t RNG_GENS_PER_THREAD =
SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers
constexpr size_t TILE_DIM_Y = 32;
constexpr size_t TILE_DIM_X = 128;
// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D
constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8
constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y;
constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X;
constexpr size_t STAGES = TILES_Y * TILES_X;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = TILE_DIM_Y;
constexpr size_t BUFF_DIM_X = TILE_DIM_X;
constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X;
constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr size_t PACK_SIZE = 8;
constexpr size_t WAVES = SCALE_DIM / PACK_SIZE;
constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8
constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16
constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2
constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM;
constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16
// Compute per-block E4M3 encoding/decoding scaling factor
__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax,
const float S_enc) {
// constexpr float rcp_6f = 1.0f / 6.0f;
// const float S_dec_b = block_amax * rcp_6f;
// const nvfp4_scale_t S_dec_b_fp8 = static_cast<nvfp4_scale_t>(S_dec_b * S_enc);
// return S_dec_b_fp8;
// NOTE: Divide by 6.0f is not elegant and not efficient.
// However, this is part of the emulation code to ensure exact match.
using namespace detail;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
const float S_dec_b = block_amax / fp4_max * S_enc;
return static_cast<nvfp4_scale_t>(fminf(S_dec_b, TypeExtrema<float>::max));
}
// Compute the global encode scale factor for a given global amax
__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) {
using namespace detail;
constexpr float fp8_max = TypeExtrema<fp8e4m3>::max; // 448.0f;
constexpr float fp4_max = TypeExtrema<fp4e2m1>::max; // 6.0f;
float global_encode_scale = fp8_max * fp4_max / global_amax;
// If scale is infinity, return max value of float32
global_encode_scale = fminf(global_encode_scale, TypeExtrema<float>::max);
// If global amax is 0 or infinity, return 1
if (global_amax == 0.0f || global_encode_scale == 0.0f) {
return 1.0f;
}
return global_encode_scale;
}
__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) {
if (rnd_idx == 4) {
rnd_idx = 0;
curanddx::uniform_bits dist;
random_uint4 = dist.generate4(rng);
}
// Treat uint4 as an array of 4x uint32_t elements for indexing
const uint32_t *const rbits_arr = reinterpret_cast<uint32_t *>(&random_uint4);
const uint32_t rbits = rbits_arr[rnd_idx++];
return rbits;
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(
const uint64_t in_4x, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b16 v0_bf16; \n\t"
".reg.b16 v1_bf16; \n\t"
".reg.b16 v2_bf16; \n\t"
".reg.b16 v3_bf16; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(in_4x), "l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits);
} else {
return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits);
}
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(
const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) {
uint16_t out_4x = 0;
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order
"}"
: "=h"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)), "r"(rbits));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return *reinterpret_cast<fp4e2m1x4 *>(&out_4x);
}
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01,
const float2 in23,
const float2 scale,
const uint32_t rbits) {
// NOTE: rbits unused for rn.
uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing.
#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
asm volatile(
"{\n"
".reg.b64 v01; \n\t"
".reg.b64 v23; \n\t"
".reg.b32 v0; \n\t"
".reg.b32 v1; \n\t"
".reg.b32 v2; \n\t"
".reg.b32 v3; \n\t"
".reg.b8 f0; \n\t"
".reg.b8 f1; \n\t"
"mov.b64 {v0, v1} , %1; \n\t"
"mov.b64 {v2, v3} , %2; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order
"mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"mov.b32 %0, {f0, f1, f0, f1};\n\t"
"}"
: "=r"(out_4x)
: "l"(reinterpret_cast<const uint64_t &>(in01)),
"l"(reinterpret_cast<const uint64_t &>(in23)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
#else
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL
return reinterpret_cast<fp4e2m1x4 *>(&out_4x)[0];
}
template <bool USE_STOCHASTIC_ROUNDING>
__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23,
const float2 scale,
const uint32_t rbits) {
if constexpr (USE_STOCHASTIC_ROUNDING) {
return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits);
} else {
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
}
}
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
// const size_t tid_X_t = 0;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
float thread_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = 0.0f;
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType block_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i]));
}
block_amax = static_cast<float>(block_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise =
(row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE;
block_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
block_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds =
(row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
block_amax = fmaxf(block_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
block_amax = fmaxf(block_amax, fabsf(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM)
nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t,
nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr,
const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows,
const size_t cols, const size_t scale_stride,
const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT =
(!COMPUTE_ACTIVATIONS) && (!std::is_same_v<IType, float>);
using IType2 = typename ptx::FPx2<IType>;
if constexpr (!COMPUTE_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG rng(rng_seed, rng_sequence, rng_offset);
curanddx::uniform_bits dist;
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0};
int rnd_idx =
0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x
// NEW: 2D Block-based scaling constants
constexpr size_t BLOCK_DIM = 16;
constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2
constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8
constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile
constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y;
const size_t chunk_rows = rows - block_offset_Y;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X;
const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const size_t tid_X_colwise = threadIdx.x;
const size_t tid_Y_t = tid_X_colwise;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t;
const size_t scales_offset_X_t = scales_block_offset_X_t;
const size_t SFs_per_row = cols / SCALE_DIM;
const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row;
const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols;
// Helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_mem_rowwise_data = buff_size_aligned_out;
constexpr size_t out_mem_colwise_data = buff_size_aligned_out;
constexpr size_t out_mem_rowwise_scales = 0;
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *out_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *out_t_data_sh = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise = (amax_rowwise_ptr == nullptr)
? 1.0f
: compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
// NOTE: This is to match with how emulation code was written.
const float S_dec_rowwise = 1.0 / S_enc_rowwise;
const float S_enc_colwise = (amax_colwise_ptr == nullptr)
? S_enc_rowwise
: compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
const float S_dec_colwise = 1.0 / S_enc_colwise;
const size_t warp_id = threadIdx.x / 32;
const size_t lane_id = threadIdx.x % 32;
float thread_amax = 0.0f;
const size_t block_in_warp = lane_id / BLOCKS_PER_WARP;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
__shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1];
// Helper function for warp reduction
auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float {
#pragma unroll
for (int delta = 8; delta >= 1; delta /= 2) {
float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta);
thread_amax = fmaxf(thread_amax, other_amax);
}
return thread_amax;
};
initialize_barriers<STAGES, THREADS_NUM>(mbar, is_master_thread);
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
#pragma unroll
for (size_t stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const size_t buff_offset_in = buff * BUFF_IN_SIZE;
const size_t buff_offset_out = buff * BUFF_OUT_SIZE;
const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_IN_SIZE;
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
float block_amax = 0.0f;
#pragma unroll
for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
const size_t block_in_tile_y = block_iter;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
for (int elem = 0; elem < BLOCK_DIM; elem += 2) {
const size_t elem_0_row = block_iter * BLOCK_DIM + elem;
const size_t elem_1_row = elem_0_row + 1;
const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
const size_t elem_1_col = elem_0_col;
const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col;
const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col;
IType2 val_2x;
val_2x.x = in_sh[shmem_offset_0];
val_2x.y = in_sh[shmem_offset_1];
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x);
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else {
for (int elem = 0; elem < BLOCK_DIM; ++elem) {
const size_t elem_row = block_iter * BLOCK_DIM + elem;
const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id;
// Bounds checking
const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows);
const bool col_out_of_bounds = (block_offset_X + elem_col >= cols);
if (!row_out_of_bounds && !col_out_of_bounds) {
const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col;
float elt = static_cast<float>(in_sh[shmem_offset]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
}
}
// Warp reduction to get block amax
block_amax = warp_reduce_amax(thread_amax, block_in_warp);
if (lane_id == 0 || lane_id == 16) {
block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax;
}
}
// sync thread to ensure block_amax_matrix is done storing
__syncthreads();
// COLWISE scaling
if constexpr (RETURN_TRANSPOSE) {
#pragma unroll
for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM;
const size_t in_thread_offset_Y = 0 + it * SCALE_DIM;
const size_t in_thread_offset_X = thread_offset_X_colwise;
const size_t out_t_thread_offset_Y = thread_offset_X_colwise;
const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET;
const size_t shmem_offset_base_colwise_in =
buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X;
const size_t shmem_offset_base_colwise_out_t =
buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_colwise[SCALE_DIM];
IType in_colwise_IType[SCALE_DIM];
// 3. Scale elements
// Load data in
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
}
} else {
for (int i = 0; i < SCALE_DIM; ++i) {
const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
in_compute_colwise[i] = elt;
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_colwise);
// // Store scaling factors through SHMEM
const size_t scale_idx_sh =
tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it;
out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8;
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
fp4e2m1x4 regs[SCALE_DIM / 4];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_colwise_IType[4 * e]);
regs[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(elts, block_scale_inverse_2x,
rbits);
} else {
const float2 in01 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e]);
const float2 in23 = *reinterpret_cast<float2 *>(&in_compute_colwise[4 * e + 2]);
regs[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const int group = thread_lane / 16;
uint32_t val[2];
uint32_t *regs_4x = reinterpret_cast<uint32_t *>(regs);
// Helps reducing bank conflicts
switch (group) {
case 0:
val[0] = regs_4x[0];
val[1] = regs_4x[1];
break;
case 1:
val[0] = regs_4x[1];
val[1] = regs_4x[0];
break;
}
uint32_t *out_t_data_sh_as_uint32_t =
reinterpret_cast<uint32_t *>(&out_t_data_sh[shmem_offset_base_colwise_out_t]);
out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2;
out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2;
}
}
// ROWWISE scaling
{
const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y;
#pragma unroll
for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) {
const size_t block_in_tile_y = it;
const size_t block_in_tile_x = tid_X_rowwise;
const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
const size_t shmem_offset_base_rowwise_in =
buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X;
const size_t shmem_offset_base_rowwise_out =
buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X;
block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x];
float in_compute_rowwise[SCALE_DIM];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find NVFP4-block AMAX
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
}
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const size_t j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (COMPUTE_ACTIVATIONS) {
elt = OP(elt, {});
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E4M3 scaling factor
const nvfp4_scale_t S_dec_b_fp8 =
compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
// Check boundaries
const size_t scales_offset_Y =
scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE;
const size_t scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X;
// const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows;
const bool rowwise_scale_is_within_bounds_Y =
(stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows;
if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) {
scales_ptr[scale_idx_global] = S_dec_b_fp8;
}
// Compute "correct" per-block encoding scaling factor
constexpr float float_max = detail::TypeExtrema<float>::max;
const float block_scale_inverse = fminf(
1.0f / (static_cast<float>(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8
const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<fp4e2m1x4, PACK_SIZE / 4> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 4; ++e) {
const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx);
IType2 in01;
IType2 in23;
if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_IType[w].data.elt[2 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else if constexpr (IS_CACHED_ACT_OP) {
const uint64_t elts = *reinterpret_cast<uint64_t *>(&in_cached[w].data.elt[4 * e]);
out.data.elt[e] = mul_cvt_bf16_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
elts, block_scale_inverse_2x, rbits);
} else {
const int j = w * PACK_SIZE + 4 * e;
const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]);
const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]);
out.data.elt[e] = mul_cvt_fp32_to_fp4_4x<USE_STOCHASTIC_ROUNDING>(
in01, in23, block_scale_inverse_2x, rbits);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2;
out.store_to(&out_data_sh[shmem_offset_rowwise]);
}
}
}
__builtin_assume(thread_amax >= 0);
thread_amax = fmaxf(thread_amax, block_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t global_offset_Y_t = block_offset_Y_t;
const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X, global_offset_Y,
reinterpret_cast<uint64_t *>(&out_data_sh[buff_offset_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_t,
global_offset_Y_t, reinterpret_cast<uint64_t *>(&out_t_data_sh[buff_offset_out_t]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} // end of stages
// Vectorized store scaling factors through SHMEM
if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y;
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(&out_colwise_scales_sh[scale_idx_sh]);
const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t;
const size_t count = // number of scales in Y dimension of this chunk
(chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM);
nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global];
constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t);
if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast<uintptr_t>(dst) % vec_bytes == 0)) {
// Fast path: vectorized store when destination is properly aligned
scales_vec.store_to(dst);
} else {
// Safe path: element-wise store for tails or unaligned destinations
scales_vec.store_to_elts(dst, 0, count);
}
}
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace nvfp4_transpose
#endif // CUDA_VERSION > 12080
// Compile-time flag to choose kernel variant
#ifndef USE_2D_NVFP4_KERNEL
#define USE_2D_NVFP4_KERNEL 0
#endif
template <bool COMPUTE_ACTIVATIONS, typename ParamOP, float (*OP)(float, const ParamOP &),
bool use_2d_quantization>
void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config, cudaStream_t stream) {
#if CUDA_VERSION > 12080
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
// return the transposed data.
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
using namespace nvfp4_transpose;
using namespace ptx;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", false);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if (return_transpose) {
NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
"Transposed output must have FP4 type.");
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Transposed scaling tensor must be allocated");
}
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_NUM;
const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride_transpose =
return_transpose ? output->columnwise_scale_inv.shape[1] : 0;
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
nvfp4_scale_t *const scales_transpose_ptr =
reinterpret_cast<nvfp4_scale_t *>(output->columnwise_scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const amax_rowwise_ptr = reinterpret_cast<const float *>(output->amax.dptr);
const float *const amax_colwise_ptr =
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
using IType = bf16;
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t);
constexpr size_t in_mem = buff_size_aligned_in;
constexpr size_t out_data_mem = buff_size_aligned_out;
constexpr size_t out_data_transpose_mem = buff_size_aligned_out;
constexpr size_t out_scales_transpose_mem = buff_size_scales;
constexpr size_t out_mem = out_data_mem + out_data_transpose_mem;
constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = nvfp4_transpose_kernel<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
if constexpr (use_2d_quantization) {
kernel = nvfp4_transpose_kernel_2D<COMPUTE_ACTIVATIONS, ParamOP, OP, IType,
USE_STOCHASTIC_ROUNDING, RETURN_TRANSPOSE>;
}
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
}););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // CUDA_VERSION > 12080
}
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif // CUDA_VERSION >= 12080
namespace transformer_engine { namespace transformer_engine {
namespace ptx { namespace ptx {
...@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { ...@@ -125,9 +129,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS); return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
} }
#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \
((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM103_ALL)))
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { __device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ #if CUDA_ARCH_HAS_FEATURE_SM10X_ALL
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out; uint16_t out;
asm volatile( asm volatile(
"{\n" "{\n"
...@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 { ...@@ -230,18 +238,86 @@ struct alignas(2 * sizeof(T)) FPx2 {
T y; T y;
}; };
template <typename T>
struct FPx4 {
T x1;
T x2;
T x3;
T x4;
};
template <typename T>
struct Type2x {};
template <>
struct Type2x<float> {
using type = float2;
};
template <>
struct Type2x<bf16> {
using type = __nv_bfloat162;
};
template <>
struct Type2x<fp16> {
using type = __half2;
};
using floatx2 = FPx2<float>; using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>; using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>; using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>; using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>; using fp8e5m2x2 = FPx2<fp8e5m2>;
using floatx4 = FPx4<float>;
using bf16x4 = FPx4<bf16>;
using fp16x4 = FPx4<fp16>;
using fp8e4m3x4 = FPx4<fp8e4m3>;
using fp8e5m2x4 = FPx4<fp8e5m2>;
static_assert(sizeof(floatx2) == 8); static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2);
#if CUDA_VERSION >= 12080
using fp4e2m1 = __nv_fp4_e2m1;
using fp4e2m1x2 = __nv_fp4x2_e2m1;
using fp4e2m1x4 = __nv_fp4x4_e2m1;
static_assert(sizeof(fp4e2m1x2) == 1);
static_assert(sizeof(fp4e2m1x4) == 2);
#endif // CUDA_VERSION >= 12080
// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1
// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6.
// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures:
// sm_100a
// sm_101a
// sm_120a
// When converting to .e2m1x2 data formats, the destination operand d has .b8 type.
// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format,
// and the converted values are packed in the destination operand d such that the value
// converted from input a is stored in the upper 4 bits of d and the value converted
// from input b is stored in the lower 4 bits of d.
// SIMD like "Fused" cast + multiplication (x4)
#if CUDA_VERSION >= 12080
template <typename Tx2>
__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23,
const float scale) {
const float x0 = static_cast<float>(in01.x) * scale;
const float x1 = static_cast<float>(in01.y) * scale;
const float x2 = static_cast<float>(in23.x) * scale;
const float x3 = static_cast<float>(in23.y) * scale;
out = fp4e2m1x4(make_float4(x0, x1, x2, x3));
}
#endif // CUDA_VERSION >= 12080
// SIMD like "Fused" cast + multiplication (x2) // SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) { const floatx2 &scale) {
...@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const ...@@ -377,7 +453,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
"r"(reinterpret_cast<const uint32_t &>(p2))); "r"(reinterpret_cast<const uint32_t &>(p2)));
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
.value("kInt8", transformer_engine::DType::kInt8); \ .value("kInt8", transformer_engine::DType::kInt8); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \ pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
...@@ -41,6 +42,10 @@ ...@@ -41,6 +42,10 @@
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local()) \
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \ pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local()) \
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \
......
...@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32; ...@@ -49,6 +49,26 @@ constexpr uint32_t THREADS_PER_WARP = 32;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// Device-side error
#define NVTE_DEVICE_ERROR(message) \
do { \
printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \
__func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \
(message)); \
assert(0); \
} while (false)
// Device-side error on thread 0
#define NVTE_DEVICE_THREAD0_ERROR(message) \
do { \
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \
threadIdx.y == 0 && threadIdx.z == 0) { \
NVTE_DEVICE_ERROR(message); \
} \
} while (false)
////////////////////////////////////////////////////////////////////////////////////////////////////
#if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__) #if !defined(USE_HIPBLASLT) && !defined(__HIPCC_RTC__)
inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*)
return {a.x + b.x, a.y + b.y}; return {a.x + b.x, a.y + b.y};
......
...@@ -5,11 +5,10 @@ ...@@ -5,11 +5,10 @@
from typing import Sequence, Union, Callable, Optional, Tuple from typing import Sequence, Union, Callable, Optional, Tuple
import operator import operator
from functools import reduce, partial from functools import reduce, partial
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -27,7 +26,7 @@ from .misc import ( ...@@ -27,7 +26,7 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100, should_apply_1x_fused_dbias_war_for_arch_l_100,
NamedSharding, NamedSharding,
) )
from .quantization import _jax_dbias, _quantize_dbias_impl from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
...@@ -37,10 +36,6 @@ from ..quantize import ( ...@@ -37,10 +36,6 @@ from ..quantize import (
ScalingMode, ScalingMode,
) )
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"]
...@@ -415,27 +410,28 @@ class ActLuPrimitive(BasePrimitive): ...@@ -415,27 +410,28 @@ class ActLuPrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types
prefix = "ActLuPrimitive_" prefix = "ActLu_"
x_rank = len(value_types[0].shape) input_shape = value_types[0].shape
output_shape = input_shape[:-2] + input_shape[-1:]
# Here we pass len of output so that the scales are propagated correctly
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 output_shape, unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) x_axes = scale_rules.input_spec
out = (*x_axes[:-2], x_axes[-1]) # Correct input spec with act dim
scale_inv = scale_rules.rowwise_rule x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:]
out = scale_rules.input_spec
colwise_out = (prefix + "out_colwise",) colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",) colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
colwise_scale_inv = scale_rules.colwise_rule colwise_scale_inv = scale_rules.colwise_rule
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple( colwise_out = multidim_transpose(out, transpose_axis=-1)
multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2)
)
else: else:
colwise_out = out colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
# amax is always a unit tensor.
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
...@@ -443,7 +439,8 @@ class ActLuPrimitive(BasePrimitive): ...@@ -443,7 +439,8 @@ class ActLuPrimitive(BasePrimitive):
x_axes, x_axes,
("…1",), ("…1",),
), ),
(out, colwise_out, scale_inv, colwise_scale_inv, amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax),
**scale_rules.factor_sizes,
) )
...@@ -888,26 +885,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): ...@@ -888,26 +885,30 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive):
result_types, result_types,
): ):
del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types
prefix = "BaseDActLuDBiasQuantizePrimitive_" prefix = "DActLuDBias_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
dz_axes = (*x_axes[:-2], x_axes[-1]) dz_axes = (*x_axes[:-2], x_axes[-1])
out = x_axes out = x_axes
colwise_out = (prefix + "out_colwise",) colwise_out = (prefix + "out_colwise",)
colwise_scale_inv = (prefix + "scale_inv_colwise",)
if is_2x: if is_2x:
if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value:
colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2))
else: else:
colwise_out = out colwise_out = out
colwise_scale_inv = scale_rules.colwise_rule
dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",)
amax = (prefix + "amax",) amax = (prefix + "amax",)
return SdyShardingRule( return SdyShardingRule(
(dz_axes, x_axes, ("…2",)), (dz_axes, x_axes, ("…2",)),
(out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -984,6 +985,7 @@ def act_lu( ...@@ -984,6 +985,7 @@ def act_lu(
x: jnp.ndarray, x: jnp.ndarray,
activation_type: Sequence[Union[str, Callable]], activation_type: Sequence[Union[str, Callable]],
quantizer: Optional[Quantizer] = None, quantizer: Optional[Quantizer] = None,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Union[jnp.ndarray, ScaledTensor]: ) -> Union[jnp.ndarray, ScaledTensor]:
"""Activation with optional quantization. """Activation with optional quantization.
...@@ -992,6 +994,7 @@ def act_lu( ...@@ -992,6 +994,7 @@ def act_lu(
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply. activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
If quantizer is None: If quantizer is None:
...@@ -1049,7 +1052,13 @@ def act_lu( ...@@ -1049,7 +1052,13 @@ def act_lu(
activation_type=activation_type, activation_type=activation_type,
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) out, _ = _quantize_dbias_impl(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out return out
if isinstance(quantizer, DelayedScaleQuantizer): if isinstance(quantizer, DelayedScaleQuantizer):
......
...@@ -8,11 +8,10 @@ import warnings ...@@ -8,11 +8,10 @@ import warnings
from dataclasses import dataclass, replace from dataclasses import dataclass, replace
from functools import partial, reduce from functools import partial, reduce
from typing import Optional, Tuple from typing import Optional, Tuple
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes, lax from jax import dtypes, lax, ffi
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule
...@@ -49,12 +48,6 @@ from ..sharding import ( ...@@ -49,12 +48,6 @@ from ..sharding import (
) )
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"FusedAttnHelper", "FusedAttnHelper",
"fused_attn_fwd", "fused_attn_fwd",
......
...@@ -7,22 +7,16 @@ import re ...@@ -7,22 +7,16 @@ import re
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from packaging import version
from jax.extend import core from jax.extend import core
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src import dispatch from jax._src import dispatch
from jax import ffi
import jax
import transformer_engine_jax import transformer_engine_jax
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
class BasePrimitive(metaclass=ABCMeta): class BasePrimitive(metaclass=ABCMeta):
""" """
...@@ -179,7 +173,7 @@ class BasePrimitive(metaclass=ABCMeta): ...@@ -179,7 +173,7 @@ class BasePrimitive(metaclass=ABCMeta):
_primitive_registry = {} _primitive_registry = {}
def register_primitive(cls): def register_primitive(cls, outer_only=False):
""" """
Register a JAX primitive and add it to the internal registry. Register a JAX primitive and add it to the internal registry.
""" """
...@@ -192,13 +186,14 @@ def register_primitive(cls): ...@@ -192,13 +186,14 @@ def register_primitive(cls):
def name_of_wrapper_p(): def name_of_wrapper_p():
return cls.name + "_wrapper" return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name) if not outer_only:
dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p = core.Primitive(cls.name)
inner_p.multiple_results = cls.multiple_results dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.multiple_results = cls.multiple_results
inner_p.def_abstract_eval(cls.abstract) inner_p.def_impl(partial(xla.apply_primitive, inner_p))
mlir.register_lowering(inner_p, cls.lowering, platform="cuda") inner_p.def_abstract_eval(cls.abstract)
cls.inner_primitive = inner_p mlir.register_lowering(inner_p, cls.lowering, platform="cuda")
cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p()) outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p) dispatch.prim_requires_devices_during_lowering.add(outer_p)
......
...@@ -6,8 +6,10 @@ ...@@ -6,8 +6,10 @@
import math import math
import operator import operator
from collections.abc import Iterable from collections.abc import Iterable
from typing import Tuple, Sequence, Union from dataclasses import dataclass
from functools import partial, reduce from functools import partial, reduce
from typing import Tuple, Sequence, Union
from enum import Enum
import warnings import warnings
import jax import jax
...@@ -16,8 +18,13 @@ from jax import dtypes ...@@ -16,8 +18,13 @@ from jax import dtypes
from jax.sharding import NamedSharding, PartitionSpec from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule
import transformer_engine_jax as tex from transformer_engine_jax import (
from transformer_engine_jax import get_num_compute_streams get_num_compute_streams,
JAXX_Collective_Op,
get_device_compute_capability,
initialize_cgemm_communicator,
get_cgemm_num_max_streams,
)
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .quantization import grouped_quantize from .quantization import grouped_quantize
...@@ -37,11 +44,19 @@ from ..quantize import ( ...@@ -37,11 +44,19 @@ from ..quantize import (
is_fp8_gemm_with_all_layouts_supported, is_fp8_gemm_with_all_layouts_supported,
apply_padding_to_scale_inv, apply_padding_to_scale_inv,
) )
from ..sharding import global_mesh_resource from .misc import get_padded_spec, is_all_reduce_in_float32
from .misc import get_padded_spec from ..sharding import (
global_mesh_resource,
tpsp_axis_size,
dp_or_fsdp_axis_size,
)
__all__ = [ __all__ = [
"CollectiveOp",
"CollectiveOpSet",
"collective_gemm_bootstrap",
"noop_collective_op_set",
"gemm", "gemm",
"grouped_gemm", "grouped_gemm",
"gemm_uses_jax_dot", "gemm_uses_jax_dot",
...@@ -56,7 +71,7 @@ num_cublas_streams = get_num_compute_streams() ...@@ -56,7 +71,7 @@ num_cublas_streams = get_num_compute_streams()
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures.""" """Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if tex.get_device_compute_capability(0) >= 90: if get_device_compute_capability(0) >= 90:
return 33_554_432 return 33_554_432
return 4_194_304 return 4_194_304
...@@ -152,6 +167,161 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ ...@@ -152,6 +167,161 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_
return lhs_q, rhs_q return lhs_q, rhs_q
def collective_gemm_bootstrap(
num_total_devices,
num_devices_per_process,
process_id,
tensor_parallel_size,
num_max_streams=3,
compute_stream_priority=0,
communication_stream_priority=0,
num_sm_for_communication=2,
use_ce=True,
aggregate_all_gather=False,
):
"""Initialize NCCL communicators for Collective GEMM operations.
This function sets up the distributed communication infrastructure needed for
tensor parallel collective GEMM operations. It supports two main scenarios:
1. **Multi-device per process**: TP domain = single process
- Each process manages multiple GPUs (num_devices_per_process > 1)
- TP group consists of GPUs within the same process
- Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4
2. **Single device per process**: TP domain spans multiple processes
- Each process manages one GPU (num_devices_per_process = 1)
- TP group spans across multiple processes
- Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4
Args:
num_total_devices (int): Total number of ranks across all processes.
Must be divisible by num_devices_per_process.
num_devices_per_process (int): Number of GPUs per process.
- For multi-device: equals tp_size (e.g., 4 GPUs per process)
- For single-device: equals 1 (1 GPU per process)
process_id (int): Process identifier (0-based).
Must be in range [0, num_total_devices // num_devices_per_process).
tensor_parallel_size (int): Size of tensor parallel groups.
Must divide num_total_devices evenly.
num_max_streams (int, optional): Maximum number of CUDA streams for overlap.
Higher values enable more parallelism but use more GPU resources. Default: 3.
compute_stream_priority (int, optional): Priority for GEMM computation streams.
Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
communication_stream_priority (int, optional): Priority for NCCL communication streams.
Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0.
num_sm_for_communication (int, optional): Number of streaming multiprocessors
reserved for communication operations. Default: 2.
use_ce (bool, optional): Enable CUDA copy engines for memory transfers.
Can improve performance by offloading memory operations. Default: True.
aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations
into larger ones for better efficiency. Default: False.
Raises:
AssertionError: If num_total_devices is not divisible by num_devices_per_process,
or if process_id is out of valid range.
AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now)
RuntimeError: If NCCL initialization fails or if configuration
is invalid (e.g., insufficient GPUs).
Example:
# Basic initialization (single device per process)
collective_gemm_bootstrap(
num_total_devices=8,
num_devices_per_process=1,
process_id=0,
tensor_parallel_size=4
)
# Advanced configuration with custom performance settings
collective_gemm_bootstrap(
num_total_devices=8,
num_devices_per_process=1,
process_id=0,
tensor_parallel_size=4,
num_max_streams=5, # More parallelism
compute_stream_priority=1, # Lower compute priority
communication_stream_priority=0, # Higher comm priority
num_sm_for_communication=4, # More SMs for communication
use_ce=True, # Enable copy engines
aggregate_all_gather=True # Aggregate small operations
)
Note:
This function must be called after JAX distributed initialization
and before any collective GEMM operations. Each process should call
this function with its own unique process_id.
"""
assert (
num_devices_per_process == 1 and jax.local_device_count() == 1
), "Only single device per process is supported at the moment!"
assert num_total_devices % num_devices_per_process == 0, (
f"Invalid num_total_devices={num_total_devices},"
f" num_devices_per_process={num_devices_per_process}"
)
assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}"
initialize_cgemm_communicator(
num_total_devices,
num_devices_per_process,
process_id,
tensor_parallel_size,
num_max_streams,
compute_stream_priority,
communication_stream_priority,
num_sm_for_communication,
use_ce,
aggregate_all_gather,
)
class CollectiveOp(Enum):
"Enum for Collective Type in Collective GEMM"
NONE = JAXX_Collective_Op.NONE
ALL_GATHER = JAXX_Collective_Op.ALL_GATHER
REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER
@property
def is_all_gather(self) -> bool:
"""Check if AllGather"""
return self == CollectiveOp.ALL_GATHER
@property
def is_reduce_scatter(self) -> bool:
"""Check if ReduceScatter"""
return self == CollectiveOp.REDUCE_SCATTER
@property
def is_none(self) -> bool:
"""Check if None"""
return self == CollectiveOp.NONE
@dataclass(frozen=True)
class CollectiveOpSet:
"""
A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers.
"""
forward: CollectiveOp
backward: CollectiveOp
@staticmethod
def create(forward_collective_op: CollectiveOp):
"""Create a set of CollectiveOp for forward and backward passes"""
if forward_collective_op.is_all_gather:
backward_collective_op = CollectiveOp.REDUCE_SCATTER
elif forward_collective_op.is_reduce_scatter:
backward_collective_op = CollectiveOp.ALL_GATHER
else:
backward_collective_op = CollectiveOp.NONE
return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op)
noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE)
@partial(jax.jit, static_argnums=(1, 2)) @partial(jax.jit, static_argnums=(1, 2))
def swizzled_scale(scale_inv, flatten_axis, is_colwise): def swizzled_scale(scale_inv, flatten_axis, is_colwise):
"Swizzle scale_inv via JAX transpose ops" "Swizzle scale_inv via JAX transpose ops"
...@@ -174,7 +344,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -174,7 +344,7 @@ class GemmPrimitive(BasePrimitive):
name = "te_gemm_ffi" name = "te_gemm_ffi"
multiple_results = True multiple_results = True
impl_static_args = (6, 7, 8, 9, 10, 11, 12) impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
inner_primitive = None inner_primitive = None
outer_primitive = None outer_primitive = None
...@@ -193,8 +363,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -193,8 +363,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
): ):
del use_split_accumulator del use_split_accumulator, transpose_batch_sequence
def _dims_are_consecutive(dims): def _dims_are_consecutive(dims):
if len(dims) <= 1: if len(dims) <= 1:
...@@ -238,7 +412,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -238,7 +412,7 @@ class GemmPrimitive(BasePrimitive):
), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands."
if ( if (
scaling_mode != ScalingMode.MXFP8_1D_SCALING scaling_mode != ScalingMode.MXFP8_1D_SCALING
and not tex.is_non_nt_fp8_gemm_supported() and not is_fp8_gemm_with_all_layouts_supported()
): ):
assert not lhs_is_transposed and rhs_is_transposed, ( assert not lhs_is_transposed and rhs_is_transposed, (
"cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) "
...@@ -263,6 +437,19 @@ class GemmPrimitive(BasePrimitive): ...@@ -263,6 +437,19 @@ class GemmPrimitive(BasePrimitive):
out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape)
output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype)
# Adjust output shape for comm+GEMM overlap
if not collective_op.is_none and not is_outer: # Inner abstract
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
overlap_out_shape = list(out_shape).copy()
if collective_op.is_all_gather:
overlap_out_shape[1] *= tpsp_axis_size()
else: # RS
overlap_out_shape[sequence_dim] = (
overlap_out_shape[sequence_dim] // tpsp_axis_size()
)
assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}"
output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype)
# Validate bias # Validate bias
bias_shape = (0,) bias_shape = (0,)
bias_dtype = out_dtype bias_dtype = out_dtype
...@@ -302,9 +489,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -302,9 +489,12 @@ class GemmPrimitive(BasePrimitive):
pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype)
# Declare cuBLAS workspace # Declare cuBLAS workspace
workspace_size = get_cublas_workspace_size_bytes()
if not collective_op.is_none:
workspace_size *= get_cgemm_num_max_streams()
# cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not
# necessarily 256 bytes aligned, we add some padding to ensure alignment. # necessarily 256 bytes aligned, we add some padding to ensure alignment.
workspace_size = get_cublas_workspace_size_bytes() + 256 workspace_size += 256
workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8)
return output, bias_grad, pre_gelu_out, workspace return output, bias_grad, pre_gelu_out, workspace
...@@ -330,8 +520,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -330,8 +520,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
): ):
del out_dtype del out_dtype, transpose_batch_sequence, sequence_dim, is_outer
lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_aval, _, rhs_aval, *_ = ctx.avals_in
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims)
...@@ -350,6 +544,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -350,6 +544,7 @@ class GemmPrimitive(BasePrimitive):
"fuse_gelu": fuse_gelu, "fuse_gelu": fuse_gelu,
"grad": grad, "grad": grad,
"use_split_accumulator": use_split_accumulator, "use_split_accumulator": use_split_accumulator,
"collective_op": int(collective_op.value),
} }
operand_output_aliases = {} operand_output_aliases = {}
...@@ -378,6 +573,10 @@ class GemmPrimitive(BasePrimitive): ...@@ -378,6 +573,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
): ):
if scaling_mode.is_1d_block_scaling(): if scaling_mode.is_1d_block_scaling():
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
...@@ -396,7 +595,34 @@ class GemmPrimitive(BasePrimitive): ...@@ -396,7 +595,34 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed)
rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed)
outputs = GemmPrimitive.inner_primitive.bind( # Alter lhs blocks so that CGEMM RS outputs correctly
if (
collective_op.is_reduce_scatter
and not transpose_batch_sequence
and not is_outer
and not lhs.shape[0] == 1
):
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
original_shape = lhs.shape
assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
f"Original_shape[0]={original_shape[0]} is not divisible by"
f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
)
assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
f"Original_shape[1]={original_shape[1]} is not divisible by"
f" tpsp_axis_size()={tpsp_axis_size()}"
)
reshaped = lhs.reshape(
dp_or_fsdp_axis_size(),
int(original_shape[0] / dp_or_fsdp_axis_size()),
tpsp_axis_size(),
int(original_shape[1] / tpsp_axis_size()),
*original_shape[2:],
)
reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim))
lhs = reordered.reshape(original_shape)
(output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind(
lhs, lhs,
lhs_scale_inv, lhs_scale_inv,
rhs, rhs,
...@@ -410,8 +636,39 @@ class GemmPrimitive(BasePrimitive): ...@@ -410,8 +636,39 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
collective_op=collective_op,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=sequence_dim,
is_outer=is_outer,
) )
return outputs[:-1] # discard workspace array # Alter output blocks for CGEMM AG
if (
collective_op.is_all_gather
and not transpose_batch_sequence
and not is_outer
and not output.shape[0] == 1
):
assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
original_shape = output.shape
assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, (
f"Original_shape[0]={original_shape[0]} is not divisible by"
f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}"
)
assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, (
f"Original_shape[1]={original_shape[1]} is not divisible by"
f" tpsp_axis_size()={tpsp_axis_size()}"
)
reshaped = output.reshape(
tpsp_axis_size(),
dp_or_fsdp_axis_size(),
int(original_shape[0] / dp_or_fsdp_axis_size()),
int(original_shape[1] / tpsp_axis_size()),
*original_shape[2:],
)
reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim))
output = reordered.reshape(original_shape)
return [output, bias_grad, pre_gelu_out]
@staticmethod @staticmethod
def outer_impl( def outer_impl(
...@@ -428,6 +685,10 @@ class GemmPrimitive(BasePrimitive): ...@@ -428,6 +685,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
): ):
return GemmPrimitive.impl( return GemmPrimitive.impl(
lhs, lhs,
...@@ -443,6 +704,10 @@ class GemmPrimitive(BasePrimitive): ...@@ -443,6 +704,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
) )
@staticmethod @staticmethod
...@@ -456,7 +721,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -456,7 +721,12 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
collective_op,
transpose_batch_sequence,
sequence_dim,
is_outer,
): ):
del transpose_batch_sequence, sequence_dim, is_outer
assert GemmPrimitive.outer_primitive is not None assert GemmPrimitive.outer_primitive is not None
lhs_bdims, _, rhs_bdims, *_ = batch_dims lhs_bdims, _, rhs_bdims, *_ = batch_dims
...@@ -484,6 +754,10 @@ class GemmPrimitive(BasePrimitive): ...@@ -484,6 +754,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
collective_op=collective_op,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=sequence_dim,
is_outer=is_outer,
), ),
(out_bdims, bias_bdims, pre_gelu_bdims), (out_bdims, bias_bdims, pre_gelu_bdims),
) )
...@@ -492,6 +766,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -492,6 +766,8 @@ class GemmPrimitive(BasePrimitive):
def _parse_operand_output_specs( def _parse_operand_output_specs(
arg_infos, arg_infos,
contracting_dims, contracting_dims,
transpose_batch_sequence,
collective_op,
): ):
lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos)
...@@ -499,14 +775,12 @@ class GemmPrimitive(BasePrimitive): ...@@ -499,14 +775,12 @@ class GemmPrimitive(BasePrimitive):
# Ensure that tensor sequence parallelism is not used via setting tp_resource # Ensure that tensor sequence parallelism is not used via setting tp_resource
if gsr.tp_resource is not None: if gsr.tp_resource is not None:
for i in range(len(lhs_specs) - 1): if gsr.tp_resource in lhs_specs:
if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: warnings.warn(
warnings.warn( "Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'"
"Tensor sequence parallelism is detected as" " appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource"
f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" " for tensor sequence parallelism to avoid potential issues."
f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" )
" tensor sequence parallelism to avoid potential issues."
)
lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs))
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims)
...@@ -528,10 +802,43 @@ class GemmPrimitive(BasePrimitive): ...@@ -528,10 +802,43 @@ class GemmPrimitive(BasePrimitive):
assert reduce_spec is None, "Multiple reduce dimension is detected!" assert reduce_spec is None, "Multiple reduce dimension is detected!"
reduce_spec = l reduce_spec = l
sequence_dim = None
# Find sequence dimension in lhs_specs if tensor sequence parallel is enabled
# We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim
if collective_op.is_all_gather:
try:
tpsp_idx = lhs_specs.index(gsr.tpsp_resource)
except ValueError as exc:
raise ValueError(
f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}."
" Please check your sharding configuration."
) from exc
sequence_dim = tpsp_idx
assert (sequence_dim == 1) ^ transpose_batch_sequence, (
"CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)"
" or (sequence_dim=0 and transpose_batch_sequence=True). Received:"
f" sequence_dim={sequence_dim},"
f" transpose_batch_sequence={transpose_batch_sequence}."
)
elif collective_op.is_reduce_scatter:
assert reduce_spec == gsr.tpsp_resource, (
"Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got"
f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}"
)
sequence_dim = int(not transpose_batch_sequence)
if reduce_spec is not None: if reduce_spec is not None:
# Other non-reduce cdims (if exists) need to be unsharded # Other non-reduce cdims (if exists) need to be unsharded
lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs)
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) # Only do AG Sequence dim if not Overlap
if collective_op.is_all_gather:
rhs_cspecs = tuple(
s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs
)
else:
rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs)
# Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden
# No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim.
...@@ -551,13 +858,31 @@ class GemmPrimitive(BasePrimitive): ...@@ -551,13 +858,31 @@ class GemmPrimitive(BasePrimitive):
for spec in rhs_non_cspecs for spec in rhs_non_cspecs
) )
# Non-contracting dims of LHS to be gathered along the SP axis. # Only do AG Sequence dim if not Overlap
# Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for if not collective_op.is_all_gather:
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. # Non-contracting dims of LHS to be gathered along the SP axis.
lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for
# dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet.
lhs_non_cspecs = tuple(
None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs
)
out_specs = lhs_non_cspecs + rhs_non_cspecs out_specs = lhs_non_cspecs + rhs_non_cspecs
# Only do AG Sequence dim if not Overlap RS
if collective_op.is_all_gather:
assert sequence_dim <= len(
lhs_non_cspecs
), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}"
out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :]
elif collective_op.is_reduce_scatter:
assert sequence_dim <= len(
lhs_non_cspecs
), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}"
out_specs = (
out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :]
)
# specs = merge(cspecs, non_cspecs) # specs = merge(cspecs, non_cspecs)
lhs_specs, rhs_specs = map( lhs_specs, rhs_specs = map(
lambda cdims, cspecs, non_cspecs: ( lambda cdims, cspecs, non_cspecs: (
...@@ -572,10 +897,14 @@ class GemmPrimitive(BasePrimitive): ...@@ -572,10 +897,14 @@ class GemmPrimitive(BasePrimitive):
bias_specs = tuple(list(rhs_non_cspecs).copy()) bias_specs = tuple(list(rhs_non_cspecs).copy())
gelu_specs = tuple(list(out_specs).copy()) gelu_specs = tuple(list(out_specs).copy())
if not collective_op.is_none:
assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}"
return ( return (
(lhs_specs, rhs_specs, bias_specs, gelu_specs), (lhs_specs, rhs_specs, bias_specs, gelu_specs),
(out_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs),
reduce_spec, reduce_spec,
sequence_dim,
) )
@staticmethod @staticmethod
...@@ -587,6 +916,10 @@ class GemmPrimitive(BasePrimitive): ...@@ -587,6 +916,10 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
...@@ -595,11 +928,16 @@ class GemmPrimitive(BasePrimitive): ...@@ -595,11 +928,16 @@ class GemmPrimitive(BasePrimitive):
out_dtype, out_dtype,
scaling_mode, scaling_mode,
grad, grad,
use_split_accumulator,
result_infos,
is_outer,
sequence_dim,
) )
del use_split_accumulator, result_infos
(_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = (
GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) GemmPrimitive._parse_operand_output_specs(
arg_infos, contracting_dims, transpose_batch_sequence, collective_op
)
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
...@@ -624,20 +962,29 @@ class GemmPrimitive(BasePrimitive): ...@@ -624,20 +962,29 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
mesh, mesh,
arg_infos, arg_infos,
result_infos, result_infos,
): ):
del result_infos del result_infos, is_outer, sequence_dim
( (
(lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs),
(out_specs, dbias_specs, pre_gelu_specs), (out_specs, dbias_specs, pre_gelu_specs),
reduce_spec, reduce_spec,
) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) inferred_sequence_dim,
) = GemmPrimitive._parse_operand_output_specs(
arg_infos,
contracting_dims,
transpose_batch_sequence,
collective_op,
)
# Assemble argument shardings # Block scale inverses match their operands, but tensor scale inverses are unsharded.
# NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded.
none_sharding = NamedSharding(mesh, PartitionSpec(None)) none_sharding = NamedSharding(mesh, PartitionSpec(None))
lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs))
rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs))
...@@ -686,11 +1033,19 @@ class GemmPrimitive(BasePrimitive): ...@@ -686,11 +1033,19 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=inferred_sequence_dim,
is_outer=False,
collective_op=collective_op,
) )
# All-Reduce GEMM output if reduce_spec is not None and not collective_op.is_reduce_scatter:
if reduce_spec is not None: if is_all_reduce_in_float32(): # For unittest only
outputs[0] = jax.lax.psum(outputs[0], reduce_spec) outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype(
out_dtype
)
else:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
return outputs return outputs
...@@ -705,14 +1060,24 @@ class GemmPrimitive(BasePrimitive): ...@@ -705,14 +1060,24 @@ class GemmPrimitive(BasePrimitive):
fuse_gelu, fuse_gelu,
grad, grad,
use_split_accumulator, use_split_accumulator,
transpose_batch_sequence,
sequence_dim,
is_outer,
collective_op,
mesh, mesh,
operand_types, operand_types,
result_types, result_types,
): ):
del out_dtype, grad, use_split_accumulator del out_dtype, grad, use_split_accumulator
del mesh, result_types del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer
prefix = "GemmPrimitive_" if not collective_op.is_none:
raise NotImplementedError(
"CollectiveGEMM with Shardy propagation is not supported yet! Please turn off"
" Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false"
)
prefix = "Gemm_"
warnings.warn( warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
...@@ -746,13 +1111,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -746,13 +1111,8 @@ class GemmPrimitive(BasePrimitive):
lhs_scale_specs = ("…1",) lhs_scale_specs = ("…1",)
rhs_scale_specs = ("…2",) rhs_scale_specs = ("…2",)
if scaling_mode.is_1d_block_scaling(): if scaling_mode.is_1d_block_scaling():
# Shardy rules for MXFP8 scales cannot be related to the operands because of the lhs_scale_specs = lhs_specs
# global-unpadding and local-padding workflow. This can potentially insert expensive rhs_scale_specs = rhs_specs
# re-shards in the partition call later if the scales are not already sharded correctly.
lhs_scale_specs, rhs_scale_specs = map(
lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs),
(lhs_specs, rhs_specs),
)
lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims)
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
...@@ -797,6 +1157,8 @@ def _te_gemm( ...@@ -797,6 +1157,8 @@ def _te_gemm(
fuse_gelu: bool = False, fuse_gelu: bool = False,
grad: bool = False, grad: bool = False,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
...@@ -805,6 +1167,7 @@ def _te_gemm( ...@@ -805,6 +1167,7 @@ def _te_gemm(
lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) lhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32)
scaling_mode = ScalingMode.NO_SCALING scaling_mode = ScalingMode.NO_SCALING
lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims)
lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims)
...@@ -864,6 +1227,10 @@ def _te_gemm( ...@@ -864,6 +1227,10 @@ def _te_gemm(
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=-1,
is_outer=True,
collective_op=collective_op,
) )
...@@ -1181,6 +1548,8 @@ def gemm( ...@@ -1181,6 +1548,8 @@ def gemm(
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)),
lhs_quantizer: Quantizer = None, lhs_quantizer: Quantizer = None,
rhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None,
transpose_batch_sequence: bool = False,
collective_op: CollectiveOp = CollectiveOp.NONE,
**kwargs, **kwargs,
) -> Tuple[jnp.ndarray, ...]: ) -> Tuple[jnp.ndarray, ...]:
r"""General matrix multiplication with optional quantization. r"""General matrix multiplication with optional quantization.
...@@ -1214,8 +1583,11 @@ def gemm( ...@@ -1214,8 +1583,11 @@ def gemm(
TE's custom call to cuBLAS GEMM. TE's custom call to cuBLAS GEMM.
use_split_accumulator: bool, default = True use_split_accumulator: bool, default = True
Enable promoting some intermediate sums to higher precision when accumulating the result in Enable promoting some intermediate sums to higher precision when accumulating the result in
the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed.
supported with TE's custom call to cuBLAS GEMM. transpose_batch_sequence: bool, default = False
Transpose the batch and sequence dimensions of the input tensor.
collective_op: CollectiveOp, default = CollectiveOp.NONE
Collective operation type for collective GEMM.
Returns Returns
------- -------
...@@ -1259,6 +1631,7 @@ def gemm( ...@@ -1259,6 +1631,7 @@ def gemm(
"`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS "
"GEMM primitive is disabled." "GEMM primitive is disabled."
) )
assert collective_op.is_none, "JAX GEMM does not support collective GEMM"
return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer)
outputs = _te_gemm( outputs = _te_gemm(
...@@ -1267,6 +1640,8 @@ def gemm( ...@@ -1267,6 +1640,8 @@ def gemm(
lhs_quantizer=lhs_quantizer, lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer, rhs_quantizer=rhs_quantizer,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
transpose_batch_sequence=transpose_batch_sequence,
collective_op=collective_op,
**kwargs, **kwargs,
) )
......
...@@ -293,3 +293,11 @@ class NamedSharding(jax.sharding.NamedSharding): ...@@ -293,3 +293,11 @@ class NamedSharding(jax.sharding.NamedSharding):
Create a new NamedSharding with the same mesh and spec but with a new description. Create a new NamedSharding with the same mesh and spec but with a new description.
""" """
return NamedSharding(self.mesh, self.spec, desc=desc) return NamedSharding(self.mesh, self.spec, desc=desc)
@functools.lru_cache(maxsize=1)
def is_all_reduce_in_float32():
"""
Check if all-reduce is in float32
"""
return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1"
...@@ -7,11 +7,10 @@ import warnings ...@@ -7,11 +7,10 @@ import warnings
import operator import operator
from functools import partial, cache, reduce from functools import partial, cache, reduce
from typing import Optional, Union from typing import Optional, Union
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -28,7 +27,7 @@ from .misc import ( ...@@ -28,7 +27,7 @@ from .misc import (
NamedSharding, NamedSharding,
get_cudnn_version, get_cudnn_version,
) )
from .quantization import _quantize_dbias_impl from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
...@@ -38,11 +37,6 @@ from ..quantize import ( ...@@ -38,11 +37,6 @@ from ..quantize import (
ScalingMode, ScalingMode,
) )
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"layernorm_fwd", "layernorm_fwd",
...@@ -587,9 +581,9 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -587,9 +581,9 @@ class NormFwdPrimitive(BasePrimitive):
result_types, result_types,
) )
prefix = "NormFwdPrimitive_" prefix = "NormFwd_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1
) )
x_axes = scale_rules.input_spec x_axes = scale_rules.input_spec
...@@ -610,6 +604,7 @@ class NormFwdPrimitive(BasePrimitive): ...@@ -610,6 +604,7 @@ class NormFwdPrimitive(BasePrimitive):
mu, mu,
rsigma, rsigma,
), ),
**scale_rules.factor_sizes,
) )
...@@ -886,6 +881,7 @@ def layernorm_fwd( ...@@ -886,6 +881,7 @@ def layernorm_fwd(
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]:
"""Layer normalization forward pass with optional quantization. """Layer normalization forward pass with optional quantization.
...@@ -899,6 +895,7 @@ def layernorm_fwd( ...@@ -899,6 +895,7 @@ def layernorm_fwd(
zero_centered_gamma: If True, gamma is zero-centered. zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability. epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -958,7 +955,13 @@ def layernorm_fwd( ...@@ -958,7 +955,13 @@ def layernorm_fwd(
epsilon=epsilon, epsilon=epsilon,
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) out, _ = _quantize_dbias_impl(
out,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
)
return out, mu, rsigma return out, mu, rsigma
is_2x2x = quantizer.is_2x2x() is_2x2x = quantizer.is_2x2x()
...@@ -1088,6 +1091,7 @@ def rmsnorm_fwd( ...@@ -1088,6 +1091,7 @@ def rmsnorm_fwd(
zero_centered_gamma: bool, zero_centered_gamma: bool,
epsilon: float, epsilon: float,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]:
"""Root mean square normalization forward pass with optional quantization. """Root mean square normalization forward pass with optional quantization.
...@@ -1099,6 +1103,7 @@ def rmsnorm_fwd( ...@@ -1099,6 +1103,7 @@ def rmsnorm_fwd(
zero_centered_gamma: If True, gamma is zero-centered. zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability. epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1159,7 +1164,11 @@ def rmsnorm_fwd( ...@@ -1159,7 +1164,11 @@ def rmsnorm_fwd(
quantizer=None, quantizer=None,
) )
out, _ = _quantize_dbias_impl( out, _ = _quantize_dbias_impl(
out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype out.data,
is_dbias=False,
quantizer=quantizer,
dq_dtype=x.dtype,
amax_scope=amax_scope,
) )
return out, rsigma return out, rsigma
...@@ -1284,6 +1293,7 @@ def normalization_fwd( ...@@ -1284,6 +1293,7 @@ def normalization_fwd(
epsilon: float, epsilon: float,
norm_type: str, norm_type: str,
quantizer: Optional[Quantizer], quantizer: Optional[Quantizer],
amax_scope: AmaxScope = AmaxScope.LOCAL,
): ):
"""Common wrapper for normalization forward pass. """Common wrapper for normalization forward pass.
...@@ -1300,6 +1310,7 @@ def normalization_fwd( ...@@ -1300,6 +1310,7 @@ def normalization_fwd(
- 'layernorm': Layer normalization - 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization - 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output. quantizer: Optional quantizer for FP8 quantization of the output.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -1317,12 +1328,27 @@ def normalization_fwd( ...@@ -1317,12 +1328,27 @@ def normalization_fwd(
zero_centered_gamma is not supported if norm_type is 'rmsnorm'. zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
""" """
if norm_type == "layernorm": if norm_type == "layernorm":
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) output, mu, rsigma = layernorm_fwd(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer,
amax_scope=amax_scope,
)
elif norm_type == "rmsnorm": elif norm_type == "rmsnorm":
assert ( assert (
not zero_centered_gamma not zero_centered_gamma
), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer) output, rsigma = rmsnorm_fwd(
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer,
amax_scope=amax_scope,
)
mu = None mu = None
else: else:
raise ValueError(f"{norm_type=} is not supported.") raise ValueError(f"{norm_type=} is not supported.")
......
...@@ -6,11 +6,12 @@ import operator ...@@ -6,11 +6,12 @@ import operator
from functools import reduce from functools import reduce
from typing import Tuple, Optional, Union from typing import Tuple, Optional, Union
import math import math
from packaging import version from enum import Enum
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes, ffi
from jax.experimental.custom_partitioning import SdyShardingRule from jax.experimental.custom_partitioning import SdyShardingRule
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
...@@ -27,7 +28,12 @@ from .misc import ( ...@@ -27,7 +28,12 @@ from .misc import (
get_min_device_compute_capability, get_min_device_compute_capability,
NamedSharding, NamedSharding,
) )
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import (
all_reduce_max_along_all_axes_except_PP,
all_reduce_sum_along_dp_fsdp,
global_mesh_resource,
lax_paral_op,
)
from ..quantize import ( from ..quantize import (
ScaledTensor2x, ScaledTensor2x,
ScaledTensor, ScaledTensor,
...@@ -41,11 +47,6 @@ from ..quantize import ( ...@@ -41,11 +47,6 @@ from ..quantize import (
NoScaleTensor, NoScaleTensor,
) )
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"]
...@@ -494,9 +495,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -494,9 +495,9 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
): ):
del out_dtype, scale_dtype, is_outer, mesh, result_types del out_dtype, scale_dtype, is_outer, mesh, result_types
prefix = "BaseDBiasQuantizePrimitive_" prefix = "DBiasQuantize_"
scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules(
len(value_types[0].shape), value_types[0].shape,
unique_var=prefix + "x", unique_var=prefix + "x",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
...@@ -518,6 +519,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): ...@@ -518,6 +519,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive):
return SdyShardingRule( return SdyShardingRule(
(x_axes, ("…1",), amax), (x_axes, ("…1",), amax),
(out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias),
**scale_rules.factor_sizes,
) )
...@@ -532,6 +534,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): ...@@ -532,6 +534,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive):
"""Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS."""
class AmaxScope(Enum):
"""
Amax Scope Enum
"""
LOCAL = 1
TPSP = 2
FSDP = 3
class AmaxCalculationPrimitive(BasePrimitive):
"""
Amax Calculation Primitive with custom_partitioning
"""
name = "jax_local_amax"
multiple_results = False
impl_static_args = (1,) # amax_scope
inner_primitive = None
outer_primitive = None
@staticmethod
def abstract(
x_aval,
*,
amax_scope,
):
"""
amax calcuation abstract
"""
del amax_scope
dtype = dtypes.canonicalize_dtype(x_aval.dtype)
assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32)
return out_aval
@staticmethod
def impl(
x,
amax_scope,
):
"""
amax calcuation implementation
"""
del amax_scope
amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,))
return amax
@staticmethod
def infer_sharding_from_operands(
amax_scope,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation infer_sharding_from_operands
"""
del (amax_scope, arg_infos, result_infos) # Unused.
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
return amax_sharding
@staticmethod
def partition(
amax_scope,
mesh,
arg_infos,
result_infos,
):
"""
amax calcuation partition
"""
del result_infos
amax_sharding = NamedSharding(
mesh,
PartitionSpec(None),
desc="AmaxCalculationPrimitive.out_sharding",
)
def sharded_impl(x):
amax = AmaxCalculationPrimitive.impl(
x,
amax_scope=amax_scope,
)
if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh)
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh)
if amax_scope is AmaxScope.FSDP: # Run AR across FSDP
gmesh = global_mesh_resource()
amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh)
return amax
arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos)
return mesh, sharded_impl, amax_sharding, arg_shardings
@staticmethod
def shardy_sharding_rule(amax_scope, mesh, value_types, result_types):
"""
amax calcuation shardy_sharding_rule
"""
del amax_scope, mesh, result_types
prefix = "AmaxCal"
input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape)))
output_spec = (f"{prefix}_amax",)
return SdyShardingRule((input_spec,), (output_spec,))
register_primitive(AmaxCalculationPrimitive, outer_only=True)
def _jax_quantize( def _jax_quantize(
x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1
): ):
...@@ -578,6 +700,7 @@ def _quantize_dbias_impl( ...@@ -578,6 +700,7 @@ def _quantize_dbias_impl(
is_dbias: bool = False, is_dbias: bool = False,
dq_dtype: Optional[jnp.dtype] = None, dq_dtype: Optional[jnp.dtype] = None,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
""" """
Cast wrapper Cast wrapper
...@@ -634,7 +757,10 @@ def _quantize_dbias_impl( ...@@ -634,7 +757,10 @@ def _quantize_dbias_impl(
# until the tensor is dequantized (e.g. in the GEMM). # until the tensor is dequantized (e.g. in the GEMM).
amax = x.amax amax = x.amax
if amax is None: if amax is None:
amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) amax = AmaxCalculationPrimitive.outer_primitive.bind(
x.data,
amax_scope=amax_scope,
)
scale = compute_scale_from_amax(amax, quantizer.q_dtype) scale = compute_scale_from_amax(amax, quantizer.q_dtype)
elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING:
scale = quantizer.scale scale = quantizer.scale
...@@ -706,6 +832,7 @@ def quantize( ...@@ -706,6 +832,7 @@ def quantize(
x: Union[jnp.ndarray, NoScaleTensor], x: Union[jnp.ndarray, NoScaleTensor],
quantizer: Quantizer, quantizer: Quantizer,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Tuple[ScaledTensor]: ) -> Tuple[ScaledTensor]:
"""Quantize input tensor according to the quantizer. """Quantize input tensor according to the quantizer.
...@@ -716,6 +843,7 @@ def quantize( ...@@ -716,6 +843,7 @@ def quantize(
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
is None. is None.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A ScaledTensor containing the quantized input tensor. A ScaledTensor containing the quantized input tensor.
...@@ -724,6 +852,7 @@ def quantize( ...@@ -724,6 +852,7 @@ def quantize(
x, x,
quantizer=quantizer, quantizer=quantizer,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope,
) )
return out return out
...@@ -733,6 +862,7 @@ def quantize_dbias( ...@@ -733,6 +862,7 @@ def quantize_dbias(
quantizer: Quantizer, quantizer: Quantizer,
is_dbias: bool = True, is_dbias: bool = True,
flatten_axis: int = -1, flatten_axis: int = -1,
amax_scope: AmaxScope = AmaxScope.LOCAL,
) -> Tuple[ScaledTensor2x, jnp.ndarray]: ) -> Tuple[ScaledTensor2x, jnp.ndarray]:
"""Quantize input tensor and compute bias gradient. """Quantize input tensor and compute bias gradient.
...@@ -743,6 +873,8 @@ def quantize_dbias( ...@@ -743,6 +873,8 @@ def quantize_dbias(
is_dbias: If True, compute bias gradient. Defaults to True. is_dbias: If True, compute bias gradient. Defaults to True.
flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization.
Defaults to -1. Defaults to -1.
amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL.
Returns: Returns:
A tuple containing: A tuple containing:
...@@ -756,6 +888,7 @@ def quantize_dbias( ...@@ -756,6 +888,7 @@ def quantize_dbias(
quantizer=quantizer, quantizer=quantizer,
is_dbias=is_dbias, is_dbias=is_dbias,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
amax_scope=amax_scope,
) )
......
...@@ -6,22 +6,16 @@ from abc import abstractmethod ...@@ -6,22 +6,16 @@ from abc import abstractmethod
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import warnings import warnings
from packaging import version
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import dtypes from jax import dtypes, ffi
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from .base import BasePrimitive, register_primitive from .base import BasePrimitive, register_primitive
from .misc import get_padded_spec, check_valid_batch_dims from .misc import get_padded_spec, check_valid_batch_dims
from ..softmax import SoftmaxType from ..softmax import SoftmaxType
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi # pylint: disable=ungrouped-imports
else:
from jax.extend import ffi # pylint: disable=ungrouped-imports
__all__ = [ __all__ = [
"scaled_softmax_fwd", "scaled_softmax_fwd",
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include <cudnn.h> #include <cudnn.h>
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <transformer_engine/comm_gemm_overlap.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
...@@ -32,9 +33,6 @@ ...@@ -32,9 +33,6 @@
#include "transformer_engine/activation.h" #include "transformer_engine/activation.h"
#include "transformer_engine/multi_stream.h" #include "transformer_engine/multi_stream.h"
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
namespace transformer_engine { namespace transformer_engine {
namespace jax { namespace jax {
...@@ -43,16 +41,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D ...@@ -43,16 +41,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D
// Activation // Activation
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler);
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x); JAXX_Scaling_Mode scaling_mode, bool is_2x);
// Normalization // Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler);
pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype,
...@@ -121,6 +123,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -121,6 +123,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// GEMM // GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler);
// Grouped GEMM // Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler);
...@@ -134,4 +137,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); ...@@ -134,4 +137,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler);
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode);
XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op);
#endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_
...@@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, ...@@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI,
.Attr<bool>("is_2x"), .Attr<bool>("is_2x"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf,
Result_Type amax_buf, int64_t act_enum,
JAXX_Scaling_Mode scaling_mode, bool is_2x_int) {
return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf,
colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf,
act_enum, scaling_mode, is_2x_int);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // scale
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Attr<int64_t>("act_enum")
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("is_2x"));
pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size,
DType in_dtype, DType out_dtype, DType in_dtype, DType out_dtype,
JAXX_Scaling_Mode scaling_mode, bool is_2x) { JAXX_Scaling_Mode scaling_mode, bool is_2x) {
...@@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI ...@@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI
.Attr<bool>("is_2x") .Attr<bool>("is_2x")
.Attr<bool>("is_dbias"), .Attr<bool>("is_dbias"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf,
Buffer_Type act_input_buf, Buffer_Type scale_buf,
Result_Type output_buf, Result_Type colwise_output_buf,
Result_Type scale_inv_buf,
Result_Type colwise_scale_inv_buf, Result_Type amax_buf,
Result_Type dbias_buf, Result_Type workspace_buf,
JAXX_Scaling_Mode scaling_mode, int64_t act_enum,
bool is_2x, bool is_dbias) {
return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf,
act_input_buf, scale_buf, output_buf, colwise_output_buf,
scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf,
workspace_buf, scaling_mode, act_enum, is_2x, is_dbias);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler,
DActLuDBiasQuantizeInitializeFFI,
FFI::Bind<FFI_Initialize>()
.Ctx<FFI_Stream_Type>() // stream
.Arg<Buffer_Type>() // input
.Arg<Buffer_Type>() // act input
.Arg<Buffer_Type>() // scale
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // colwise output
.Ret<Buffer_Type>() // scale_inv
.Ret<Buffer_Type>() // scale_inv colwise
.Ret<Buffer_Type>() // amax
.Ret<Buffer_Type>() // dbias
.Ret<Buffer_Type>() // wkspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("act_enum")
.Attr<bool>("is_2x")
.Attr<bool>("is_dbias"));
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy ...@@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
size_t q_max_seqlen, size_t kv_max_seqlen, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t qk_head_dim, size_t v_head_dim, size_t qk_head_dim, size_t v_head_dim,
int64_t window_size_left, int64_t window_size_right) { int64_t window_size_left, int64_t window_size_right) {
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
return backend; return backend;
} }
...@@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64); auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector<size_t>{2}, DType::kInt64);
auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
NVTETensorPack aux_output_tensors; NVTETensorPack aux_output_tensors;
nvte_tensor_pack_create(&aux_output_tensors); nvte_tensor_pack_create(&aux_output_tensors);
...@@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( ...@@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen");
nvte_fused_attn_fwd_qkvpacked( nvte_fused_attn_fwd_qkvpacked(
qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, query_workspace_tensor.data(), nullptr); window_size_right, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(),
ragged_offset_tensor.data(), dummy_page_table_tensor.data(),
dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen,
kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type,
mask_type, softmax_type, window_size_left, window_size_right,
query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported QKVLayout."); NVTE_ERROR("Unsupported QKVLayout.");
} }
...@@ -262,10 +268,15 @@ static void FusedAttnForwardImpl( ...@@ -262,10 +268,15 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */ /* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64); auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto dummy_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
/* Auxiliary tensors (to be propagated to the backward pass later) */ /* Auxiliary tensors (to be propagated to the backward pass later) */
...@@ -280,12 +291,12 @@ static void FusedAttnForwardImpl( ...@@ -280,12 +291,12 @@ static void FusedAttnForwardImpl(
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim};
auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype);
nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), nvte_fused_attn_fwd_qkvpacked(
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(),
q_seq_offsets_tensor.data(), rng_state_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_max_seqlen, is_training, scaling_factor, dropout_probability, q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training,
qkv_layout, bias_type, mask_type, window_size_left, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type,
window_size_right, workspace_tensor.data(), stream); window_size_left, window_size_right, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = auto kv_shape =
...@@ -293,12 +304,13 @@ static void FusedAttnForwardImpl( ...@@ -293,12 +304,13 @@ static void FusedAttnForwardImpl(
auto q_tensor = TensorWrapper(q, q_shape, dtype); auto q_tensor = TensorWrapper(q, q_shape, dtype);
auto kv_tensor = TensorWrapper(k, kv_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype);
nvte_fused_attn_fwd_kvpacked( nvte_fused_attn_fwd_kvpacked(
q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(),
&aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(),
is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout,
window_size_left, window_size_right, workspace_tensor.data(), stream); bias_type, mask_type, softmax_type, window_size_left, window_size_right,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -307,12 +319,13 @@ static void FusedAttnForwardImpl( ...@@ -307,12 +319,13 @@ static void FusedAttnForwardImpl(
auto k_tensor = TensorWrapper(k, k_shape, dtype); auto k_tensor = TensorWrapper(k, k_shape, dtype);
auto v_tensor = TensorWrapper(v, v_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype);
nvte_fused_attn_fwd( nvte_fused_attn_fwd(
q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(),
o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors,
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(),
q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor,
bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
window_size_right, workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
// For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0
min_num_segments = input_batch * max_segments_per_seq; min_num_segments = input_batch * max_segments_per_seq;
} }
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) {
// the last one is the largest which will be the returned workspace size // the last one is the largest which will be the returned workspace size
auto q_cu_seqlens_tensor = auto q_cu_seqlens_tensor =
...@@ -453,37 +469,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( ...@@ -453,37 +469,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor = auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32); TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), nvte_fused_attn_bwd_qkvpacked(
s_tensor.data(), // not used for F16 qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), s_tensor.data(), // not used for F16
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
bias_type, mask_type, window_size_left, window_size_right, dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
deterministic, query_workspace_tensor.data(), nullptr); qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked( nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
nullptr); window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(),
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability,
window_size_left, window_size_right, deterministic, qkv_layout, bias_type, mask_type, softmax_type, window_size_left,
query_workspace_tensor.data(), nullptr); window_size_right, deterministic, query_workspace_tensor.data(), nullptr);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
...@@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl( ...@@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl(
/* Output tensors */ /* Output tensors */
auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16 auto s_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, dtype); // not used in F16
auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype);
auto dummy_d_softmax_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kFloat32);
NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX;
/* Auxiliary tensors (propagated from the forward pass) */ /* Auxiliary tensors (propagated from the forward pass) */
NVTETensorPack aux_input_tensors; NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors); nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend( auto backend = nvte_get_fused_attn_backend(
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups,
kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias); softmax_aux, rng_state, bias);
...@@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl( ...@@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor,
bias_type, mask_type, window_size_left, window_size_right, dropout_probability, qkv_layout, bias_type, mask_type,
deterministic, workspace_tensor.data(), stream); softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto kv_shape = auto kv_shape =
...@@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl( ...@@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(),
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
deterministic, workspace_tensor.data(), stream); mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, qk_head_dim};
auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_shape = std::vector<size_t>{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim};
...@@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl( ...@@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl(
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16
&aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(),
dbias_tensor.data(), q_cu_seqlens_tensor.data(), dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen,
dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type,
window_size_right, deterministic, workspace_tensor.data(), stream); mask_type, softmax_type, window_size_left, window_size_right, deterministic,
workspace_tensor.data(), stream);
} else { } else {
NVTE_ERROR("Unsupported qkv_layout."); NVTE_ERROR("Unsupported qkv_layout.");
} }
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "cgemm_helper.h"
#include "common/util/system.h"
#include "nccl.h"
namespace transformer_engine {
namespace jax {
ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) {
ncclUniqueId unique_id;
int tp_domain_id = get_tp_domain_id();
bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0);
pid_t pgid = getpgid(0);
std::string base_path = getenv<std::string>("NVTE_JAX_NCCL_FILE_PATH", "/tmp");
std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) +
"_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) +
"_domain_" + std::to_string(tp_domain_id) + ".bin";
if (is_tp_leader) {
NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id));
// Write the ID to a temporary file
std::ofstream file(id_file, std::ios::binary);
NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file);
file.write(reinterpret_cast<const char *>(&unique_id), sizeof(ncclUniqueId));
file.close();
} else {
// Wait for the ID file to be created and read it
int attempts = 0;
const int max_attempts = 100;
while (attempts < max_attempts) {
std::ifstream file(id_file, std::ios::binary);
if (file.is_open()) {
file.read(reinterpret_cast<char *>(&unique_id), sizeof(ncclUniqueId));
if (file.gcount() == sizeof(ncclUniqueId)) {
file.close();
break;
}
file.close();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
attempts++;
}
NVTE_CHECK(attempts < max_attempts,
"Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file);
}
if (is_tp_leader) {
_nccl_id_file_name.push_back(id_file);
}
return unique_id;
}
void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size) {
// Validate inputs
NVTE_CHECK(num_devices_per_process == 1,
"num_devices_per_process must be == 1, got num_devices_per_process=",
num_devices_per_process);
NVTE_CHECK(num_total_devices >= 1,
"num_total_devices must be >= 1, got num_total_devices=", num_total_devices);
NVTE_CHECK(
num_total_devices % num_devices_per_process == 0,
"num_total_devices must be divisible by num_devices_per_process, got num_total_devices=",
num_total_devices, ", num_devices_per_process=", num_devices_per_process);
// Validate TP size
NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size);
NVTE_CHECK(num_total_devices % tp_size == 0,
"num_total_devices must be divisible by tp_size, got num_total_devices=",
num_total_devices, ", tp_size=", tp_size);
auto &handler = get(false);
handler.num_total_devices = num_total_devices;
handler.num_devices_per_process = num_devices_per_process;
handler.process_id = process_id;
handler.num_processes = num_total_devices / num_devices_per_process;
handler.tp_size = tp_size;
handler.tp_num_domains = num_total_devices / tp_size;
// Initialize vectors with the correct size
handler.local_device_ids_within_process.resize(num_devices_per_process);
handler.local_device_ids_within_tp_domain.resize(num_devices_per_process);
handler.tp_domain_ids.resize(num_devices_per_process);
handler.global_device_ids.resize(num_devices_per_process);
handler.tp_comms.resize(num_devices_per_process);
NVTE_CHECK(0 <= process_id && process_id < handler.num_processes,
"Invalid process_id=", process_id, ", which is out of range [0, ",
handler.num_processes, ")");
// Initialize local devices and calculate their global device IDs and TP topology
for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) {
// Use the device that JAX has already assigned to this process
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
handler.local_device_ids_within_process[local_idx] = current_device;
handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx;
// Calculate TP-related values for this device
int global_device_id = handler.global_device_ids[local_idx];
if (num_devices_per_process == tp_size) {
// Scenario 1: Multi-device per process - TP domain = single process
handler.local_device_ids_within_tp_domain[local_idx] = local_idx;
handler.tp_domain_ids[local_idx] = process_id;
} else {
// Scenario 2: Single device per process - TP domain spans multiple processes
handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size;
handler.tp_domain_ids[local_idx] = global_device_id / tp_size;
}
}
ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp");
NVTE_CHECK_NCCL(ncclGroupStart());
for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) {
NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx]));
int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx];
NVTE_CHECK_NCCL(
ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank));
}
NVTE_CHECK_NCCL(ncclGroupEnd());
// Allocate device memory for barrier operations
NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int)));
handler._initialize = true;
// Bootstrap UB via creating a dummy CommOverlapP2PBase object
std::vector<size_t> buffer_shape{1, 1};
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32,
JAXX_Collective_Op::ALL_GATHER);
}
void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size, int num_max_streams, int gemm_priority,
int comm_priority, int num_comm_sm, bool use_ce,
bool aggregate_ag) {
auto &config = CgemmConfig::get(false);
config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag);
auto &handler = CommunicatorHandler::get(false);
handler.init(num_total_devices, num_devices_per_process, process_id, tp_size);
}
int GetCgemmNumMaxStreams() {
auto &config = CgemmConfig::get();
return config.num_max_streams;
}
CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector<size_t> buffer_shape,
DType dtype,
JAXX_Collective_Op collective_op) {
auto &comm_handler = CommunicatorHandler::get();
auto &cgemm_config = CgemmConfig::get();
int device_idx = comm_handler.get_local_device_idx_for_current_device();
int64_t plan_id = 0;
hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast<size_t>(dtype),
static_cast<int>(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams,
cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm,
cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx);
auto it = plan_map.find(plan_id);
if (it != plan_map.end()) {
return it->second.get();
}
if (comm_handler.num_devices_per_process == comm_handler.tp_size) {
// Multi-device per process
} else if (comm_handler.num_devices_per_process == 1) {
// Single device per process
NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0,
"For single device per process, num_total_devices must be divisible by tp_size, "
"got num_total_devices=",
comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size);
} else {
NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=",
comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size,
". Supported scenarios: "
"(1) num_devices_per_process == tp_size (multi-device per process), "
"(2) num_devices_per_process == 1 (single device per process)");
}
std::unique_ptr<CommOverlapCore> executor;
executor = std::make_unique<CommOverlapP2PBase>(
buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices,
comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size,
comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size,
comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op),
cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority,
cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/,
cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag);
CommOverlapCore *executor_ptr = executor.get();
plan_map[plan_id] = std::move(executor);
return executor_ptr;
}
void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) {
NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier");
int device_idx = get_local_device_idx_for_current_device();
ncclComm_t tp_comm = tp_comms[device_idx];
NVTE_CHECK_NCCL(
ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr));
cudaDeviceSynchronize();
}
void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes,
void *input_buf, size_t input_bytes, ExtComm) {
NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather");
int device_idx = get_local_device_idx_for_current_device();
ncclComm_t tp_comm = tp_comms[device_idx];
size_t expected_output_bytes = input_bytes * tp_size;
NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ",
expected_output_bytes, ", got ", output_bytes);
NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr));
cudaDeviceSynchronize();
}
CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) {
allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf,
size_t input_bytes, ExtComm comm) {
this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm);
};
barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); };
}
CommunicatorHandler::~CommunicatorHandler() {
if (_initialize && !tp_comms.empty()) {
for (auto &comm : tp_comms) {
if (comm != nullptr) {
ncclCommDestroy(comm);
}
}
}
if (_device_barrier) cudaFree(_device_barrier);
for (const auto &file_path : _nccl_id_file_name) {
std::remove(file_path.c_str());
}
}
} // namespace jax
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
#include <unistd.h>
#include <chrono>
#include <cstdio>
#include <fstream>
#include <functional>
#include <memory>
#include <thread>
#include <unordered_map>
#include "../extensions.h"
#include "common/comm_gemm_overlap/userbuffers/userbuffers.h"
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
#include "transformer_engine/comm_gemm_overlap.h"
namespace transformer_engine {
namespace jax {
// Configuration singleton for CGEMM parameters
class CgemmConfig {
public:
int num_max_streams;
int gemm_priority;
int comm_priority;
int num_comm_sm;
bool use_ce;
bool aggregate_ag;
static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm,
bool _use_ce, bool _aggregate_ag) {
auto &config = get(false);
config._initialized = true;
config.num_max_streams = _num_max_streams;
config.gemm_priority = _gemm_priority;
config.comm_priority = _comm_priority;
config.num_comm_sm = _num_comm_sm;
config.use_ce = _use_ce;
config.aggregate_ag = _aggregate_ag;
}
static CgemmConfig &get(bool is_initialized = true) {
static thread_local CgemmConfig instance;
NVTE_CHECK(
instance._initialized == is_initialized,
"CgemmConfig must be initialized before using it, got is_initialized=", is_initialized);
return instance;
}
CgemmConfig(const CgemmConfig &) = delete;
CgemmConfig &operator=(const CgemmConfig &) = delete;
private:
CgemmConfig() = default;
~CgemmConfig() = default;
bool _initialized = false;
};
// Forward declaration
class CollectiveGemmPlanRegistry;
// NCCL communicator handler for collective GEMM operations
// Support both single process single device AND single process multi device
// Two scenarios:
// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size)
// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1)
class CommunicatorHandler {
public:
int num_total_devices = -1;
int num_devices_per_process = -1;
int process_id = -1;
int num_processes = -1;
int tp_size = -1;
int tp_num_domains = -1;
std::vector<int> local_device_ids_within_tp_domain;
std::vector<int> tp_domain_ids;
std::vector<ncclComm_t> tp_comms;
std::vector<int> local_device_ids_within_process;
std::vector<int> global_device_ids;
int get_global_rank() const {
int device_idx = get_local_device_idx_for_current_device();
return global_device_ids[device_idx];
}
void nccl_device_barrier_impl(ExtComm);
void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf,
size_t input_bytes, ExtComm);
ncclComm_t get_comm_for_current_device() const {
int device_idx = get_local_device_idx_for_current_device();
return tp_comms[device_idx];
}
int get_local_device_idx_for_current_device() const {
int current_device;
NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
for (int i = 0; i < num_devices_per_process; i++) {
if (local_device_ids_within_process[i] == current_device) {
return i;
}
}
NVTE_ERROR("Current CUDA device ", current_device,
" not found in local_device_ids_within_process");
}
int get_local_device_id_within_tp_domain() const {
int device_idx = get_local_device_idx_for_current_device();
return local_device_ids_within_tp_domain[device_idx];
}
int get_tp_domain_id() const {
int device_idx = get_local_device_idx_for_current_device();
return tp_domain_ids[device_idx];
}
int get_tp_num_domains() const { return tp_num_domains; }
static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size);
private:
ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type);
public:
static CommunicatorHandler &get(bool is_initialized = true) {
static CommunicatorHandler instance;
NVTE_CHECK(instance._initialize == is_initialized,
"CommunicatorHandler._initialize=", instance._initialize,
", is_initialized=", is_initialized);
return instance;
}
ExtAllgatherOp allgather_func;
ExtBarrierOp barrier_func;
CommunicatorHandler(const CommunicatorHandler &) = delete;
CommunicatorHandler &operator=(const CommunicatorHandler &) = delete;
private:
CommunicatorHandler();
~CommunicatorHandler();
bool _initialize = false;
int *_device_barrier = nullptr;
std::vector<std::string> _nccl_id_file_name;
};
// Plan registry for caching collective GEMM executors
class CollectiveGemmPlanRegistry {
public:
static CollectiveGemmPlanRegistry &getInstance() {
static thread_local CollectiveGemmPlanRegistry instance;
return instance;
}
CommOverlapCore *get_executor(std::vector<size_t> buffer_shape, DType dtype,
JAXX_Collective_Op collective_op);
private:
CollectiveGemmPlanRegistry() {}
CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete;
CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete;
std::unordered_map<int64_t, std::unique_ptr<CommOverlapCore>> plan_map;
};
// Function declarations
void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id,
int tp_size, int num_max_streams, int gemm_priority,
int comm_priority, int num_comm_sm, bool use_ce,
bool aggregate_ag);
int GetCgemmNumMaxStreams();
} // namespace jax
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_
...@@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>; ...@@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream<cudaStream_t>;
using Dictionary = xla::ffi::Dictionary; using Dictionary = xla::ffi::Dictionary;
constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare;
constexpr auto FFI_Initialize = xla::ffi::ExecutionStage::kInitialize;
constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible};
DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type);
...@@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) { ...@@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) {
} }
} }
template <typename... Args>
Error_Type wrapInStreamCapture(std::function<Error_Type(cudaStream_t, Args...)> func,
cudaStream_t stream, Args... args) {
cudaGraph_t graph{};
NVTE_CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed));
Error_Type error = func(stream, std::forward<Args>(args)...);
NVTE_CHECK_CUDA(cudaStreamEndCapture(stream, &graph));
NVTE_CHECK_CUDA(cudaGraphDestroy(graph));
return error;
}
} // namespace jax } // namespace jax
} // namespace transformer_engine } // namespace transformer_engine
...@@ -6,13 +6,19 @@ ...@@ -6,13 +6,19 @@
#include "transformer_engine/gemm.h" #include "transformer_engine/gemm.h"
#include <memory> #include <memory>
#include <mutex>
#include <stdexcept>
#include <string_view> #include <string_view>
#include <tuple> #include <tuple>
#include "../extensions.h" #include "../extensions.h"
#include "cgemm_helper.h"
#include "common.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/string.h" #include "common/util/string.h"
#include "common/util/system.h" #include "common/util/system.h"
#include "cuda_runtime.h"
#include "nccl.h"
#include "transformer_engine/swizzle.h" #include "transformer_engine/swizzle.h"
#include "xla/ffi/api/c_api.h" #include "xla/ffi/api/c_api.h"
...@@ -66,12 +72,75 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand( ...@@ -66,12 +72,75 @@ std::tuple<TensorWrapper, std::vector<size_t>> xla_buffer_to_nvte_gemm_operand(
return std::make_tuple(std::move(input), input_shape); return std::make_tuple(std::move(input), input_shape);
} }
Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias,
Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad,
Result_Type pre_gelu_out, Result_Type workspace,
JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed,
bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad,
bool use_split_accumulator, JAXX_Collective_Op collective_op) {
nvte_cublas_handle_init();
// Init UB buffer
if (collective_op != JAXX_Collective_Op::NONE) {
auto &comm_handler = CommunicatorHandler::get();
std::vector<size_t> lhs_shape = {
product(lhs.dimensions(), 0, lhs_axis_boundary),
product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())};
std::vector<size_t> rhs_shape = {
product(rhs.dimensions(), 0, rhs_axis_boundary),
product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())};
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
}
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype,
collective_op);
}
return ffi_with_cuda_error_check();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI,
FFI::Bind<FFI_Prepare>()
.Arg<Buffer_Type>() // lhs
.Arg<Buffer_Type>() // lhs_scale_inv
.Arg<Buffer_Type>() // rhs
.Arg<Buffer_Type>() // rhs_scale_inv
.Arg<Buffer_Type>() // bias
.Arg<Buffer_Type>() // gelu_input
.Ret<Buffer_Type>() // output
.Ret<Buffer_Type>() // bias_grad
.Ret<Buffer_Type>() // pre_gelu_out
.Ret<Buffer_Type>() // workspace
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<int64_t>("lhs_axis_boundary")
.Attr<int64_t>("rhs_axis_boundary")
.Attr<bool>("lhs_transposed")
.Attr<bool>("rhs_transposed")
.Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"));
Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs,
Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input,
Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out,
Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary,
int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed,
bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator,
JAXX_Collective_Op collective_op) {
// NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when
// device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x)
bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING ||
...@@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode,
rhs_axis_boundary, make_rhs_rowwise); rhs_axis_boundary, make_rhs_rowwise);
// Output tensor
std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], std::vector<size_t> out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0],
(rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]};
auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type());
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, "
"expected ",
out_.numel(), " elements ", to_string_like(out_shape), " but got ",
output->element_count(), " elements ", to_string_like(output->dimensions()));
// Bias input to forward pass or bias gradient output from backward pass // Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr; void *bias_ptr = nullptr;
...@@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
// Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order
auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0); auto num_math_sm = cuda::sm_count() - getenv<int>("NVTE_EXT_MARGIN_SM", 0);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false, if (collective_op == JAXX_Collective_Op::NONE) {
use_split_accumulator, num_math_sm, stream); auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ",
to_string_like(out_shape), " but got ", output->element_count(), " elements ",
to_string_like(output->dimensions()));
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
use_split_accumulator, num_math_sm, stream);
} else {
std::vector<size_t> buffer_shape{0, 0};
DType buffer_dtype = out_dtype;
auto &comm_handler = CommunicatorHandler::get();
if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size;
buffer_shape[1] = lhs_shape[1];
out_shape[0] = out_shape[0] * comm_handler.tp_size;
buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type());
} else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
buffer_shape[0] = out_shape[0];
buffer_shape[1] = out_shape[1];
out_shape[0] = out_shape[0] / comm_handler.tp_size;
}
auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, collective_op);
if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype);
// Prepare the auxiliary buffer for the reduce-scattered GEMM output
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Launch GEMM+RS
executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_,
pre_gelu_, workspace_, grad, false, use_split_accumulator, out_,
stream);
} else if (collective_op == JAXX_Collective_Op::ALL_GATHER) {
auto aux_out_ = TensorWrapper(nullptr, std::vector<size_t>{0}, out_dtype); // Empty
auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype);
NVTE_CHECK(out_.numel() == output->element_count(),
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(),
" elements ", to_string_like(out_shape), " but got ", output->element_count(),
" elements ", to_string_like(output->dimensions()));
// Copy the distributed LHS operand into the local chunk of the communication buffer
executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise);
// Launch AG+GEMM
executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_,
workspace_, grad, false, use_split_accumulator, aux_out_, stream);
}
}
return ffi_with_cuda_error_check(); return ffi_with_cuda_error_check();
} }
...@@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, ...@@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<bool>("fuse_bias") .Attr<bool>("fuse_bias")
.Attr<bool>("fuse_gelu") .Attr<bool>("fuse_gelu")
.Attr<bool>("grad") .Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"), .Attr<bool>("use_split_accumulator")
.Attr<JAXX_Collective_Op>("collective_op"),
FFI_CudaGraph_Traits); FFI_CudaGraph_Traits);
Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment