Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -50,7 +50,7 @@ class MMParams:
Parameters
----------
use_split_accumulator : bool, default = `True`
use_split_accumulator : bool, default = True
Use FP8 fast accumulation on Hopper or Ada. For more details,
see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul.
"""
......@@ -159,7 +159,7 @@ class DelayedScaling(Recipe):
recipe: DelayedScaling) -> Tensor
where `Tensor` is a framework tensor type.
reduce_amax: bool, default = `True`
reduce_amax: bool, default = True
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `amax_reduction_group` (specified in the `autocast`
call). This keeps the amaxes and scaling factors synced across the given
......@@ -167,13 +167,13 @@ class DelayedScaling(Recipe):
GPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
fp8_dpa: bool, default = `False`
fp8_dpa: bool, default = False
Whether to enable FP8 dot product attention (DPA). When the model is placed in an
`autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the
inputs from higher precision to FP8, performs attention in FP8, and casts tensors
back to higher precision as outputs. FP8 DPA currently is only supported in the
`FusedAttention` backend.
fp8_mha: bool, default = `False`
fp8_mha: bool, default = False
Whether to enable FP8 multi-head attention (MHA). When `True`, it removes the casting
operations mentioned above at the DPA boundaries. Currently only standard MHA modules
i.e. `LayerNormLinear/Linear + DPA + Linear`, are supported for this feature. When
......@@ -422,11 +422,11 @@ class NVFP4BlockScaling(Recipe):
----------
fp4_format : {Format.E2M1}, default = Format.E2M1
FP4 data type.
disable_rht : bool, default = `False`
disable_rht : bool, default = False
If set to `True`, random Hadamard transforms are not applied to any tensor.
disable_stochastic_rounding : bool, default = `False`
disable_stochastic_rounding : bool, default = False
If set to `True`, stochastic rounding is disabled during quantization for all tensors.
disable_2d_quantization : bool, default = `False`
disable_2d_quantization : bool, default = False
If set to `True`, 1D block scaling with block size 16 is used for all tensors.
"""
......@@ -492,17 +492,19 @@ class CustomRecipe(Recipe):
Parameters
----------
qfactory : Callable
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as:
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
Factory callable that returns a quantizer instance for a
given semantic tensor role.
The callable is typically invoked as::
qfactory(
role: str,
)
Where `role` is one of the following strings for e.g. te.Linear
(stable public contract):
- forward: "linear_input", "linear_weight", "linear_output"
- backward: "linear_grad_output", "linear_grad_input"
"""
qfactory: Callable[..., Any]
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/recipe.h>
#include "../common.h"
#include "../util/ptx.cuh"
#include "../utils.cuh"
namespace transformer_engine {
namespace mxfp8_scaling_recipe {
constexpr int rowwise_row_padding = 128; // Row padding of rowwise_scale and rowwise_amax
constexpr int rowwise_col_padding = 4; // Column padding of rowwise_scale and rowwise_amax
constexpr int colwise_row_padding = 4; // Row padding of colwise_scale and colwise_amax
constexpr int colwise_col_padding = 128; // Column padding of colwise_scale and colwise_amax
constexpr int kRowsPerTile = 32; // Rows each block processes
constexpr int kColsPerTile = 128; // Columns each block processes
constexpr int kThreadsPerBlock = 128;
template <typename IType>
__global__ void __launch_bounds__(kThreadsPerBlock)
mxfp8_scaling_compute_partial_amax_kernel(const IType *input, IType *amax_rowwise,
IType *amax_colwise, int amax_rowwise_stride,
int amax_colwise_stride, int rows, int cols,
size_t start_offset, size_t len) {
__shared__ float smem_amax_rowwise[kRowsPerTile][kColsPerTile / 32];
size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
int warp_idx = threadIdx.x / 32;
int lane_idx = threadIdx.x % 32;
int c = blockIdx.x * kColsPerTile + threadIdx.x;
int r = blockIdx.y * kRowsPerTile;
float col_amax = 0.0f;
#pragma unroll
for (int i = 0; i < kRowsPerTile; i++) {
size_t idx = r * cols + c;
float row_amax = 0.0f;
if (r < rows && c < cols && idx >= start_offset && idx < end_offset) {
float abs_input = fabs(static_cast<float>(input_minus_offset[idx]));
row_amax = fmaxf(row_amax, abs_input);
col_amax = fmaxf(col_amax, abs_input);
}
#pragma unroll
for (int delta = 16; delta > 0; delta /= 2) {
float other_row_amax = __shfl_down_sync(0xFFFFFFFF, row_amax, delta);
row_amax = fmaxf(row_amax, other_row_amax);
}
if (lane_idx == 0) {
smem_amax_rowwise[i][warp_idx] = row_amax;
}
r++;
}
amax_colwise[blockIdx.y * amax_colwise_stride + c] = static_cast<IType>(col_amax);
__syncthreads();
int r_ = threadIdx.x / (kColsPerTile / 32); // rows in shared memory
int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory
r = blockIdx.y * kRowsPerTile + r_;
c = blockIdx.x * kColsPerTile / 32 + c_;
amax_rowwise[r * amax_rowwise_stride + c] = static_cast<IType>(smem_amax_rowwise[r_][c_]);
}
template <typename IType, typename OType>
__global__ void __launch_bounds__(kThreadsPerBlock)
mxfp8_scaling_partial_cast_kernel(const IType *input, OType *output_rowwise,
OType *output_colwise, const e8m0_t *scale_inv_rowwise,
const e8m0_t *scale_inv_colwise, int scale_inv_rowwise_stride,
int scale_inv_colwise_stride, int rows, int cols,
size_t start_offset, size_t len) {
__shared__ float smem_scales_rowwise[kRowsPerTile][kColsPerTile / 32];
__shared__ float smem_scales_colwise[kColsPerTile];
// Load scales_rowwise
{
int r_ = threadIdx.x / (kColsPerTile / 32); // rows in shared memory
int c_ = threadIdx.x % (kColsPerTile / 32); // cols in shared memory
int r = blockIdx.y * kRowsPerTile + r_;
int c = blockIdx.x * kColsPerTile / 32 + c_;
size_t idx = r * scale_inv_rowwise_stride + c;
smem_scales_rowwise[r_][c_] = ptx::exp2f_rcp(scale_inv_rowwise[idx]);
}
// Load scales_colwise
{
int c_ = threadIdx.x;
int r = blockIdx.y * kRowsPerTile / 32;
int c = blockIdx.x * kColsPerTile + c_;
size_t idx = r * scale_inv_colwise_stride + c;
smem_scales_colwise[c_] = ptx::exp2f_rcp(scale_inv_colwise[idx]);
}
__syncthreads();
size_t end_offset = start_offset + len;
const IType *input_minus_offset = input - start_offset;
OType *output_rowwise_minus_offset = output_rowwise - start_offset;
OType *output_colwise_minus_offset = output_colwise - start_offset;
int warp_idx = threadIdx.x / 32;
// int lane_idx = threadIdx.x % 32;
int c = blockIdx.x * kColsPerTile + threadIdx.x;
int r = blockIdx.y * kRowsPerTile;
#pragma unroll
for (int i = 0; i < kRowsPerTile; i++) {
size_t idx = r * cols + c;
if (r < rows && c < cols && idx >= start_offset && idx < end_offset) {
float inp = static_cast<float>(input_minus_offset[idx]);
OType out_rowwise = static_cast<OType>(inp * smem_scales_rowwise[i][warp_idx]);
OType out_colwise = static_cast<OType>(inp * smem_scales_colwise[threadIdx.x]);
output_rowwise_minus_offset[idx] = out_rowwise;
output_colwise_minus_offset[idx] = out_colwise;
}
r++;
}
}
void mxfp8_scaling_compute_partial_amax(const Tensor input, Tensor amax_rowwise,
Tensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream) {
NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32");
NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");
NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols,
"Invalid start_offset");
NVTE_CHECK(amax_rowwise.data.shape.size() == 2, "amax_rowwise must be a 2D tensor");
NVTE_CHECK(amax_rowwise.data.shape[0] % rowwise_row_padding == 0,
"Wrong padding of amax_rowwise's rows");
NVTE_CHECK(amax_rowwise.data.shape[0] >= rows, "Invalid rows");
NVTE_CHECK(amax_rowwise.data.shape[1] % rowwise_col_padding == 0,
"Wrong padding of amax_rowwise's cols");
NVTE_CHECK(amax_rowwise.data.shape[1] >= cols / 32, "Invalid cols");
NVTE_CHECK(amax_rowwise.dtype() == input.dtype(), "Wrong dtype of amax_rowwise");
NVTE_CHECK(amax_colwise.data.shape.size() == 2, "amax_colwise must be a 2D tensor");
NVTE_CHECK(amax_colwise.data.shape[0] % colwise_row_padding == 0,
"Wrong padding of amax_colwise's rows");
NVTE_CHECK(amax_colwise.data.shape[0] >= rows / 32, "Invalid rows");
NVTE_CHECK(amax_colwise.data.shape[1] % colwise_col_padding == 0,
"Wrong padding of amax_colwise's cols");
NVTE_CHECK(amax_colwise.data.shape[1] >= cols, "Invalid cols");
NVTE_CHECK(amax_colwise.dtype() == input.dtype(), "Wrong dtype of amax_colwise");
int blocks_x = (cols + kColsPerTile - 1) / kColsPerTile;
int blocks_y = (rows + kRowsPerTile - 1) / kRowsPerTile;
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
mxfp8_scaling_compute_partial_amax_kernel<IType><<<grid, kColsPerTile, 0, stream>>>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<IType *>(amax_rowwise.data.dptr),
reinterpret_cast<IType *>(amax_colwise.data.dptr), amax_rowwise.data.shape[1],
amax_colwise.data.shape[1], rows, cols, start_offset, input.data.shape[0]);)
}
void mxfp8_scaling_partial_cast(const Tensor input, Tensor output_rowwise, Tensor output_colwise,
const Tensor scale_inv_rowwise, const Tensor scale_inv_colwise,
int rows, int cols, size_t start_offset, cudaStream_t stream) {
NVTE_CHECK(rows % 32 == 0, "rows must be divisible by 32");
NVTE_CHECK(cols % 32 == 0, "cols must be divisible by 32");
NVTE_CHECK(input.data.shape.size() == 1, "input must be a 1D tensor");
NVTE_CHECK(start_offset + input.data.shape[0] <= static_cast<size_t>(rows) * cols,
"Invalid start_offset");
NVTE_CHECK(output_rowwise.data.shape.size() == 1, "output_rowwise must be a 1D tensor");
NVTE_CHECK(output_colwise.data.shape.size() == 1, "output_colwise must be a 1D tensor");
NVTE_CHECK(output_rowwise.data.shape[0] == input.data.shape[0],
"Size of input and output_rowwise mismatch");
NVTE_CHECK(output_colwise.data.shape[0] == input.data.shape[0],
"Size of input and output_colwise mismatch");
NVTE_CHECK(output_rowwise.dtype() == DType::kFloat8E4M3 || output_rowwise.dtype() == DType::kByte,
"output_rowwise should be e4m3 or uint8");
NVTE_CHECK(output_colwise.dtype() == DType::kFloat8E4M3 || output_colwise.dtype() == DType::kByte,
"output_colwise should be e4m3 or uint8");
NVTE_CHECK(scale_inv_rowwise.data.shape.size() == 2, "scale_inv_rowwise must be a 2D tensor");
NVTE_CHECK(scale_inv_rowwise.data.shape[0] % rowwise_row_padding == 0,
"Wrong padding of scale_inv_rowwise's rows");
NVTE_CHECK(scale_inv_rowwise.data.shape[0] >= rows, "Invalid rows");
NVTE_CHECK(scale_inv_rowwise.data.shape[1] % rowwise_col_padding == 0,
"Wrong padding of scale_inv_rowwise's cols");
NVTE_CHECK(scale_inv_rowwise.data.shape[1] >= cols / 32, "Invalid cols");
NVTE_CHECK(scale_inv_rowwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_rowwise");
NVTE_CHECK(scale_inv_colwise.data.shape.size() == 2, "scale_inv_colwise must be a 2D tensor");
NVTE_CHECK(scale_inv_colwise.data.shape[0] % colwise_row_padding == 0,
"Wrong padding of scale_inv_colwise's rows");
NVTE_CHECK(scale_inv_colwise.data.shape[0] >= rows / 32, "Invalid rows");
NVTE_CHECK(scale_inv_colwise.data.shape[1] % colwise_col_padding == 0,
"Wrong padding of scale_inv_colwise's cols");
NVTE_CHECK(scale_inv_colwise.data.shape[1] >= cols, "Invalid cols");
NVTE_CHECK(scale_inv_colwise.dtype() == DType::kByte, "Wrong dtype of scale_inv_colwise");
int blocks_x = (cols + kColsPerTile - 1) / kColsPerTile;
int blocks_y = (rows + kRowsPerTile - 1) / kRowsPerTile;
dim3 grid(blocks_x, blocks_y);
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input.dtype(), IType,
mxfp8_scaling_partial_cast_kernel<IType, fp8e4m3><<<grid, kColsPerTile, 0, stream>>>(
reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<fp8e4m3 *>(output_rowwise.data.dptr),
reinterpret_cast<fp8e4m3 *>(output_colwise.data.dptr),
reinterpret_cast<const e8m0_t *>(scale_inv_rowwise.data.dptr),
reinterpret_cast<const e8m0_t *>(scale_inv_colwise.data.dptr),
scale_inv_rowwise.data.shape[1], scale_inv_colwise.data.shape[1], rows, cols,
start_offset, input.data.shape[0]);)
}
} // namespace mxfp8_scaling_recipe
} // namespace transformer_engine
void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise,
NVTETensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_mxfp8_scaling_compute_partial_amax);
using namespace transformer_engine;
mxfp8_scaling_recipe::mxfp8_scaling_compute_partial_amax(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(amax_rowwise),
*convertNVTETensorCheck(amax_colwise), rows, cols, start_offset, stream);
}
void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise,
NVTETensor output_colwise, const NVTETensor scale_inv_rowwise,
const NVTETensor scale_inv_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream) {
NVTE_API_CALL(nvte_mxfp8_scaling_partial_cast);
using namespace transformer_engine;
mxfp8_scaling_recipe::mxfp8_scaling_partial_cast(
*convertNVTETensorCheck(input), *convertNVTETensorCheck(output_rowwise),
*convertNVTETensorCheck(output_colwise), *convertNVTETensorCheck(scale_inv_rowwise),
*convertNVTETensorCheck(scale_inv_colwise), rows, cols, start_offset, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -342,68 +342,122 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_
} // namespace
void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) {
NVTE_CHECK(
input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()),
"Input tensor has invalid dtype (", to_string(input->dtype()), ").");
// Do nothing if tensor is empty
if (input->data.numel() == 0) {
return;
}
// Check scaling mode
const auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ").");
// Check tensors
CheckInputTensor(*input, "scaling_factor_input");
CheckInputTensor(*output, "scaling_factor_output");
NVTE_CHECK(!input->with_gemm_swizzled_scales,
"Expected input tensor with scales in compact format.");
NVTE_CHECK(output->with_gemm_swizzled_scales,
"Expected output tensor with scales in GEMM swizzled format.");
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
to_string(input->dtype()), ").");
break;
case NVTE_NVFP4_1D_SCALING:
NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ",
to_string(input->dtype()), ").");
break;
default:
NVTE_ERROR("Invalid scaling mode");
}
auto& scaling_mode = input->scaling_mode;
NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING,
"Unsupported scaling mode for swizzling.");
bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;
// Check if scaling factors are non-trivial
const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}
// 1D block scaling, row-wise or colum-wise
int m, k;
if (input->has_data()) {
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else {
if (nvfp4) {
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
} else {
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
// Deduce tensor dims
int m{0}, k{0};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
// Check dims
constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output->has_data()) {
NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(),
output->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
// Check that output tensor matches input tensor
if (has_rowwise_scale_inv) {
NVTE_CHECK(output->scale_inv.has_data(),
"Output tensor does not have row-wise scaling factors.");
NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k,
" row-wise scaling factors, but got shape=", output->scale_inv.shape, ".");
}
if (output->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(),
output->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Output tensor does not have column-wise scaling factors.");
NVTE_CHECK(
m * k == output->columnwise_scale_inv.numel(), "Expected output tensor to have ", m * k,
" column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, ".");
}
int num_tiles_m = m / SF_TILE_DIM_M;
int num_tiles_k = k / SF_TILE_DIM_K;
// Choose swizzle implementation
bool rowwise_swizzle{false}, columnwise_swizzle{false};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
rowwise_swizzle = has_rowwise_scale_inv;
columnwise_swizzle = has_columnwise_scale_inv;
break;
}
case NVTE_NVFP4_1D_SCALING: {
// NVFP4 column-wise data is transposed, so row-wise and
// column-wise scales have same swizzling format
rowwise_swizzle = true;
columnwise_swizzle = false;
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
// For NVFP4, the scale inverse for tranposed data needs rowwise swizzle.
const bool rowwise_swizzle = input->has_data() || nvfp4;
const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4;
const dim3 block_size(TB_DIM, TB_DIM);
const int num_tiles_m = m / SF_TILE_DIM_M;
const int num_tiles_k = k / SF_TILE_DIM_K;
dim3 block_size(TB_DIM, TB_DIM);
// Perform row-wise swizzle
if (rowwise_swizzle) {
int vec_load_size = (num_tiles_k - 1) % 4 + 1;
/* there is no int3 and misaligned if using int4/int2 */
......@@ -412,20 +466,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m);
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
int original_M, original_K;
void *input_scale_inv_ptr, *output_scale_inv_ptr;
if (!nvfp4 || input->has_data()) {
int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE;
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / block_scale_size;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else {
original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
int original_M{0}, original_K{0};
void *input_scale_inv_ptr{nullptr}, *output_scale_inv_ptr{nullptr};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
original_M = input->flat_first_dim();
original_K = input->flat_last_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->scale_inv.dptr;
output_scale_inv_ptr = output->scale_inv.dptr;
} else if (has_columnwise_scale_inv) {
original_M = input->flat_last_dim();
original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE;
input_scale_inv_ptr = input->columnwise_scale_inv.dptr;
output_scale_inv_ptr = output->columnwise_scale_inv.dptr;
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
switch (vec_load_size) {
......@@ -481,7 +547,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
// Perform column-wise swizzle
if (columnwise_swizzle) {
int vec_load_size = (num_tiles_m - 1) % 4 + 1;
if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */
......@@ -490,8 +559,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t);
const int original_M = input->flat_last_dim();
const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE;
// NVFP4 shouldn't end up here because it only needs rowwise swizzle
NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle");
switch (vec_load_size) {
#ifdef __HIP_PLATFORM_AMD__
......@@ -552,8 +619,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s
NVTE_ERROR("Not valid vec_load_size.");
break;
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <int SF_TILE_DIM_M, int SF_TILE_DIM_K>
......@@ -702,18 +769,24 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK(
(is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)),
"Not implemented scaling mode " + to_string(scaling_mode) + ".");
NVTE_CHECK(!input[i]->with_gemm_swizzled_scales,
"Expected input tensors with scales in compact format.");
NVTE_CHECK(output[i]->with_gemm_swizzled_scales,
"Expected output tensors with scales in GEMM swizzled format.");
// We don't allow empty tensors. They should be filtered out before calling this function.
if (input[i]->data.numel() == 0) {
NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty.");
}
NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty.");
CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]");
CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]");
all_has_data &= input[i]->has_data();
all_has_columnwise_data &= input[i]->has_columnwise_data();
all_nvfp4 &= is_nvfp4_scaling(scaling_mode);
all_has_data = all_has_data && input[i]->scale_inv.has_data();
all_has_columnwise_data =
(all_has_columnwise_data && input[i]->columnwise_scale_inv.has_data());
all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode);
}
NVTE_CHECK(all_has_data || all_has_columnwise_data,
"All tensors should have data or columnwise data.");
NVTE_CHECK(!all_has_data || !all_has_columnwise_data,
"All tensors have both data and columnwise data.");
const bool rowwise_swizzle = all_has_data || all_nvfp4;
const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4;
......@@ -752,18 +825,19 @@ void multi_tensor_swizzle_scaling_factors(const std::vector<Tensor*>& input,
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");
NVTE_CHECK(k > 0, "Input scale inverse should be 2D!");
if (output[i]->has_data()) {
NVTE_CHECK(
m * k == std::accumulate(output[i]->scale_inv.shape.begin(),
output[i]->scale_inv.shape.end(), 1, std::multiplies<int>()),
"Input.scale_inv size is not equal to Output.scale_inv size!");
if (all_has_data) {
NVTE_CHECK(output[i]->scale_inv.has_data(), "Output tensor ", i,
" does not have row-wise scaling factors.");
NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ",
m * k, " row-wise scaling factors, but got shape=", output[i]->scale_inv.shape,
".");
}
if (output[i]->has_columnwise_data()) {
NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(),
output[i]->columnwise_scale_inv.shape.end(), 1,
std::multiplies<int>()),
"Input.columnwise_scale_inv size is not equal to "
"Output.columnwise_scale_inv size!");
if (all_has_columnwise_data) {
NVTE_CHECK(output[i]->columnwise_scale_inv.has_data(), "Output tensor ", i,
" does not have column-wise scaling factors.");
NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i,
" to have ", m * k, " column-wise scaling factors, but got shape=",
output[i]->columnwise_scale_inv.shape, ".");
}
int num_tiles_k = k / SF_TILE_DIM_K;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -99,7 +99,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
const void* const warp_src =
(reinterpret_cast<const uint8_t*>(in) + in_tile_y * in_y_stride + in_tile_x * in_x_stride);
// load scaling factors for this lane's initial four 1x128 tiles
uint4 sf;
......@@ -114,7 +115,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
}
// pack the exponent bits of the scaling factors
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
uint32_t packed_exponents = ((sf.x >> 23) & 0xFF) | (((sf.y >> 23) & 0xFF) << 8) |
(((sf.z >> 23) & 0xFF) << 16) | (((sf.w >> 23) & 0xFF) << 24);
// partially swizzle the scaling factors
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches
......@@ -129,7 +131,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// store them cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
void* const warp_dst =
(reinterpret_cast<uint8_t*>(out) + out_tile_y * out_y_stride + out_tile_x * out_x_stride);
reinterpret_cast<uint4*>(warp_dst)[lane] = sf;
}
......@@ -193,21 +196,24 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
// calculate this warp's input base pointer
constexpr uint32_t in_x_stride = sizeof(float);
const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride;
const void* const warp_src =
(reinterpret_cast<const uint8_t*>(in) + in_tile_y * in_y_stride + in_tile_x * in_x_stride);
// load scaling factor for this warp's 128x128 tile
uint32_t sf = *reinterpret_cast<const uint32_t*>(warp_src);
// broadcast it to four scaling factors for 1x32 tiles
sf = (sf << 1) | (sf >> 7);
sf = sf | (sf >> 16);
// extract and broadcast the exponent byte to four bytes for E8M0 format
uint32_t exp_byte = (sf >> 23) & 0xFF;
sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);
// broadcast it to sixteen scaling factors for 1x32 tiles
const uint4 sf4{sf, sf, sf, sf};
// store it cooperatively for 512 1x32 tiles in a 128x128 tile
constexpr uint32_t out_x_stride = 512;
void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride;
void* const warp_dst =
(reinterpret_cast<uint8_t*>(out) + out_tile_y * out_y_stride + out_tile_x * out_x_stride);
reinterpret_cast<uint4*>(warp_dst)[lane] = sf4;
}
......@@ -260,6 +266,9 @@ void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor*
NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0,
"Output must have E8M0 scaling factors");
NVTE_CHECK(output->with_gemm_swizzled_scales,
"Expected output tensor with scales in GEMM swizzled format.");
NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data");
NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input");
NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors");
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <transformer_engine/transformer_engine.h>
#include <algorithm>
#include <atomic>
#include <climits>
#include <cstring>
#include <iostream>
#include <mutex>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "common.h"
#include "common/util/cuda_runtime.h"
......@@ -81,7 +85,7 @@ std::string to_string(const NVTEScalingMode &mode) {
}
void CheckNoopTensor(const Tensor &t, const std::string &name) {
if (t.data.dptr != nullptr) {
if (t.data.has_data()) {
NVTE_CHECK(t.numel() == 1, "Expected 1 element for ", name, " noop, but found ", t.numel(),
".");
NVTE_CHECK(t.data.dtype == DType::kFloat32, "Found wrong dtype for ", name,
......@@ -92,15 +96,30 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) {
void CheckScaleTensorShape(const Tensor &t, const std::string &name) {
NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!");
if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected (1), got ",
t.columnwise_scale_inv.shape, ")");
if (is_fp8_dtype(t.dtype())) {
// FP8 tensor with tensor scaling
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid scale_inv shape (expected 1 entry, got ", t.scale_inv.shape,
")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected 1 entry, got ",
t.columnwise_scale_inv.shape, ")");
}
} else {
// High-precision tensor
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 0, "Tensor \"", name,
"\" has invalid scale_inv shape (expected 0 entries, got ", t.scale_inv.shape,
")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 0, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected 0 entries, got ",
t.columnwise_scale_inv.shape, ")");
}
}
} else {
if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) {
......@@ -163,7 +182,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
if (is_fp8_dtype(type) || is_int8_dtype(type)) {
// FP8 input needs to have scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor input ", name,
......@@ -172,7 +191,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor input ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
......@@ -185,7 +204,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
// TODO(ksivaman): Fix this to check for amaxes and other details.
// For now only needed for swizzle.
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor input ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name,
"_scale_inverse has invalid dtype "
......@@ -193,7 +212,7 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor input ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ",
name,
......@@ -202,11 +221,10 @@ void CheckInputTensor(const Tensor &t, const std::string &name) {
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
name);
}
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input ", name, " is not allocated!");
......@@ -217,14 +235,14 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
const DType type = t.dtype();
if (is_fp8_dtype(type) || is_int8_dtype(type)) {
// FP8 output needs to have scale, scale_inv and (if delayed scaling) amax
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.dptr != nullptr) {
if (t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING && t.amax.has_data()) {
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Invalid amax dtype (expected ",
to_string(DType::kFloat32), ", got ", to_string(t.amax.dtype), ")");
NVTE_CHECK(product(t.amax.shape) == 1, "Invalid shape of amax in output ", name,
NVTE_CHECK(t.amax.numel() == 1, "Invalid shape of amax in output ", name,
" (expected 1 entry, got shape=", t.amax.shape, ")");
}
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP8 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32 || t.scale_inv.dtype == DType::kFloat8E8M0,
"FP8 scaling factor output ", name,
......@@ -233,7 +251,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP8 scaling factor output ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP8 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat32 ||
t.columnwise_scale_inv.dtype == DType::kFloat8E8M0,
......@@ -245,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
} else if (is_fp4_dtype(type)) {
// FP4 output needs to have the scale_inv
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
NVTE_CHECK(t.scale_inv.has_data(), "FP4 scaling factor output ", name,
"_scale_inverse must be allocated");
NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name,
"_scale_inverse has invalid dtype "
......@@ -253,7 +271,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name,
NVTE_CHECK(t.columnwise_scale_inv.has_data(), "FP4 scaling factor output ", name,
"_columnwise_scale_inverse must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ",
name,
......@@ -262,12 +280,10 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
to_string(t.columnwise_scale_inv.dtype), ")");
}
} else {
NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name);
// Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax.
// NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name);
NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr,
"Scale_inv is not supported for non-FP8 input ", name);
NVTE_CHECK(!t.scale.has_data(), "Scale is not supported for non-FP8 output ", name);
NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv is not supported for non-FP8 output ", name);
NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv is not supported for non-FP8 input ",
name);
}
if (!allow_empty) {
......@@ -277,6 +293,128 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
CheckScaleTensorShape(t, name);
}
void CheckGroupedTensorShapeArrays(const GroupedTensor &t, const std::string &name) {
NVTE_CHECK(t.num_tensors > 0, "Grouped tensor ", name, " has no tensors!");
// Helper lambda to validate shape arrays
// All three arrays are OPTIONAL:
// - first_dims: empty if all tensors have same first dimension
// - last_dims: empty if all tensors have same last dimension
// - tensor_offsets: empty if all tensors have same shape (offsets are predictable)
auto check_shape_array = [&](const SimpleTensor &arr, const char *arr_name) {
if (arr.has_data()) {
NVTE_CHECK(arr.shape.size() == 1, "Grouped tensor ", name, " ", arr_name, " must be 1D");
NVTE_CHECK(arr.dtype == DType::kInt64, "Grouped tensor ", name, " ", arr_name,
" must have dtype Int64");
NVTE_CHECK(arr.shape[0] == t.num_tensors, "Grouped tensor ", name, " ", arr_name, " size (",
arr.shape[0], ") must equal num_tensors (", t.num_tensors, ")");
}
};
// Validate shape arrays (all optional)
check_shape_array(t.first_dims, "first_dims");
check_shape_array(t.last_dims, "last_dims");
check_shape_array(t.tensor_offsets, "tensor_offsets");
// tensor_offsets is required if any dimension varies
// (i.e., required unless all_same_shape())
if (!t.all_same_shape()) {
NVTE_CHECK(
t.tensor_offsets.dptr != nullptr, "Grouped tensor ", name,
" must have tensor_offsets when any dimension varies (first_dims or last_dims is set)");
}
// Validate logical_shape
NVTE_CHECK(t.logical_shape.ndim == 2, "Grouped tensor ", name, " logical_shape must be 2D");
NVTE_CHECK(t.logical_shape.data[0] > 0 && t.logical_shape.data[1] > 0, "Grouped tensor ", name,
" logical_shape must have positive dimensions");
// Validate all data fields are 1D (flattened)
if (t.has_data()) {
NVTE_CHECK(t.data.shape.size() == 1, "Grouped tensor ", name, " data must be 1D");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_data.shape.size() == 1, "Grouped tensor ", name,
" columnwise_data must be 1D");
}
// Validate data size matches logical_shape
size_t expected_numel = t.logical_shape.data[0] * t.logical_shape.data[1];
if (t.has_data()) {
NVTE_CHECK(t.data.numel() == expected_numel, "Grouped tensor ", name, " data size (",
t.data.numel(), ") must match logical_shape size (", expected_numel, ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_data.numel() == expected_numel, "Grouped tensor ", name,
" columnwise_data size (", t.columnwise_data.numel(),
") must match logical_shape size (", expected_numel, ")");
}
}
// Helper function to check scale_inv for both input and output
static void CheckGroupedScaleInv(const GroupedTensor &t, const std::string &name, bool is_output) {
const char *tensor_type = is_output ? "output" : "input";
// Helper to check scale_inv for both rowwise and columnwise layouts
auto check_scales = [&](DType expected_dtype) {
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.has_data(), tensor_type, " ", name,
" rowwise scale_inv must be allocated");
NVTE_CHECK(t.scale_inv.dtype == expected_dtype, tensor_type, " ", name,
" rowwise scale_inv has invalid dtype (expected ", to_string(expected_dtype),
", got ", to_string(t.scale_inv.dtype), ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.has_data(), tensor_type, " ", name,
" columnwise scale_inv must be allocated");
NVTE_CHECK(t.columnwise_scale_inv.dtype == expected_dtype, tensor_type, " ", name,
" columnwise scale_inv has invalid dtype (expected ", to_string(expected_dtype),
", got ", to_string(t.columnwise_scale_inv.dtype), ")");
}
};
// Determine expected dtype based on data type and scaling mode
if (is_fp8_dtype(t.dtype()) && is_tensor_scaling(t.scaling_mode)) {
check_scales(DType::kFloat32);
} else if (is_mxfp8_scaling(t.scaling_mode)) {
check_scales(DType::kFloat8E8M0);
} else if (is_nvfp4_scaling(t.scaling_mode)) {
check_scales(DType::kFloat8E4M3);
} else {
// Non-quantized types should not have scale/scale_inv
NVTE_CHECK(!t.scale_inv.has_data(), "Scale_inv not supported for non-quantized ", tensor_type,
" ", name);
NVTE_CHECK(!t.columnwise_scale_inv.has_data(), "Scale_inv not supported for non-quantized ",
tensor_type, " ", name);
}
}
void CheckInputGroupedTensor(const GroupedTensor &t, const std::string &name) {
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Input grouped tensor ", name,
" not allocated");
CheckGroupedScaleInv(t, name, false);
CheckGroupedTensorShapeArrays(t, name);
}
void CheckOutputGroupedTensor(const GroupedTensor &t, const std::string &name, bool allow_empty) {
if (!allow_empty) {
NVTE_CHECK(t.has_data() || t.has_columnwise_data(), "Output grouped tensor ", name,
" not allocated");
}
// Only perform dtype-specific validation if data is allocated
if (t.has_data() || t.has_columnwise_data()) {
// Amax validation for delayed scaling
if (is_fp8_dtype(t.dtype()) && t.scaling_mode == NVTE_DELAYED_TENSOR_SCALING) {
NVTE_CHECK(t.amax.has_data(), "Output ", name, " amax must be allocated");
NVTE_CHECK(t.amax.dtype == DType::kFloat32, "Output ", name, " amax must be Float32");
}
CheckGroupedScaleInv(t, name, true);
}
CheckGroupedTensorShapeArrays(t, name);
}
class TensorAllocator {
public:
static TensorAllocator &instance() {
......@@ -391,6 +529,89 @@ Tensor *convertNVTETensorCheck(const NVTETensor t) {
return ptr;
}
// GroupedTensor allocator - similar pattern to TensorAllocator
class GroupedTensorAllocator {
public:
static GroupedTensorAllocator &instance() {
static GroupedTensorAllocator allocator;
return allocator;
}
~GroupedTensorAllocator() {}
NVTEGroupedTensor Allocate(NVTEScalingMode mode, size_t num_tensors, NVTEShape logical_shape) {
std::lock_guard<std::mutex> lock(mutex);
if (!free_list.empty()) {
uintptr_t index = free_list.back();
NVTEGroupedTensor ret = reinterpret_cast<NVTEGroupedTensor>(index);
free_list.pop_back();
// 1-based indexing - fully reinitialize the tensor to avoid stale data
memory[index - 1].scaling_mode = mode;
memory[index - 1].num_tensors = num_tensors;
memory[index - 1].logical_shape = logical_shape;
memory[index - 1].nvte_tensor = ret;
return ret;
}
if (memory.size() < memory.capacity()) {
memory.emplace_back(mode, num_tensors);
GroupedTensor &t = memory.back();
size = memory.size();
// 1-based indexing
uintptr_t index = memory.size();
t.logical_shape = logical_shape;
t.nvte_tensor = reinterpret_cast<NVTEGroupedTensor>(index);
return reinterpret_cast<NVTEGroupedTensor>(index);
}
NVTE_ERROR(
"Cannot allocate a new NVTEGroupedTensor. Maximum number of grouped tensors reached: ",
MAX_GROUPED_TENSOR_NUM, ". There is probably a memory leak in your application.");
}
void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index);
// Clean up
memory[index - 1].clear();
}
GroupedTensor *convertNVTEGroupedTensor(NVTEGroupedTensor t) {
uintptr_t index = reinterpret_cast<uintptr_t>(t);
// 1-based indexing to enable 0-initialization of NVTEGroupedTensor
// to be invalid tensor
static_assert(nullptr == 0);
if (index != 0 && index <= size) {
return &(memory[index - 1]);
}
return nullptr;
}
private:
GroupedTensorAllocator() {
std::lock_guard<std::mutex> lock(mutex);
memory.reserve(MAX_GROUPED_TENSOR_NUM);
}
std::mutex mutex;
std::atomic<size_t> size;
// Allocate at most 20 MB for grouped tensors
const size_t MAX_GROUPED_TENSOR_NUM = 20 * 1024 * 1024 / sizeof(GroupedTensor);
std::vector<uintptr_t> free_list;
std::vector<GroupedTensor> memory;
};
GroupedTensor *convertNVTEGroupedTensor(const NVTEGroupedTensor t) {
return GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t);
}
GroupedTensor *convertNVTEGroupedTensorCheck(const NVTEGroupedTensor t) {
GroupedTensor *ptr = GroupedTensorAllocator::instance().convertNVTEGroupedTensor(t);
NVTE_CHECK(ptr != nullptr, "Invalid grouped tensor.");
return ptr;
}
} // namespace transformer_engine
NVTETensor nvte_create_tensor(NVTEScalingMode scaling_mode) {
......@@ -427,7 +648,11 @@ NVTEShape nvte_make_shape(const size_t *data, size_t ndim) {
NVTE_CHECK(ndim <= sizeof(ret.data) / sizeof(ret.data[0]),
"Too many dims for NVTEShape (requested: ", ndim,
", max: ", sizeof(ret.data) / sizeof(ret.data[0]), ")");
std::copy(data, data + ndim, ret.data);
if (data == nullptr) {
std::fill(ret.data, ret.data + ndim, 0);
} else {
std::copy(data, data + ndim, ret.data);
}
ret.ndim = ndim;
return ret;
}
......@@ -540,7 +765,7 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) {
NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) {
auto *t = transformer_engine::convertNVTETensor(tensor);
if (t == nullptr) {
return nvte_make_shape(nullptr, 0);
return nvte_make_shape(nullptr, 1);
}
return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size());
}
......@@ -573,13 +798,14 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
t->columnwise_amax = *param;
break;
default:
NVTE_ERROR("Unknown tensor parameter!");
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param_name),
"). Consider using nvte_set_tensor_param_v2 instead.");
}
}
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)};
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
}
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
switch (param_name) {
......@@ -598,7 +824,148 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p
case kNVTEColumnwiseAmax:
return t.columnwise_amax;
default:
NVTE_ERROR("Unknown tensor parameter!");
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param_name),
"). Consider using nvte_set_tensor_param_v2 instead.");
}
}
void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf,
size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast<int>(param),
")");
NVTE_CHECK(tensor != nullptr, "Tensor pointer can't be NULL.");
auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
const auto &attr_size = transformer_engine::Tensor::attr_sizes[param];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for tensor parameter "
"(parameter ",
static_cast<int>(param), " needs ", attr_size, " bytes, but buffer has ",
size_in_bytes, " bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
switch (param) {
case kNVTERowwiseData: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.data = *basic_tensor;
break;
}
case kNVTEColumnwiseData: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.columnwise_data = *basic_tensor;
break;
}
case kNVTEScale: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.scale = *basic_tensor;
break;
}
case kNVTEAmax: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.amax = *basic_tensor;
break;
}
case kNVTERowwiseScaleInv: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.scale_inv = *basic_tensor;
break;
}
case kNVTEColumnwiseScaleInv: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.columnwise_scale_inv = *basic_tensor;
break;
}
case kNVTEColumnwiseAmax: {
const NVTEBasicTensor *basic_tensor = reinterpret_cast<const NVTEBasicTensor *>(buf);
t.columnwise_amax = *basic_tensor;
break;
}
case kNVTEWithGEMMSwizzledScales:
t.with_gemm_swizzled_scales = static_cast<bool>(*reinterpret_cast<const uint8_t *>(buf));
break;
default:
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param), ")");
}
}
void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf,
size_t size_in_bytes, size_t *size_written) {
using namespace transformer_engine;
// Check param
NVTE_CHECK(param < kNVTENumTensorParams, "Invalid NVTETensorParam (got ", static_cast<int>(param),
")");
// Write attribute size if provided
const auto &attr_size = Tensor::attr_sizes[param];
if (size_written != nullptr) {
*size_written = attr_size;
}
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for tensor parameter "
"(parameter ",
static_cast<int>(param), " needs ", attr_size, " bytes, but buffer has ",
size_in_bytes, " bytes)");
// Get C++ tensor
const Tensor *t = convertNVTETensor(tensor);
std::optional<Tensor> dummy;
if (t == nullptr) {
// Make dummy tensor if provided tensor is invalid
dummy.emplace();
t = &(*dummy);
}
// Write to buffer
switch (param) {
case kNVTERowwiseData: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->data);
break;
}
case kNVTEColumnwiseData: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_data);
break;
}
case kNVTEScale: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->scale);
break;
}
case kNVTEAmax: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->amax);
break;
}
case kNVTERowwiseScaleInv: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->scale_inv);
break;
}
case kNVTEColumnwiseScaleInv: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_scale_inv);
break;
}
case kNVTEColumnwiseAmax: {
NVTEBasicTensor *basic_tensor = reinterpret_cast<NVTEBasicTensor *>(buf);
*basic_tensor = static_cast<NVTEBasicTensor>(t->columnwise_amax);
break;
}
case kNVTEWithGEMMSwizzledScales:
*reinterpret_cast<uint8_t *>(buf) = static_cast<uint8_t>(t->with_gemm_swizzled_scales);
break;
default:
NVTE_ERROR("Unsupported tensor parameter (", static_cast<int>(param), ")");
}
}
......@@ -624,14 +991,21 @@ void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) {
if (tensor == nullptr) return;
const auto &t = *transformer_engine::convertNVTETensorCheck(tensor);
// Zero out tensor data if allocated
if (t.data.dptr != nullptr) {
const size_t size_in_bytes = nvte_tensor_size_bytes(tensor);
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream));
const auto size = t.data.buffer_size_bytes();
if (size > 0) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size, stream));
}
}
// Set amax to 0 if allocated
// Zero out amax if allocated
if (t.amax.dptr != nullptr) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream));
const auto size = t.amax.buffer_size_bytes();
if (size > 0) {
NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, size, stream));
}
}
}
......@@ -642,12 +1016,15 @@ NVTEQuantizationConfig nvte_create_quantization_config() {
void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
using namespace transformer_engine;
// Write attribute size
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
*size_written = attr_size;
const auto &attr_size = QuantizationConfig::attr_sizes[attr];
if (size_written != nullptr) {
*size_written = attr_size;
}
// Return immediately if buffer is not provided
if (buf == nullptr) {
......@@ -661,12 +1038,18 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto bool_to_uint8 = [](bool in, void *out) {
*reinterpret_cast<uint8_t *>(out) = static_cast<uint8_t>(in);
};
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::QuantizationConfig *>(config);
const auto &config_ = *reinterpret_cast<const QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(buf, &config_.force_pow_2_scales, attr_size);
bool_to_uint8(config_.force_pow_2_scales, buf);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(buf, &config_.amax_epsilon, attr_size);
......@@ -674,8 +1057,23 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
case kNVTEQuantizationConfigNoopTensor:
std::memcpy(buf, &config_.noop_tensor, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(buf, &config_.float8_block_scale_tensor_format, attr_size);
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: {
// Deprecated
const auto invalid = Float8BlockScaleTensorFormat::INVALID;
std::memcpy(buf, &invalid, attr_size);
break;
}
case kNVTEQuantizationConfigRNGState:
std::memcpy(buf, &config_.rng_state, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
bool_to_uint8(config_.nvfp4_2d_quantization, buf);
break;
case kNVTEQuantizationConfigStochasticRounding:
bool_to_uint8(config_.stochastic_rounding, buf);
break;
case kNVTEQuantizationConfigUseFastMath:
bool_to_uint8(config_.use_fast_math, buf);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
......@@ -685,10 +1083,12 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
size_t size_in_bytes) {
using namespace transformer_engine;
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEQuantizationConfigNumAttributes,
"Invalid NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::QuantizationConfig::attr_sizes[attr];
const auto &attr_size = QuantizationConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for quantization config attribute "
"(attribute ",
......@@ -696,12 +1096,18 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// bool size is implementation-dependent, so we explicitly specify
// uint8_t in the user-facing API.
auto uint8_to_bool = [](const void *in, bool &out) {
out = static_cast<bool>(*reinterpret_cast<const uint8_t *>(in));
};
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEQuantizationConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::QuantizationConfig *>(config);
auto &config_ = *reinterpret_cast<QuantizationConfig *>(config);
switch (attr) {
case kNVTEQuantizationConfigForcePow2Scales:
std::memcpy(&config_.force_pow_2_scales, buf, attr_size);
uint8_to_bool(buf, config_.force_pow_2_scales);
break;
case kNVTEQuantizationConfigAmaxEpsilon:
std::memcpy(&config_.amax_epsilon, buf, attr_size);
......@@ -710,16 +1116,19 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
std::memcpy(&config_.noop_tensor, buf, attr_size);
break;
case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat:
std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size);
// Deprecated
break;
case kNVTEQuantizationConfigRNGState:
std::memcpy(&config_.rng_state, buf, attr_size);
break;
case kNVTEQuantizationConfigNVFP42DQuantization:
std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size);
uint8_to_bool(buf, config_.nvfp4_2d_quantization);
break;
case kNVTEQuantizationConfigStochasticRounding:
std::memcpy(&config_.stochastic_rounding, buf, attr_size);
uint8_to_bool(buf, config_.stochastic_rounding);
break;
case kNVTEQuantizationConfigUseFastMath:
uint8_to_bool(buf, config_.use_fast_math);
break;
default:
NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast<int>(attr), ")");
......@@ -736,12 +1145,146 @@ int nvte_is_non_tn_fp8_gemm_supported() {
#if USE_ROCM
return true;
#else
int deviceComputeCapability =
transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device());
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
return (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
std::call_once(flags[device_id], [&]() {
int deviceComputeCapability = transformer_engine::cuda::sm_arch(device_id);
// Note: this is temporary restriction and should be lifted in the future.
// (remove the note once it's done.)
cache[device_id] = (deviceComputeCapability >= 100 && deviceComputeCapability < 120) ||
deviceComputeCapability >= 130;
});
return cache[device_id];
#endif
}
// Grouped Tensor C API implementations
NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors,
NVTEShape logical_shape) {
NVTE_CHECK(num_tensors > 0, "Number of tensors must be greater than 0");
NVTE_CHECK(logical_shape.ndim == 2, "Logical shape must be 2D");
NVTE_CHECK(logical_shape.data[0] > 0 && logical_shape.data[1] > 0,
"Logical shape must have positive dimensions");
NVTEGroupedTensor ret = transformer_engine::GroupedTensorAllocator::instance().Allocate(
scaling_mode, num_tensors, logical_shape);
return ret;
}
void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor) {
transformer_engine::GroupedTensorAllocator::instance().Free(tensor);
}
void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name,
const NVTEBasicTensor *param) {
NVTE_CHECK(tensor != nullptr, "Grouped tensor pointer can't be NULL.");
auto *t = transformer_engine::convertNVTEGroupedTensor(*tensor);
NVTE_CHECK(t != nullptr, "Grouped tensor is not allocated.");
NVTE_CHECK(param != nullptr, "Grouped tensor param can't be NULL.");
switch (param_name) {
case kNVTEGroupedRowwiseData:
t->data = *param;
break;
case kNVTEGroupedColumnwiseData:
t->columnwise_data = *param;
break;
case kNVTEGroupedScale:
t->scale = *param;
break;
case kNVTEGroupedAmax:
t->amax = *param;
break;
case kNVTEGroupedRowwiseScaleInv:
t->scale_inv = *param;
break;
case kNVTEGroupedColumnwiseScaleInv:
t->columnwise_scale_inv = *param;
break;
case kNVTEGroupedColumnwiseAmax:
t->columnwise_amax = *param;
break;
case kNVTEGroupedFirstDims:
t->first_dims = *param;
// Validate it's Int64
NVTE_CHECK(t->first_dims.dtype == transformer_engine::DType::kInt64,
"first_dims must have dtype Int64");
break;
case kNVTEGroupedLastDims:
t->last_dims = *param;
// Validate it's Int64
NVTE_CHECK(t->last_dims.dtype == transformer_engine::DType::kInt64,
"last_dims must have dtype Int64");
break;
case kNVTEGroupedTensorOffsets:
t->tensor_offsets = *param;
// Validate it's Int64
NVTE_CHECK(t->tensor_offsets.dtype == transformer_engine::DType::kInt64,
"tensor_offsets must have dtype Int64");
break;
default:
NVTE_ERROR("Unknown grouped tensor parameter!");
}
}
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name) {
if (tensor == nullptr) {
return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 1)};
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
switch (param_name) {
case kNVTEGroupedRowwiseData:
return t.data;
case kNVTEGroupedColumnwiseData:
return t.columnwise_data;
case kNVTEGroupedScale:
return t.scale;
case kNVTEGroupedAmax:
return t.amax;
case kNVTEGroupedRowwiseScaleInv:
return t.scale_inv;
case kNVTEGroupedColumnwiseScaleInv:
return t.columnwise_scale_inv;
case kNVTEGroupedColumnwiseAmax:
return t.columnwise_amax;
case kNVTEGroupedFirstDims:
return t.first_dims;
case kNVTEGroupedLastDims:
return t.last_dims;
case kNVTEGroupedTensorOffsets:
return t.tensor_offsets;
default:
NVTE_ERROR("Unknown grouped tensor parameter!");
}
}
size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor) {
auto *t = transformer_engine::convertNVTEGroupedTensor(tensor);
if (t == nullptr) return 0;
return t->num_tensors;
}
NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor) {
auto *t = transformer_engine::convertNVTEGroupedTensor(tensor);
if (t == nullptr) return kNVTEFloat32;
return static_cast<NVTEDType>(t->dtype());
}
NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor) {
if (tensor == nullptr) {
return NVTE_DELAYED_TENSOR_SCALING;
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.scaling_mode;
}
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor) {
if (tensor == nullptr) {
return nvte_make_shape(nullptr, 1);
}
const auto &t = *transformer_engine::convertNVTEGroupedTensorCheck(tensor);
return t.logical_shape;
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -36,7 +36,7 @@ enum class FP8BlockwiseRowwiseOption {
NONE,
// Rowwise data, scales in GEMM format
ROWWISE_GEMM_READY,
// Rowwise data, scales in compact format, needs extra processing (padding, transposing) before GEMM
// Deprecated
ROWWISE_COMPACT
};
......@@ -50,8 +50,7 @@ enum class FP8BlockwiseColumnwiseOption {
// On Hopper sm90, GEMM_READY means that columnwise quantization also fuses transpose op
// On higher sm versions with TN,NT,NN fp8 gemm, GEMM_READY doesn't fuse transpose
COLUMNWISE_GEMM_READY,
// Columnwise data in original shape
// Scales in compact format, needs extra processing (padding, transposing) before GEMM
// Deprecated
COLUMNWISE_COMPACT
};
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -203,8 +203,6 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, /
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -908,7 +908,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
}
NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
const size_t row_length = input.shape.size() > 0 ? input.shape.back() : 1;
size_t num_rows = 1;
for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) {
num_rows *= input.shape.at(i);
......@@ -927,12 +927,14 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
NVTE_CHECK(output_t.shape.size() == input.shape.size(), "input (shape=", input.shape,
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
if (output_t.shape.size() > 0) {
NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t.");
NVTE_CHECK(output_t.shape.front() == input.shape.back(), "input (shape=", input.shape,
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
for (size_t i = 1; i < output_t.shape.size(); ++i) {
NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t");
NVTE_CHECK(output_t.shape[i] == input.shape[i - 1], "input (shape=", input.shape,
") and output_t (shape=", output_t.shape, ") have incompatible dims.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type.");
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment