Unverified Commit 1bbeab1c authored by kwyss-nvidia's avatar kwyss-nvidia Committed by GitHub
Browse files

Blockwise float8 quantizer and quantized tensor class (#1513)



* Blockwise float8 quantizer and quantized tensor class.

The classes are configurable for 128x128 blocksize
and 1x128 blocksize via setting block_scaling_dim == 2,1 respectively.

Scale tensors are stored in a format emenable for matrix multiplication,
however the integration of matmul is deferred as a separate story.

Fusions of quantization and DBIAS or activation functions are not yet
implemented, and the dequantization is currently implemented in torch.

Tests for quantization are included in C++ and pytorch layers, with
exact comparison to reference quantizer behavior as well as an attempt
to hit interesting branches through the API such as tensor creation
in pytorch and CPP and dequantization of row and columnwise usage.

Two CUDA kernels for quantization are included, and are direct ports
of equivalents in the kitchen repository, where a subchannel recipe
has been used for end to end training.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Apply linting changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Alignment for 1D scaling for GEMM edge case.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Change API name.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix merge conflict with name change.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use common tensor map API.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Change API to use two scaling mode enums.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix typo.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update some call sites.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Tests for torch tensor API surface.

Since the quantized tensor is a tensor
subclass, these tests exercise torch hooks.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reuse scale calculation between quantizer refs.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Save memory by dropping reference to saved tensors.

Issues previously observed are solved.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove constexpr parameters from kernel.

Code size is reduced with fewer constexpr params.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Merge conflict from rebase.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add shape implementations for block scaling.

nvte_shape was added upstream. Logic added
for block scaled fp8.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Move benchmark to te_playground
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove amax_epsilon and pow_2_scales from tensor.

Hardcodes the default values.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Lint changes.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fixup MR changes that broke.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Safer ifdef in kernel.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Documentation prose.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Reuse compute_scale function from Current Scaling.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Bugfix on inf_value scale refactor.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Remove qopt calls from test.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update pytest list.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Add copyright to reference scale calc.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use ptx.cuh functions instead of cde.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update shape logic with allocation and reuse shape.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Usage defaults MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Copyright and header guard.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Updating torch dispatch code.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Fix exception type.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Use TypeInfo
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* MR feedback.
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>

* Update CS scale update test to use updated ref impl
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Update JAX scaling mode enum
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Skip tests on Lovelace
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarKeith Wyss <kwyss@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 3e305f72
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <cfloat>
#include <cuda/barrier>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#if (!defined(__CUDA_MINIMUM_ARCH__) && __CUDA_ARCH__ >= 900) || \
(defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900)
#define TMA_HW_SUPPORTED
#endif
namespace transformer_engine {
namespace {
// const values configuration
constexpr size_t kThreadsPerWarp = 32;
#ifdef TMA_HW_SUPPORTED
constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 32;
constexpr size_t WARP_TILE_DIM_Y = 64;
constexpr size_t THREAD_TILE_DIM_X = 16;
constexpr size_t THREAD_TILE_DIM_Y = 4;
#else
constexpr size_t BLOCK_TILE_DIM = 128;
constexpr size_t WARP_TILE_DIM_X = 64;
constexpr size_t WARP_TILE_DIM_Y = 32;
constexpr size_t THREAD_TILE_DIM_X = 8;
constexpr size_t THREAD_TILE_DIM_Y = 8;
#endif
#ifdef TMA_HW_SUPPORTED
constexpr size_t NUM_BYTES_PER_BANK = 4;
constexpr size_t NUM_BANKS_PER_SHARED_ELEM = THREAD_TILE_DIM_Y / NUM_BYTES_PER_BANK;
constexpr size_t SHARED_BLOCK_TILE_DIM_Y = BLOCK_TILE_DIM;
constexpr size_t SHARED_BLOCK_TILE_DIM_X_BANKS =
BLOCK_TILE_DIM / (NUM_BYTES_PER_BANK * NUM_BANKS_PER_SHARED_ELEM);
constexpr size_t NUM_BANKS_Y_IN_WARP = WARP_TILE_DIM_Y / NUM_BYTES_PER_BANK;
#endif
constexpr size_t ELE_PER_THREAD = THREAD_TILE_DIM_X * THREAD_TILE_DIM_Y;
constexpr size_t THREADS_PER_BLOCK = BLOCK_TILE_DIM * BLOCK_TILE_DIM / ELE_PER_THREAD;
constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X;
constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y;
constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK;
constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X;
constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP;
#define MIN(a, b) (a < b ? a : b)
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK)
block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c,
OType* const output_t, CType* const tile_scales_inv_c,
CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y, const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon,
const __grid_constant__ CUtensorMap tensor_map_output_t,
bool pow_2_scaling) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
IVec thrd_tile_input[THREAD_TILE_DIM_Y];
constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1;
OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_];
const int tid_in_warp = threadIdx.x % kThreadsPerWarp;
const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP;
const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP;
const int warp_id_in_block = threadIdx.x / kThreadsPerWarp;
const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK;
const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK;
// This is ONLY true if the input is a full tile
const int tile_id_x = blockIdx.x;
const int tile_id_y = blockIdx.y;
const size_t block_tile_start_idx =
tile_id_y * BLOCK_TILE_DIM * row_length + tile_id_x * BLOCK_TILE_DIM;
const size_t warp_tile_start_idx =
block_tile_start_idx +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP;
const size_t thread_tile_start_idx = warp_tile_start_idx +
tid_in_warp_y * THREAD_TILE_DIM_Y * row_length +
tid_in_warp_x * THREAD_TILE_DIM_X;
CType warp_tile_amax;
CType block_tile_amax;
CType block_tile_scale;
CType amax = 0;
// Step 1: Load a block tile of input data into thread tiles on registers
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length);
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(static_cast<CType>(thrd_tile_input[i].data.elt[j])));
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] =
warp_tile_amax;
}
__syncthreads();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if (threadIdx.x == 0) {
CType blk_amax = block_tile_amax_shared[0];
#pragma unroll
for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) {
blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]);
}
block_tile_amax_shared[0] = blk_amax;
}
__syncthreads();
block_tile_amax = block_tile_amax_shared[0];
block_tile_scale =
compute_scale_from_types<IType, OType>(block_tile_amax, epsilon, pow_2_scaling);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c;
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length);
}
// Step 4: store transpose into shared memory
if constexpr (kReturnTranspose) {
#ifdef TMA_HW_SUPPORTED
__shared__ alignas(128)
OVecTrans block_tile_trans_shared[SHARED_BLOCK_TILE_DIM_Y][SHARED_BLOCK_TILE_DIM_X_BANKS];
OType(*block_tile_trans_shared_otype_ptr)[BLOCK_TILE_DIM] =
reinterpret_cast<OType(*)[BLOCK_TILE_DIM]>(block_tile_trans_shared);
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_X; i++) {
auto warp_id_in_block_x_ = warp_id_in_block_y;
auto warp_id_in_block_y_ = warp_id_in_block_x;
int row_idx = warp_id_in_block_y_ * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP +
tid_in_warp_x * THREAD_TILE_DIM_X + i;
int col_idx =
warp_id_in_block_x_ * (NUM_BANKS_Y_IN_WARP / NUM_BANKS_PER_SHARED_ELEM) + tid_in_warp_y;
block_tile_trans_shared[row_idx][col_idx] = thrd_tile_out_trans[i];
}
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Step 5: store transpose output
// Initiate TMA transfer to copy shared memory to global memory
if (threadIdx.x == 0) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t*>(&tensor_map_output_t), tile_id_y * BLOCK_TILE_DIM,
tile_id_x * BLOCK_TILE_DIM,
reinterpret_cast<uint64_t*>(block_tile_trans_shared_otype_ptr));
// Wait for TMA transfer to have finished reading shared memory.
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for the group to have completed reading from shared memory.
ptx::cp_async_bulk_wait_group_read<0>();
}
#else
// Step 4 Alternative (when TMA is not available, skip writing to shared memory)
const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM;
const size_t warp_tile_t_start_idx =
block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP;
const size_t thread_tile_t_start_idx = warp_tile_t_start_idx +
tid_in_warp_x * THREAD_TILE_DIM_X * num_rows +
tid_in_warp_y * THREAD_TILE_DIM_Y;
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_X; i++) {
thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows);
}
#endif
}
}
template <bool kReturnTranspose, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned(
const IType* const input, OType* const output_c, OType* const output_t,
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
IVec thrd_tile_input[THREAD_TILE_DIM_Y];
constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1;
OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_];
const int tid_in_warp = threadIdx.x % kThreadsPerWarp;
const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP;
const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP;
const int warp_id_in_block = threadIdx.x / kThreadsPerWarp;
const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK;
const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK;
const int tile_id_x = blockIdx.x;
const int tile_id_y = blockIdx.y;
const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM;
const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM;
const size_t block_tile_start_idx =
block_tile_start_row_idx * row_length + block_tile_start_col_idx;
const size_t warp_tile_start_idx =
block_tile_start_idx +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP;
const size_t thread_tile_start_idx = warp_tile_start_idx +
tid_in_warp_y * THREAD_TILE_DIM_Y * row_length +
tid_in_warp_x * THREAD_TILE_DIM_X;
// handle non-full tile
// check for three cases: full thread tile, nonfull thread tile, empty thread tile
// for empty thread tile, directly write zero to the transposed shared mem buffer
// for nonfull thread tile, fill zero to thread tile and act as if it's full
const size_t thread_tile_start_row_idx =
tile_id_y * BLOCK_TILE_DIM + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP +
tid_in_warp_y * THREAD_TILE_DIM_Y;
const size_t thread_tile_start_col_idx =
tile_id_x * BLOCK_TILE_DIM + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP +
tid_in_warp_x * THREAD_TILE_DIM_X;
const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1;
const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1;
bool full_thrd_tile =
(thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length);
bool empty_thrd_tile =
(thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length);
bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile);
const size_t thread_tile_ncols =
MIN(THREAD_TILE_DIM_X,
(MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1));
const size_t thread_tile_nrows =
MIN(THREAD_TILE_DIM_Y,
(MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1));
CType warp_tile_amax;
CType block_tile_amax;
CType block_tile_scale;
CType amax = 0;
if (!empty_thrd_tile) {
// Step 1: Load a block tile of input data into thread tiles on registers
// Edge case: nonfull thread tile case, will use the partial load function here
if (nonfull_thrd_tile) {
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
if (i >= thread_tile_nrows) {
thrd_tile_input[i].clear();
} else {
thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
}
} else {
#pragma unroll
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0,
THREAD_TILE_DIM_X);
}
}
// Step 2: calculate block tile amax and scale
// Calculate thread_tile amax
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(static_cast<CType>(thrd_tile_input[i].data.elt[j])));
}
}
}
// Reduce amax in the warp (32x32 tile)
warp_tile_amax = warp_reduce_max<kThreadsPerWarp>(amax);
// broadcast the amax to all threads in a warp from the lane 0
constexpr int lane_zero = 0;
warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero);
// reduce warp_tile_amax across multiple warps in a thread block using shared mem
if (tid_in_warp == 0) {
block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] =
warp_tile_amax;
}
__syncthreads();
// only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed,
// instead we just let thread 0 do the job
if (threadIdx.x == 0) {
CType blk_amax = block_tile_amax_shared[0];
#pragma unroll
for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) {
blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]);
}
block_tile_amax_shared[0] = blk_amax;
}
__syncthreads();
block_tile_amax = block_tile_amax_shared[0];
block_tile_scale =
compute_scale_from_types<IType, OType>(block_tile_amax, epsilon, pow_2_scaling);
if (threadIdx.x == 0) {
static_assert(std::is_same<CType, float>::value);
const CType scale_inv = 1.0f / block_tile_scale;
size_t row_idx = tile_id_y;
size_t col_idx = tile_id_x;
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
if constexpr (kReturnTranspose) {
row_idx = tile_id_x;
col_idx = tile_id_y;
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
}
// Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here
// for nonfull thread tile, pay attention when saving tmp_output_c to global
// memory, cannot vec store_to, but need to elt store to for empty tile,
// it should not enter this step, skip to Step 4
// set thrd_tile_out_trans to all zero
if constexpr (kReturnTranspose) {
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
thrd_tile_out_trans[j].clear();
}
}
if (!empty_thrd_tile) {
OVecCast tmp_output_c;
for (int i = 0; i < THREAD_TILE_DIM_Y; i++) {
if (i >= thread_tile_nrows) {
continue;
}
#pragma unroll
for (int j = 0; j < THREAD_TILE_DIM_X; j++) {
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
thrd_tile_out_trans[j].data.elt[i] = scaled_elt;
}
}
tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0,
thread_tile_ncols);
}
if constexpr (kReturnTranspose) {
const size_t block_tile_t_start_idx =
tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM;
const size_t warp_tile_t_start_idx =
block_tile_t_start_idx +
warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows +
warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP;
const size_t thread_tile_t_start_idx = warp_tile_t_start_idx +
tid_in_warp_x * THREAD_TILE_DIM_X * num_rows +
tid_in_warp_y * THREAD_TILE_DIM_Y;
#pragma unroll
for (int i = 0; i < thread_tile_ncols; i++) {
thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0,
thread_tile_nrows);
}
}
}
}
template <typename OutputType>
CUtensorMap get_tensor_map(const SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) {
CUtensorMapDataType dataType;
if constexpr (std::is_same_v<OutputType, __nv_fp8_e4m3> ||
std::is_same_v<OutputType, __nv_fp8_e5m2>) {
dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else {
NVTE_CHECK(false, "Invalid Output type (must be FP8).");
}
CUtensorMap tensor_map_output_trans{};
create_2D_tensor_map(tensor_map_output_trans, tensor, global_dim_y, global_dim_x,
/*shmemY=*/BLOCK_TILE_DIM, /*shmemX=*/BLOCK_TILE_DIM,
/*stride_elems=*/global_dim_x, /*offset_elems=*/0, sizeof(OutputType));
return tensor_map_output_trans;
}
} // namespace
} // namespace transformer_engine
namespace transformer_engine::detail {
void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
}
NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions.");
size_t scale_k = scale_inv.shape[1];
const size_t scale_stride_x = 1;
const size_t scale_stride_y = scale_k;
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions.");
scale_t_stride_x = 1;
scale_t_stride_y = scale_inv_t.shape[1];
}
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose,
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
if (full_tile) {
CUtensorMap tensor_map_output_trans;
if (return_transpose) {
tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length);
}
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale);
} else {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType>
<<<grid, THREADS_PER_BLOCK, 0, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale);
} // full-tile
) // return_transpose
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace transformer_engine::detail
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cfloat>
#include <cuda/barrier>
#include <utility>
#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/utils.cuh"
namespace transformer_engine {
namespace {
// clang-format off
/*
Step 1: Load input to shared memory
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 8 times
* What each thread does in each loop:
* 8 elements are read from the input at a time
* 2 elements are written to the shared memory at a time, for a total of 4 times
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 1 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| Warp 7 |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 8 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 2: Cast and store to output_c
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 4 times
* What each thread does in each loop:
* 2 elements are read from the shared memory at a time, for a total of 8 times
* Every 8 consecutive threads do reduction and calculate the amax of each row
* 16 elements are quantized and write to output_c at a time
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 |
| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 |
| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 |
| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 1 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| |
| Warp 7 |
| |
| |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
| ... |
| ... |
| ... |
| ... |
| Loop 4 times |
| ... |
| ... |
| ... |
| ... |
+-------------------------------+-------------------------------+-------------------------------+-------------------------------+
Step 3: Transpose, cast and store to output_t
* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding)
* 8 warps
* Loop 2 times
* What each thread does in each loop:
* 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times
* Every 8 consecutive threads do reduction and calculate the amax of each column
* 16 elements are quantized and write to output_c at a time, for a total of 2 times
+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+
| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | |
| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | |
| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | |
| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 |
| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | |
| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | |
| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | |
| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | |
+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+
*/
// clang-format on
constexpr size_t kThreadsPerWarp = 32;
// Hyperparameters for performance tuning
constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization
constexpr int kNVecIn = 8; // The number of elements each LDG touches
constexpr int kNVecOut = 16; // The number of elements each STG touches
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches
constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total
// Auto-calculated constants, do not modify directly)
static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem");
static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem");
constexpr int kSMemRow = kTileDim;
constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1;
constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem;
constexpr int kNumThreadsLoad = kTileDim / kNVecIn;
constexpr int kNumThreadsStore = kTileDim / kNVecOut;
static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp");
static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp");
template <bool kAligned, typename CType, typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock)
block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c,
OType* const output_t, CType* const tile_scales_inv_c,
CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x,
const size_t scale_stride_y,
const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon,
bool return_transpose, bool pow_2_scaling) {
using SMemVec = Vec<IType, kNVecSMem>;
using OVec = Vec<OType, kNVecOut>;
union IVec {
Vec<IType, kNVecIn> input_type;
Vec<SMemVec, kNVecIn / kNVecSMem> smem_type;
};
extern __shared__ char smem_base[];
SMemVec* smem = reinterpret_cast<SMemVec*>(&smem_base[0]);
// Step 1: Load input to shared memory
{
constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecIn), row_length - c_g)
: 0; // For not aligned case
const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
IVec input_vec;
// Step 1.1: Load from global memory (input) to registers
if constexpr (kAligned) {
input_vec.input_type.load_from(input_g);
} else {
if (r_g < num_rows) {
input_vec.input_type.load_from_elts(input_g, 0, num_ele);
} else {
input_vec.input_type.clear();
}
}
// Step 1.2: Write to shared memory
#pragma unroll
for (int i = 0; i < kNVecIn / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i];
}
// Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case)
input_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
__syncthreads();
// Step 2: Cast and store to output_c
{
constexpr int r_stride =
kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory
constexpr int num_iterations = kTileDim / r_stride;
const int c_s =
(threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory
int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory
const size_t c_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Column in global memory
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Row in global memory
const size_t stride_g = static_cast<size_t>(r_stride) * row_length; // Stride in global memory
const size_t num_ele = c_g < row_length ? min(static_cast<size_t>(kNVecOut), row_length - c_g)
: 0; // For not aligned case
OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut / kNVecSMem];
// Step 2.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
int c = c_s + i;
int r = r_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
// Step 2.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j]));
}
}
// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
CType scale;
// Step 2.4: Compute scale
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 2.5: Write scale_inv
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g < num_rows);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.y) * kTileDim + r_s;
size_t col_idx = static_cast<size_t>(blockIdx.x);
tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv;
}
// Step 2.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut / kNVecSMem; ++i) {
#pragma unroll
for (int j = 0; j < kNVecSMem; ++j) {
output_vec.data.elt[i * kNVecSMem + j] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[j]) * scale);
}
}
// Step 2.7: Store output_c
if constexpr (kAligned) {
output_vec.store_to(output_g);
} else {
if (r_g < num_rows) {
output_vec.store_to_elts(output_g, 0, num_ele);
}
}
// Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
r_s += r_stride;
if constexpr (!kAligned) {
r_g += r_stride;
}
}
}
// Step 3: Transpose, cast and store to output_t
if (return_transpose) {
constexpr int c_stride =
kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory
constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem);
const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory
int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory
size_t r_g =
static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem; // Row in global memory
const size_t c_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; // Column in global memory
const size_t stride_g =
static_cast<size_t>(c_stride) * kNVecSMem * num_rows; // Stride in global memory
const size_t num_ele = c_g < num_rows ? min(static_cast<size_t>(kNVecOut), num_rows - c_g)
: 0; // For not aligned case
OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory
// Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of
// the first thread to do the reduction.
const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore;
// This mask represents which threads should do the reduction together.
const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane;
const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0;
#pragma unroll
for (int iter = 0; iter < num_iterations; ++iter) {
SMemVec smem_vec[kNVecOut];
// Step 3.1: Load from shared memory to registers
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
int r = r_s + i;
int c = c_s;
smem_vec[i] = smem[r * kSMemCol + c];
}
#pragma unroll
for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) {
// Step 3.2: Compute local amax
CType amax = 0;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx]));
}
// Step 3.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);
__builtin_assume(amax >= 0);
__builtin_assume(other_amax >= 0);
amax = fmaxf(amax, other_amax);
}
amax = __shfl_sync(mask, amax, src_lane);
// Step 3.4: Compute scale
CType scale;
scale = compute_scale_from_types<IType, OType>(amax, epsilon, pow_2_scaling);
// Step 3.5: Write scale_inv_t
bool write_scale_inv = is_src_lane;
if constexpr (!kAligned) {
write_scale_inv &= (r_g + smem_idx < row_length);
}
if (write_scale_inv) {
CType scale_inv = 1.0 / scale;
size_t row_idx = static_cast<size_t>(blockIdx.x) * kTileDim + c_s * kNVecSMem + smem_idx;
size_t col_idx = static_cast<size_t>(blockIdx.y);
tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv;
}
// Step 3.6: Quantize
OVec output_vec;
#pragma unroll
for (int i = 0; i < kNVecOut; ++i) {
output_vec.data.elt[i] =
static_cast<OType>(static_cast<CType>(smem_vec[i].data.elt[smem_idx]) * scale);
}
// Step 3.7: Store output_t
if constexpr (kAligned) {
output_vec.store_to(output_g + smem_idx * num_rows);
} else {
if (r_g + smem_idx < row_length) {
output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele);
}
}
}
// Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case)
output_g += stride_g;
c_s += c_stride;
if constexpr (!kAligned) {
r_g += c_stride * kNVecSMem;
}
}
}
}
} // namespace
} // namespace transformer_engine
namespace transformer_engine::detail {
void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv,
SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow2_scale,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
num_elements *= input.shape.at(i);
}
// Early return if the input tensor is empty
if (num_elements == 0) {
return;
}
// Options for scale layout of cuBLAS GEMM kernel.
NVTE_CHECK(input.shape.size() == output.shape.size(),
"Input and output must have the same shape.");
size_t scale_stride_x = 0;
size_t scale_stride_y = 0;
NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2.");
size_t scale_k = scale_inv.shape[1];
scale_stride_x = scale_k;
scale_stride_y = 1;
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
scale_t_stride_x = scale_inv_t.shape[1];
scale_t_stride_y = 1;
}
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned,
size_t smem_bytes = kSMemSize * sizeof(InputType);
// shared memory must be requested up
if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>(
reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, return_transpose,
pow2_scale);) // kAligned
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace transformer_engine::detail
......@@ -1262,6 +1262,30 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
workspace_tensor, stream);
break;
}
case NVTE_BLOCK_SCALING_2D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D");
constexpr bool force_pow_2_scales = true;
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data,
/*epsilon=*/0.0,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D");
constexpr bool force_pow_2_scales = true;
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data,
/*epsilon=*/0.0,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + ".");
}
......
......@@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0");
}
} else {
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + ".");
}
}
......
......@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
: "memory");
}
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
__device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr,
......@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
: "memory");
}
__device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) {
uint32_t waitComplete;
asm volatile(
"{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2; \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(waitComplete)
: "r"(mbar_ptr), "r"(parity)
: "memory");
return static_cast<bool>(waitComplete);
}
__device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
while (!mbarrier_try_wait_parity(mbar_ptr, parity)) {
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__ __forceinline__ void cp_async_bulk_wait_group() {
asm volatile("cp.async.bulk.wait_group 0;");
......@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
asm volatile("cp.async.bulk.wait_group.read 4;");
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__ __forceinline__ void cp_async_bulk_commit_group() {
asm volatile("cp.async.bulk.commit_group;");
}
// Proxy fence (bi-directional):
__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); }
__device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;");
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx
......
......@@ -183,8 +183,8 @@ class ScalingMode(Enum):
NVTE_DELAYED_TENSOR_SCALING = 0
NVTE_MXFP8_1D_SCALING = 1
NVTE_INVALID_SCALING = 2
NVTE_NO_SCALING = 3
NVTE_INVALID_SCALING = 4
NVTE_NO_SCALING = 5
def _get_impl(self) -> ScalingModeMetadataImpl:
"""Get the implementation for this scaling mode.
......
......@@ -24,6 +24,12 @@ TE_DType = {
torch.bfloat16: tex.DType.kBFloat16,
}
"""
This is a map: int -> torch.dtype
Used for resolving cuda extension types to torch.
Has one to one mapping with enum in
transformer_engine.h
"""
TE_DType_To_Torch = {
tex.DType.kByte: torch.uint8,
tex.DType.kFloat8E4M3: torch.float8_e4m3fn,
......
......@@ -163,6 +163,38 @@ class Float8CurrentScalingQuantizer : public Quantizer {
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class Float8BlockQuantizer : public Quantizer {
public:
// Which float8 type is used for q data.
DType dtype;
private:
// Options about how to quantize the tensor
// Quantization scales are rounded down to powers of 2.
bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0;
int block_scaling_dim = 2;
public:
// Initializes from a python handle to a Float8BlockQuantizer
explicit Float8BlockQuantizer(const py::handle& quantizer);
NVTEScalingMode get_scaling_mode() const override {
return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D;
}
// Gets rowwise and columnwise_data from tensor and sets them on wrapper
void set_quantization_params(TensorWrapper* tensor) const override;
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(
const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> rowwise_data = std::nullopt) const override;
};
class MXFP8Quantizer : public Quantizer {
public:
DType dtype;
......
......@@ -12,6 +12,8 @@
// #include <torch/all.h>
#include <assert.h>
#include <limits>
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
......@@ -47,8 +49,8 @@ struct ComputeScaleAndScaleInvFunctor {
n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8,
force_pow_2_scales, epsilon);
float scale_val = transformer_engine::compute_scale_from_amax(
amax[i_start], max_fp8, force_pow_2_scales, epsilon, std::numeric_limits<float>::max());
scale[i_start] = scale_val;
transformer_engine::reciprocal(scale_inv + i_start, scale_val);
}
......
......@@ -28,6 +28,9 @@ PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr;
PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove
PyTypeObject *MXFP8TensorBasePythonClass = nullptr;
PyTypeObject *MXFP8QuantizerClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr;
PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr;
PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
void init_float8_extension() {
if (Float8TensorPythonClass) return;
......@@ -61,9 +64,31 @@ void init_mxfp8_extension() {
"Internal error: could not initialize pyTorch MXFP8 extension.");
}
void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorBasePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
"transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base");
Float8BlockwiseQuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer"));
Float8BlockwiseQTensorBasePythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase"));
Float8BlockwiseQTensorPythonClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor"));
NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr,
"Internal error: could not initialize pyTorch float8blockwise extension.");
}
void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
}
} // namespace transformer_engine::pytorch
......@@ -76,6 +101,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype"));
m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)",
......
......@@ -250,6 +250,142 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector<size_t>{1});
}
this->set_quantization_params(&tensor);
return {std::move(tensor), std::move(ret)};
}
Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) {
this->dtype = quantizer.attr("dtype").cast<DType>();
this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast<int>();
NVTE_CHECK(quantizer.attr("force_pow_2_scales").cast<bool>(),
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires pow2 scales");
NVTE_CHECK(quantizer.attr("amax_epsilon").cast<float>() == 0.0,
"Pending additional parameters to the nvte_quantize API, "
"float8 block quantization requires amax_epsilon==0");
NVTE_CHECK(this->block_scaling_dim == 1 || this->block_scaling_dim == 2,
"Unsupported block scaling dim.");
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> rowwise_data) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
size_t numel = 1;
for (auto s : shape) {
torch_shape.emplace_back(static_cast<int64_t>(s));
numel *= s;
}
TensorWrapper tensor(this->get_scaling_mode());
at::TensorOptions opts;
at::TensorOptions scale_opts;
at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back();
size_t m_dim = numel / k_dim;
constexpr size_t kBlockLen = 128;
if (rowwise_usage) {
if (rowwise_data.has_value()) {
data_rowwise = std::move(*rowwise_data);
} else {
data_rowwise = at::empty(torch_shape, opts);
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(m_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor rowwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts);
tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{sinv0, sinv1});
}
if (columnwise_usage) {
std::vector<int64_t> torch_columnwise_shape;
std::vector<size_t> columnwise_shape;
NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape. Shape ",
columnwise_shape, " torch shape: ", torch_columnwise_shape);
if (torch_shape.size() > 0) {
torch_columnwise_shape.reserve(torch_shape.size());
columnwise_shape.reserve(shape.size());
torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]);
columnwise_shape.push_back(shape[shape.size() - 1]);
for (size_t i = 0; i < torch_shape.size() - 1; ++i) {
torch_columnwise_shape.push_back(torch_shape[i]);
columnwise_shape.push_back(shape[i]);
}
}
size_t sinv0 = 0;
size_t sinv1 = 0;
if (block_scaling_dim == 2) {
sinv0 = (k_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4);
} else if (block_scaling_dim == 1) {
sinv0 = (m_dim + kBlockLen - 1) / kBlockLen;
sinv1 = roundup(k_dim, 4);
} else {
NVTE_CHECK(false,
"Unsupported block_scaling_dim in create_tensor columnwise."
"Expected 1 or 2. Got ",
block_scaling_dim);
}
data_colwise = at::empty(torch_columnwise_shape, opts);
scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts);
tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape);
tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32,
std::vector<size_t>{sinv0, sinv1});
}
this->set_quantization_params(&tensor);
py::object ret;
if (internal) {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorBasePythonClass));
ret = Float8BlockwiseQTensorClass(
"rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise,
"rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise,
"fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer,
"is_2D_scaled"_a = (block_scaling_dim == 2));
} else {
py::handle Float8BlockwiseQTensorClass(
reinterpret_cast<PyObject*>(Float8BlockwiseQTensorPythonClass));
ret = Float8BlockwiseQTensorClass(
"shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise,
"columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise,
"columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype,
"quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2));
}
return {std::move(tensor), std::move(ret)};
}
......@@ -302,8 +438,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
auto sinv1 = roundup(last_dim / MXFP8_BLOCK_SIZE, 4);
rowwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape);
tensor.set_rowwise_scale_inv(rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
tensor.set_rowwise_scale_inv(
rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
}
if (columnwise_usage) {
......@@ -313,8 +450,9 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(
columnwise_scale_inv = at::zeros({sinv0, sinv1}, opts);
tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape);
tensor.set_columnwise_scale_inv(columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{sinv0, sinv1});
tensor.set_columnwise_scale_inv(
columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0,
std::vector<size_t>{static_cast<size_t>(sinv0), static_cast<size_t>(sinv1)});
}
this->set_quantization_params(&tensor);
......
......@@ -84,6 +84,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer)
return ret;
}
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) {
const DType dtype = tensor.attr("_fp8_dtype").cast<DType>();
bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast<bool>();
bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none());
bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none());
auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
if (rowwise_usage) {
const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast<at::Tensor>();
void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr();
const auto &rowwise_shape = getTensorShape(data_rowwise);
ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape);
const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise);
ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape);
}
if (columnwise_usage) {
const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast<at::Tensor>();
const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast<at::Tensor>();
void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr();
const auto &shape = getTensorShape(data_colwise);
ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape);
const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise);
ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape);
}
quantizer->set_quantization_params(&ret);
return ret;
}
} // namespace detail
} // namespace transformer_engine::pytorch
......@@ -25,6 +25,9 @@ extern PyTypeObject *Float8CurrentScalingQuantizerClass;
extern PyTypeObject *MXFP8TensorPythonClass;
extern PyTypeObject *MXFP8TensorBasePythonClass;
extern PyTypeObject *MXFP8QuantizerClass;
extern PyTypeObject *Float8BlockwiseQTensorPythonClass;
extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass;
extern PyTypeObject *Float8BlockwiseQuantizerClass;
void init_extension();
......@@ -50,6 +53,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) {
return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass;
}
inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQuantizerClass;
}
inline bool IsFloat8BlockwiseQTensor(PyObject *obj) {
return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass ||
Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass;
}
TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer);
template <typename T>
......@@ -61,6 +73,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati
std::unique_ptr<Quantizer> CreateMXFP8Params(const py::handle params);
TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor,
Quantizer *quantization_params);
inline bool IsFloatingPointType(at::ScalarType type) {
return type == at::kFloat || type == at::kHalf || type == at::kBFloat16;
}
......@@ -71,7 +86,9 @@ constexpr std::array custom_types_converters = {
std::make_tuple(IsFloat8Tensor, IsFloat8CurrentScalingQuantizers, NVTETensorFromFloat8Tensor,
CreateQuantizer<Float8CurrentScalingQuantizer>),
std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor,
CreateQuantizer<MXFP8Quantizer>)};
CreateQuantizer<MXFP8Quantizer>),
std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers,
NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer<Float8BlockQuantizer>)};
} // namespace detail
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Mixin class holding data specific for Float8BlockwiseQTensor"""
from __future__ import annotations
import math
from typing import Optional, Dict, Any, Tuple
import torch
from transformer_engine_torch import DType as TE_DType
from ...constants import TE_DType_To_Torch
from ..quantized_tensor import Quantizer
class Float8BlockwiseQTensorBase:
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
mixin class. If this class is instantiated directly, it has the same
data, lower CPU overhead, and less functionality. It should only
be instantiated directly for performance-critical internal usage.
"""
_rowwise_data: Optional[torch.Tensor]
_columnwise_data: Optional[torch.Tensor]
_quantizer: Quantizer
_fp8_dtype: TE_DType
_rowwise_scale_inv: Optional[torch.Tensor]
_columnwise_scale_inv: Optional[torch.Tensor]
_is_2D_scaled: bool
def __new__(
cls,
*args,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
columnwise_data: Optional[torch.Tensor],
columnwise_scale_inv: Optional[torch.Tensor],
fp8_dtype: TE_DType,
quantizer: Quantizer,
is_2D_scaled: bool,
**kwargs,
):
instance = super().__new__(cls, *args, **kwargs)
instance._rowwise_data = rowwise_data
instance._columnwise_data = columnwise_data
instance._quantizer = quantizer
instance._fp8_dtype = fp8_dtype
instance._rowwise_scale_inv = rowwise_scale_inv
instance._columnwise_scale_inv = columnwise_scale_inv
instance._is_2D_scaled = is_2D_scaled
return instance
def get_metadata(self) -> Dict[str, Any]:
"""Get this tensor's metadata."""
return {
"rowwise_data": self._rowwise_data,
"rowwise_scale_inv": self._rowwise_scale_inv,
"columnwise_data": self._columnwise_data,
"columnwise_scale_inv": self._columnwise_scale_inv,
"fp8_dtype": self._fp8_dtype,
"quantizer": self._quantizer,
"is_2D_scaled": self._is_2D_scaled,
}
def prepare_for_saving(
self,
) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]:
"""Prepare the tensor base for saving for backward"""
tensors = [self._rowwise_data, self._columnwise_data]
self._rowwise_data = None
self._columnwise_data = None
return tensors, self
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list."""
self._rowwise_data = tensors[0]
self._columnwise_data = tensors[1]
return tensors[2:]
def get_data_tensors(self):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
if columnwise_dq.dim() < 2:
return columnwise_dq
permute_dims = list(range(1, columnwise_dq.dim()))
permute_dims.append(0)
return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous()
def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
block_len = 128
q_M, q_K = 1, 1
if self._rowwise_data is not None:
q = self._rowwise_data
scale_inv = self._rowwise_scale_inv
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
k_tiles, scale_m = scale_inv.shape
if q_K % block_len != 0:
k_pad_amount = (block_len - (q_K % block_len)) % block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, 0), mode="constant", value=0
).contiguous()
_, padded_K = q.shape
q_tiled = q.reshape(q_M, k_tiles, block_len)
if scale_m > q_M:
# scale_m is 4 element aligned.
scale_inv = scale_inv[:, :q_M].contiguous()
dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(q_M, k_tiles, 1)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale
if padded_K != q_K:
result = result.reshape(q_M, padded_K)[:, :q_K]
result = result.to(dtype)
if len(orig_shape) == 0:
result = result.reshape([])
else:
result = result.reshape(*orig_shape).contiguous()
if transpose_output:
return self._transpose_dq_columnwise_output(result)
return result
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
"""
block_len = 128
if not self._is_2D_scaled:
return self._dequantize_vectorwise(dtype=dtype)
def format_scale_as_logical_shape(q_K, scales, block_len):
# The GEMM for 2D blocks required padding in the scales.
derived_scale_k_shape = math.ceil(q_K / block_len)
_, scale_K = scales.shape
if derived_scale_k_shape == scale_K:
return scales
return scales[:, :derived_scale_k_shape].contiguous()
q_M, q_K = 1, 1
if self._rowwise_data is not None:
q = self._rowwise_data
scale_inv = self._rowwise_scale_inv
transpose_output = False
if len(q.shape) >= 1:
q_K = q.shape[-1]
for i in range(len(q.shape) - 1):
q_M *= q.shape[i]
else:
assert self._columnwise_data is not None, "No data to dequantize"
q = self._columnwise_data
scale_inv = self._columnwise_scale_inv
transpose_output = True
if len(q.shape) >= 1:
q_M = q.shape[0]
for i in range(1, len(q.shape)):
q_K *= q.shape[i]
orig_shape = q.shape
q = q.reshape(q_M, q_K)
formatted_scales = format_scale_as_logical_shape(q_K, scale_inv, block_len)
assert len(formatted_scales.shape) == 2
m_tiles, k_tiles = formatted_scales.shape
unpadded_m, unpadded_k = q_M, q_K
m_block_len = block_len
k_block_len = block_len
if q_M % m_block_len != 0 or q_K % k_block_len != 0:
m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len
k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len
q = torch.nn.functional.pad(
q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0
).contiguous()
padded_M, padded_K = q.shape
q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len)
torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype]
result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view(
m_tiles, 1, k_tiles, 1
)
result = result.view(padded_M, padded_K).to(dtype)
if padded_M != unpadded_m or padded_K != unpadded_k:
result = result[:unpadded_m, :unpadded_k]
if len(orig_shape) == 0:
result = result.reshape([])
else:
result = result.reshape(*orig_shape).contiguous()
if transpose_output:
return self._transpose_dq_columnwise_output(result)
return result
def size(self, *args, **kwargs):
# pylint: disable=missing-function-docstring
if self._rowwise_data is not None:
return self._rowwise_data.size(*args, **kwargs)
dims = list(self._columnwise_data.size(*args, **kwargs))
reordered = []
for i in range(1, len(dims)):
reordered.append(dims[i])
reordered.append(dims[0])
return torch.Size(reordered)
def __repr__(self):
if self._rowwise_data is not None:
data = self.dequantize()
descriptor = "rowwise"
else:
data = self.dequantize()
descriptor = "columnwise"
return (
"Float8BlockwiseQTensorBase("
f"fp8_dtype={self._fp8_dtype}, "
f"{descriptor}_scaled_data={data}"
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Tensor class with FP8 data quantized with NxN tiles"""
from __future__ import annotations
from typing import Optional, Tuple, Iterable
import math
import torch
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc
from ..utils import devices_match, round_up_to_nearest_multiple
aten = torch.ops.aten
class Float8BlockQuantizer(Quantizer):
"""Builder class for tensors quantized with current scaling using
NxN quantization tilings to choose scale.
This class is typically used to convert a high-precision tensor
(e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
dtype: TE_DType
block_len: int
amax_epsilon: float
force_pow_2_scales: bool
block_scaling_dim: int
def __init__(
self,
fp8_dtype: TE_DType,
*,
rowwise: bool,
columnwise: bool,
amax_epsilon: float = 0.0,
force_pow_2_scales: bool = True,
block_scaling_dim: int = 2,
) -> None:
super().__init__(rowwise=rowwise, columnwise=columnwise)
assert rowwise
self.dtype = fp8_dtype
self.block_len = 128
self.force_pow_2_scales = force_pow_2_scales
self.amax_epsilon = amax_epsilon
self.block_scaling_dim = block_scaling_dim
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Update the quantized tensor with data from the source tensor.
This method quantizes the input tensor and stores the result in the destination tensor.
Parameters
----------
src : torch.Tensor
Source tensor containing the data to be quantized
dst : QuantizedTensor
Destination tensor where the quantized data will be stored
noop_flag : Optional[torch.Tensor]
Optional flag tensor indicating whether to skip the quantization operation
Returns
-------
QuantizedTensor
The destination tensor containing the quantized data
Raises
------
AssertionError
If the destination tensor is not a Float8BlockwiseQTensor
"""
assert isinstance(
dst, Float8BlockwiseQTensor
), f"Cannot store quantized blockwise tensor in {type(dst)} type."
# Make sure input is in expected format
if not devices_match(src.device, dst.device):
src = src.to(device=dst.device)
if not src.is_contiguous():
src = src.contiguous()
# Launch cast kernel
tex.quantize(src, self, dst, noop_flag)
dst._fp8_dtype = self.dtype
return dst
def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]:
"""Calculate the shape of the scaling tensor for blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
The scales are padded to multiples of 4 on the inner dimension for compatibility with GEMM.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For 2D tensors:
- If columnwise: (roundup(K/blocksize), round_to_multiple(roundup(M/blocksize), 4))
- If rowwise: (roundup(M/blocksize), round_to_multiple(roundup(K/blocksize), 4))
For 1D tensors:
- If columnwise: (roundup(M/blocksize), round_to_multiple(K, 4))
- If rowwise: (roundup(K/blocksize), round_to_multiple(M, 4))
"""
M, K = 1, 1
for i in range(len(shape) - 1):
M *= shape[i]
if len(shape) > 0:
K = shape[-1]
if self.block_scaling_dim == 2:
if columnwise:
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(M / self.block_len), 4)
return (outer, inner)
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(math.ceil(K / self.block_len), 4)
return (outer, inner)
assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported"
if columnwise:
outer = math.ceil(M / self.block_len)
inner = round_up_to_nearest_multiple(K, 4)
return (outer, inner)
outer = math.ceil(K / self.block_len)
inner = round_up_to_nearest_multiple(M, 4)
return (outer, inner)
def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]:
"""Calculate the shape of a tensor after columnwise permutation.
This method rearranges the dimensions of a tensor to be columnwise,
moving the last dimension to the front and keeping the order of other dimensions.
Parameters
----------
shape : Iterable[int]
Original shape of the tensor
Returns
-------
Tuple[int, ...]
New shape with dimensions rearranged for columnwise layout.
For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1).
Returns empty tuple for empty input shape.
"""
if len(shape) == 0:
return tuple()
colwise_shape = [shape[-1]]
for i in range(len(shape) - 1):
colwise_shape.append(shape[i])
return tuple(colwise_shape)
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
# Allocate FP8 data
data = torch.empty(shape, dtype=torch.uint8, device=device)
scale_shape = self.get_scale_shape(shape, columnwise=False)
scale_inv = torch.empty(
scale_shape,
dtype=torch.float32,
device=device,
)
# Allocate FP8 data transpose if needed
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape), dtype=torch.uint8, device=device
)
columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True)
columnwise_scale_inv = torch.empty(
columnwise_scale_shape,
dtype=torch.float32,
device=device,
)
# Construct FP8 tensor
return Float8BlockwiseQTensor(
shape=shape,
dtype=dtype,
fp8_dtype=self.dtype,
rowwise_data=data,
rowwise_scale_inv=scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
requires_grad=requires_grad,
)
def calibrate(self, tensor: torch.Tensor) -> None:
# NOTE: This interface is specific to requirements like delayed scaling
# where state from an estimator influences distribution parameters.
pass
class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
"""Tensor class with FP8 data quantized via NxN blocks or 1xN blocks.
The tensor presents as having a standard, higher-precision dtype,
but the data itself is (scaled) FP8. For most tensor operations,
the data will be cast to the nominal dtype before performing the
operation.
Parameters
----------
rowwise_data: torch.Tensor
FP8 data in a uint8 tensor matching shape of dequantized tensor.
rowwise_scale_inv: torch.Tensor
FP32 dequantization scales in GEMM format for dequantizing rowwise_data.
columnwise_data: Optional[torch.Tensor]
FP8 data in a uint8 tensor matching shape of dequantized tensor transpose.
columnwise_scale_inv: Optional[torch.Tensor]
FP32 dequantization scales in GEMM format for dequantizing columnwise_data.
fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3
FP8 format.
quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and
holds configuration about quantization and dequantization modes.
"""
def __repr__(self, *, tensor_contents=None):
return (
f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype},"
f" is_2D_scaled={self._is_2D_scaled},"
f" data={self.dequantize(dtype=self.dtype)})"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
assert self._quantizer is not None
return self._quantizer
def quantize_(
self,
tensor: torch.Tensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> Float8BlockwiseQTensor:
"""Update FP8 data
Parameters
----------
tensor: torch.Tensor
Tensor to copy from
noop_flag: torch.Tensor, optional
float32 flag indicating whether to avoid performing update
"""
if isinstance(tensor, QuantizedTensor):
return self.quantize_(tensor.dequantize())
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""
Construct plain PyTorch tensor from Float8BlockwiseQTensor
By default the resulting tensor's dtype is the
Float8BlockwiseQTensor's pre-quantized dtype.
"""
if dtype is not None:
dequant_dtype = dtype
else:
dequant_dtype = self.dtype
return super().dequantize(dtype=dequant_dtype)
def detach(self) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return Float8BlockwiseQTensor.make_like(self)
def update_usage(
self, rowwise_usage: Optional[bool] = None, columnwise_usage: Optional[bool] = None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if rowwise_usage is None:
rowwise_usage = self._rowwise_data is not None
if columnwise_usage is None:
columnwise_usage = self._columnwise_data is not None
assert (
columnwise_usage or rowwise_usage
), "Must retain some data either columnwise or rowwise"
if columnwise_usage and rowwise_usage:
assert (
self._rowwise_data is not None
and self._rowwise_scale_inv is not None
and self._columnwise_data is not None
and self._columnwise_scale_inv is not None
), "Cannot update to rowwise and columnwise usage."
return
if rowwise_usage:
assert (
self._rowwise_data is not None and self._rowwise_scale_inv is not None
), "Cannot update to rowwise usage."
self._columnwise_data = None
self._columnwise_scale_inv = None
return
if columnwise_usage:
assert (
self._columnwise_data is not None and self._columnwise_scale_inv is not None
), "Cannot update to columnwise usage."
self._rowwise_data = None
self._rowwise_scale_inv = None
return
return
def clone(self) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
rowwise_data = None
if self._rowwise_data is not None:
rowwise_data = self._rowwise_data.detach().clone()
columnwise_data = None
if self._columnwise_data is not None:
columnwise_data = self._columnwise_data.detach().clone()
return _IdentityFunc.apply(
self,
{
"rowwise_data": rowwise_data,
"columnwise_data": columnwise_data,
},
)
def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return _ViewFunc.apply(self, shape)
def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
return _ReshapeFunc.apply(self, shape)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# View op
if func == aten.view.default:
tensor = args[0]
data = tensor._rowwise_data
if data is None:
# Columnwise data only.
super().__torch_dispatch__(func, types, args, kwargs)
orig_size = data.size()
out_data = data.__torch_dispatch__(
func,
types,
[data] + list(args[1:]),
kwargs,
)
if orig_size != out_data.size():
raise NotImplementedError(
"Changing shape with view not implemented "
" (scales and columnwise data untouched)."
)
return Float8BlockwiseQTensor.make_like(tensor)
# Default case
return super().__torch_dispatch__(func, types, args, kwargs)
def contiguous(
self,
memory_format: torch.memory_format = torch.contiguous_format,
) -> Float8BlockwiseQTensor:
"""Returns tensor with data in provided memory format
Returns `self` if data is already in correct memory format.
"""
if (
self._rowwise_data is not None
and self._rowwise_data.is_contiguous(memory_format=memory_format)
and (
(self._columnwise_data is None)
or (self._columnwise_data.is_contiguous(memory_format=memory_format))
)
):
return self
raise ValueError("Float8BlockwiseQTensor does not support different memory formats!")
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None
@classmethod
def _make_in_reduce_ex(
cls,
shape: torch.Size,
rowwise_data: torch.Tensor,
rowwise_scale_inv: torch.Tensor,
columnwise_data: torch.Tensor,
columnwise_scale_inv: torch.Tensor,
fp8_dtype: TE_DType,
dtype: torch.dtype,
quantizer: Quantizer,
is_2D_scaled: bool,
) -> Float8BlockwiseQTensor:
"""Build Float8BlockwiseQTensor, for use in __reduce__
__reduce_ex__ assumes object constructor has positional
arguments.
"""
return Float8BlockwiseQTensor(
shape=shape,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
fp8_dtype=fp8_dtype,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
dtype=dtype,
quantizer=quantizer,
is_2D_scaled=is_2D_scaled,
)
def __reduce_ex__(self, protocol: int) -> tuple:
"""Custom pickling to remove references to FP8 metadata objects"""
return (
Float8BlockwiseQTensor._make_in_reduce_ex,
(
self.shape,
self._rowwise_data,
self._rowwise_scale_inv,
self._columnwise_data,
self._columnwise_scale_inv,
self._fp8_dtype,
self.dtype,
self._quantizer,
self._is_2D_scaled,
),
)
def _get_data(self) -> Float8BlockwiseQTensor:
"""Get tensor data property"""
return self
@torch.no_grad()
def _set_data(self, tensor: torch.Tensor) -> None:
"""Set tensor data property
Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise
casts to FP8.
"""
# Tensor device
new_device = tensor.device if tensor.is_cuda else self.device
def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor):
dst._rowwise_data = src._rowwise_data
dst._columnwise_data = src._columnwise_data
dst._quantizer = src._quantizer
dst._fp8_dtype = src._fp8_dtype
dst._rowwise_scale_inv = src._rowwise_scale_inv
dst._columnwise_scale_inv = src._columnwise_scale_inv
dst.dtype = src.dtype
# Check that tensor dimensions match
if (
self.size() != tensor.size()
or self.stride() != tensor.stride()
or self.layout != tensor.layout
):
raise ValueError("Invalid tensor for updating Float8BlockwiseQTensor data")
# Just copy FP8 data if other tensor is Float8BlockwiseQTensor
if (
isinstance(tensor, Float8BlockwiseQTensor)
and self.storage_offset() == tensor.storage_offset()
and devices_match(self.device, new_device)
):
_set_from_tensor(self, tensor)
return
if isinstance(tensor, Float8BlockwiseQTensor):
assert tensor._quantizer is not None, "Can't quantize without a quantizer"
quantizer = tensor._quantizer
else:
assert self._quantizer is not None, "Can't quantize without a quantizer"
quantizer = self._quantizer
# Quantize to FP8
quantizer.update_quantized(tensor, self)
# Cast to FP8 when setting Float8BlockwiseQTensor.data
data = property(_get_data, _set_data)
class _ViewFunc(torch.autograd.Function):
"""View function
View the Float8BlockwiseQTensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: Float8BlockwiseQTensor,
shape: Optional[list[int]] = None,
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if ctx is not None:
ctx.shape = tensor.shape
if shape is None:
return tensor
if list(shape) != list(tensor.shape):
raise NotImplementedError("View not implemented.")
return tensor
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("View bwd not implemented")
return grad.view(ctx.shape), None
class _ReshapeFunc(torch.autograd.Function):
"""Reshape function
Reshape the Float8BlockwiseQTensor using the provided shape.
"""
@staticmethod
def forward(
ctx,
tensor: Float8BlockwiseQTensor,
shape: Optional[list[int]] = None,
) -> Float8BlockwiseQTensor:
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if ctx is not None:
ctx.shape = tensor.shape
if shape is None:
return tensor
# Canonicalize shape
if not isinstance(shape, Iterable):
shape = [shape]
elif len(shape) == 1 and isinstance(shape[0], Iterable):
shape = shape[0]
if -1 in shape:
shape = list(shape)
d_inferred = -math.prod(tensor.shape) // math.prod(shape)
for i, d in enumerate(shape):
if d == -1:
shape[i] = d_inferred
break
if list(shape) != list(tensor.shape):
raise NotImplementedError("Reshape not implemented yet.")
return tensor
@staticmethod
def backward(
ctx,
grad: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
# pylint: disable=missing-function-docstring
if isinstance(grad, Float8BlockwiseQTensor):
raise NotImplementedError("Reshape bwd not implemented yet.")
return grad.view(ctx.shape), None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment