Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -131,7 +131,7 @@ enum NVTE_Mask_Type {
* NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1),
* NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and
* NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)),
* where alpha is a learnable parameter in shape [H].
* where alpha is a learnable parameter of shape [H].
*/
enum NVTE_Softmax_Type {
/*! Vanilla softmax */
......@@ -409,7 +409,6 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -673,7 +672,7 @@ void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor se
* \param[in] len batch_size x sequence_length.
* \param[in] stream CUDA stream used for this operation.
*/
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len,
uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlens, NVTETensor workspace, size_t len,
cudaStream_t stream);
/*! \brief Set the seed and offset for RNG state.
......@@ -830,8 +829,7 @@ void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETens
* \param[in] tensor Input tensor.
* \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1].
* \param[out] new_tensor Output tensor.
* \param[in] b Batch size.
* \param[in] max_seq_len Maximum sequence length.
* \param[in] t Packed sequence length.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -57,24 +57,24 @@ NVTEMatmulConfig nvte_create_matmul_config();
/*! \brief Query an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value to.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
void *buf, size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in matrix multiplication configuration.
*
* \param[in] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
* \param[in/out] config Matrix multiplication configuration.
* \param[in] attr Option type.
* \param[in] buf Memory address to read option value from.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
......@@ -280,9 +280,11 @@ class MatmulConfigWrapper {
MatmulConfigWrapper(const MatmulConfigWrapper &) = delete;
MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete;
/*! \brief Move constructor. */
MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
/*! \brief Move-assignment operator. */
MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_matmul_config(config_);
......@@ -319,14 +321,15 @@ class MatmulConfigWrapper {
/*! \brief Set whether to compute GELU in GEMM epilogue. */
void set_with_gelu_epilogue(bool with_gelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue,
&with_gelu_epilogue, sizeof(bool));
const auto val = static_cast<uint8_t>(with_gelu_epilogue);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue, &val, sizeof(val));
}
/*! \brief Set whether to compute GELU backward in GEMM epilogue. */
void set_with_dgelu_epilogue(bool with_dgelu_epilogue) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue,
&with_dgelu_epilogue, sizeof(bool));
const auto val = static_cast<uint8_t>(with_dgelu_epilogue);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue, &val,
sizeof(val));
}
/*! \brief Set auxilliary tensor for GEMM epilogue. */
......@@ -337,13 +340,15 @@ class MatmulConfigWrapper {
/*! \brief Set whether to use split accumulator for FP8 GEMM. */
void set_use_split_accumulator(bool use_split_accumulator) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator,
&use_split_accumulator, sizeof(bool));
const auto val = static_cast<uint8_t>(use_split_accumulator);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator, &val,
sizeof(val));
}
/*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */
void set_sm_count(int sm_count) {
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int));
const auto val = static_cast<int32_t>(sm_count);
nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &val, sizeof(val));
}
private:
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -61,6 +61,69 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE
const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute RHT amaxes for each split.
*
* This function is experimental and the API is not stable.
*
* This is intended for quantizing to NVFP4 with random Hadamard
* transforms (RHT). For each tensor split, compute the maximum
* absolute value (amax) and populate the row-wise amax of the
* corresponding output tensor. Also, compute the amax after a
* transposed RHT and populate the column-wise amax of the
* corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] outputs Array of NVFP4 output tensors. Only the row-wise and
* column-wise amaxes are updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outputs,
const size_t* split_sections, size_t num_tensors,
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion_columnwise(
const NVTETensor input, NVTETensor* outputs, const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors, const NVTEQuantizationConfig quant_config,
cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor row quantize (without Hadamard) and columnwise Hadamard transform cast fusion operation.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated
*
* \param[in] input Input tensor to apply Hadamard transform.
* \param[in,out] outputs Array of output tensors.
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] split_sections Array specifying splits in dimension 0 for each output tensor.
* \param[in] num_tensors Number of output tensors, must be > 0.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETensor* outputs,
const NVTETensor hadamard_matrix,
const size_t* split_sections, size_t num_tensors,
const NVTEQuantizationConfig quant_config,
NVTETensor quant_workspace, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -265,6 +265,37 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
float max_fp8, int force_pow_2_scales,
float epsilon, cudaStream_t stream);
/*! \brief Compute E8M0 scale_inv for a list of tensors.
*
* \warning This API is **experimental** and subject to change.
*
* \param[in] chunk_size Number of tensor elements processed by a CUDA block.
* \param[in,out] tensor_lists 2D array of input tensors.
* \param[in] num_tensor_lists Size (dim0) of tensor_lists.
* \param[in] num_tensors_per_list Size (dim1) of tensor_lists.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream);
/*! \brief Split a tensor along dimension 0 and compute the amax for each split.
*
* This function is experimental and the API is not stable.
*
* For each tensor split, compute the maximum absolute value (amax)
* and populate the amax of the corresponding output tensor.
*
* \param[in] input Input tensor.
* \param[in,out] outputs Array of output tensors. Only the amax is updated.
* \param[in] split_sections Size of each tensor split along dimension 0.
* \param[in] num_tensors Number of tensor splits.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections,
size_t num_tensors, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -163,11 +163,16 @@ void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETen
NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream);
/*! \brief Helper to enable cuDNN backend for normalization
/*! \brief Set whether to enable cuDNN backend for normalization forward.
*
* \param[in] bool Enable if True
* \param[in] enable Whether to enable cuDNN backend.
*/
void nvte_enable_cudnn_norm_fwd(bool enable);
/*! \brief Set whether to enable cuDNN backend for normalization backward.
*
* \param[in] enable Whether to enable cuDNN backend.
*/
void nvte_enable_cudnn_norm_bwd(bool enable);
/*! \brief Control whether norm computes `gamma += 1.0` for zero-centered gamma
......@@ -176,11 +181,14 @@ void nvte_enable_cudnn_norm_bwd(bool enable);
* Currently this only applies to the CuDNN backend. If CuDNN is not used,
* this setting has no effect.
*
* \param[in] bool Enable if True
* \param[in] enable Whether to enable zero-centered gamma.
*/
void nvte_enable_zero_centered_gamma_in_weight_dtype(bool enable);
#ifdef __cplusplus
/*! \brief Normalization function type */
enum class NVTE_Norm_Type { LayerNorm, RMSNorm };
#endif
#ifdef __cplusplus
} // extern "C"
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -23,7 +23,7 @@ extern "C" {
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
* @f$ 2^{-margin} \max_{fp8\_dtype} @f$.
*
* \param[in] amax_history History of maximum absolute values.
* Shape: [history_length, num_scales]
......@@ -54,7 +54,7 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update(
* the last, the last entry shifts to the second to last) and the
* first entry is set to zero. The scaling factor is estimated so the
* FP8 tensor's maximum absolute value is
* @f$ 2^{-\text{margin}} \text{max}_\text{fp8\_dtype} @f$.
* @f$ 2^{-margin} \max_{fp8\_dtype} @f$.
*
* \param[in] amax_reduction_buffer The contiguous buffer used for amax reduction.
* Shape: [num_scales * num_tensors]
......@@ -113,17 +113,200 @@ void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
void nvte_compute_scale_from_amax(NVTETensor output, const NVTEQuantizationConfig config,
cudaStream_t stream);
/*! \brief Compute partial amax for FP8 blockwise scaling.
*
* This function computes the maximum absolute values for each block of the original tensor.
* `inp` contains a continuous segment from the flattened original tensor. For each block,
* if it overlaps with the range [start_offset, start_offset+inp.length), the amax is
* computed from inp; otherwise, the amax is set to 0.
*
* Example: Original tensor (logically 512x512) divided into 16 blocks of size 128x128.
* `inp` contains continuous elements starting from position start_offset
* in the flattened original tensor.
*
* Logical view - Original Tensor (e.g., 512x512) divided into 16 blocks of size 128x128:
* ┌─────────┬─────────┬─────────┬─────────┐
* │ Block0 │ Block1 │ Block2 │ Block3 │ Each block: 128x128
* │ 128x128 │ 128x128 │ 128x128 │ 128x128 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block4 │ Block5 │ Block6 │ Block7 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block8 │ Block9 │ Block10 │ Block11 │
* ├─────────┼─────────┼─────────┼─────────┤
* │ Block12 │ Block13 │ Block14 │ Block15 │
* └─────────┴─────────┴─────────┴─────────┘
*
* Physical view - Flattened in row-major order:
* ┌────────────────────────────────────────────────────────────────┐
* │[0...128][128...256][256...384][384...512]...[261632...262143] │
* └────────────────────────────────────────────────────────────────┘
* ^ ^
* start_offset start_offset + inp.length
*
* For each 128x128 block, compute amax:
* - If the block overlaps with [start_offset, start_offset+inp.length), compute amax
* - If the block is completely outside this range, set amax = 0
*
* amax output (one value per 128x128 block), block 1 and block 2 are non-zero because they
* overlap with the [start_offset, start_offset+inp.length) range:
* ┌───────┬───────┬───────┬───────┐
* │ 0 │ amax │ amax │ 0 │ Block0-3
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block4-7
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block8-11
* ├───────┼───────┼───────┼───────┤
* │ 0 │ 0 │ 0 │ 0 │ Block12-15
* └───────┴───────┴───────┴───────┘
*
* \param[in] inp Input tensor (continuous slice of flattened original tensor).
* \param[in,out] amax Output tensor for maximum absolute values per block.
* \param[in] h Height dimension of the logical tensor.
* \param[in] w Width dimension of the logical tensor.
* \param[in] amax_stride_h Stride in height dimension for amax tensor.
* \param[in] amax_stride_w Stride in width dimension for amax tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] block_len Length of a quantization block to process.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_block_scaling_compute_partial_amax(const NVTETensor inp, NVTETensor amax, size_t h,
size_t w, size_t amax_stride_h,
size_t amax_stride_w, size_t start_offset,
size_t block_len, cudaStream_t stream);
/*! \brief Perform partial FP8 casting with blockwise scaling.
*
* This function casts the input tensor to FP8 format using blockwise scaling factors.
* `inp` contains a continuous segment from the flattened original tensor.
*
* \param[in] inp Input tensor.
* \param[out] out Output tensor in FP8 format.
* \param[in] scale Scaling factors per block.
* \param[in] h Height dimension of the tensor.
* \param[in] w Width dimension of the tensor.
* \param[in] scale_stride_h Stride in height dimension for scale tensor.
* \param[in] scale_stride_w Stride in width dimension for scale tensor.
* \param[in] start_offset Starting offset for partial computation.
* \param[in] block_len Length of the block to process.
* \param[in] out_dtype Output FP8 datatype.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out,
const NVTETensor scale, size_t h, size_t w,
size_t scale_stride_h, size_t scale_stride_w,
size_t start_offset, size_t block_len,
const NVTEDType out_dtype, cudaStream_t stream);
/*! \brief Compute partial amax for MXFP8 scaling.
*
* This function computes the maximum absolute values along both row and column dimensions.
* input contains a continuous segment from the flattened original tensor. For each row/column
* block, if it overlaps with the range starting from start_offset, the amax is computed from
* `input`; otherwise, the amax is set to 0.
*
* Example: Original tensor (64 rows x 64 cols).
* Rowwise amax granularity: 1x32 (each row divided into 2 blocks)
* Columnwise amax granularity: 32x1 (each column divided into 2 blocks)
* input contains a continuous segment starting from start_offset.
*
* Logical view - Original Tensor (64x64) with 1x32 and 32x1 blocks:
*
* Rowwise blocks (1x32): Each row has 2 blocks
* ┌──────────────┬──────────────┐
* row0 │ Block_r0_0 │ Block_r0_1 │ (cols 0-31, 32-63)
* ├──────────────┼──────────────┤
* row1 │ Block_r1_0 │ Block_r1_1 │
* ├──────────────┼──────────────┤
* ... │ ... │ ... │
* ├──────────────┼──────────────┤
* row63│ Block_r63_0 │ Block_r63_1 │
* └──────────────┴──────────────┘
*
* Columnwise blocks (32x1): Each column has 2 blocks
* ┌───┬───┬─────┬───┬───┐
* │c0 │c1 │ ... │c62│c63│
* ┌────┼───┼───┼─────┼───┼───┤
* │Blk0│ │ │ │ │ │ rows 0-31
* ├────┼───┼───┼─────┼───┼───┤
* │Blk1│ │ │ │ │ │ rows 32-63
* └────┴───┴───┴─────┴───┴───┘
*
* Physical view - Flattened in row-major order:
* Total elements: 64*64 = 4096
* ┌──────────────────────────────────────────────────────┐
* │[0...63][64...127][128...191]...[4032...4095] │
* └──────────────────────────────────────────────────────┘
* ^ ^
* start_offset=60 start_offset + input.length=130
*
* Row-wise amax output (one value per 1x32 block):
* ┌────────┬────────┐
* │ amax │ amax │ row0 (block0 and block1 partially covered)
* ├────────┼────────┤
* │ 0 │ 0 │ row1 (not covered)
* ├────────┼────────┤
* │ ... │ ... │
* ├────────┼────────┤
* │ 0 │ 0 │ row63 (not covered)
* └────────┴────────┘
*
* Column-wise amax output (one value per 32x1 block):
* ┌────────┬────────┬────────┬────────┬────────┬────────┬────────┐
* │ amax │ amax │ amax │ amax │ amax │ amax │ amax │ ... row 0-31
* ├────────┼────────┼────────┼────────┼────────┼────────┼────────┤
* │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ amax=0 │ ... row 32-62
* └────────┴────────┴────────┴────────┴────────┴────────┴────────┘
* col0 col1 col2 col3 col4 col5 col6
*
* For each 1x32 or 32x1 block, if it overlaps with [start_offset, start_offset+input.length),
* compute amax; otherwise set to 0.
*
* \param[in] input Input tensor (continuous segment of flattened original tensor).
* \param[in,out] amax_rowwise Output tensor for row-wise maximum absolute values.
* \param[in,out] amax_colwise Output tensor for column-wise maximum absolute values.
* \param[in] rows Number of rows in the logical tensor.
* \param[in] cols Number of columns in the logical tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_mxfp8_scaling_compute_partial_amax(const NVTETensor input, NVTETensor amax_rowwise,
NVTETensor amax_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream);
/*! \brief Perform partial MXFP8 casting.
*
* This function casts the input tensor to MXFP8 format, producing both row-wise and
* column-wise scaled outputs. input contains a continuous segment from the flattened
* original tensor.
*
* \param[in] input Input (continuous segment of flattened original tensor).
* \param[out] output_rowwise Output tensor with row-wise scaling (MXFP8 format).
* \param[out] output_colwise Output tensor with column-wise scaling (MXFP8 format).
* \param[in] scale_inv_rowwise Inverse scaling factors for row-wise scaling.
* \param[in] scale_inv_colwise Inverse scaling factors for column-wise scaling.
* \param[in] rows Number of rows in the logical tensor.
* \param[in] cols Number of columns in the logical tensor.
* \param[in] start_offset Starting offset in the flattened tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_mxfp8_scaling_partial_cast(const NVTETensor input, NVTETensor output_rowwise,
NVTETensor output_colwise, const NVTETensor scale_inv_rowwise,
const NVTETensor scale_inv_colwise, int rows, int cols,
size_t start_offset, cudaStream_t stream);
/*! \brief Compute per-tensor scaling factor for NVFP4 format.
*
* This function computes the scaling factor (alpha) for NVFP4 quantization based
* on the input tensors A and B, with options for using row-wise amax values.
*
* \param[in] inpA Input tensor A.
* \param[in] use_rowwise_amax_A Whether to use row-wise amax for tensor A.
* \param[in] inpB Input tensor B.
* \param[in] use_rowwise_amax_B Whether to use row-wise amax for tensor B.
* \param[in] alpha_in Input scaling factor.
* \param[out] alpha_out Output scaling factor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A,
const NVTETensor inpB, const bool use_rowwise_amax_B,
float alpha_in, NVTETensor alpha_out, cudaStream_t stream);
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file cast.h
* \brief Functions to cast to/from FP8.
/*! \file swizzle.h
* \brief Functions to convert scaling factors into format expected by GEMM.
*/
#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_
......@@ -34,6 +34,7 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
*
* \param[in] inputs Input tensors with non-swizzled scale_inv.
* \param[in,out] outputs Output tensors which hosts swizzled scale_inv.
* \param[in] num_tensors Number of input and output tensors.
* \param[in] stream CUDA stream used for the operation.
*
* Requirements:
......@@ -46,7 +47,7 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
*
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
* \param[in] input Input FP8 block-scaled tensor.
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
* \param[in] stream CUDA stream used for the operation.
*
......@@ -56,7 +57,6 @@ void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETen
* Requirements:
* - input is an FP8 block scaling tensor
* - input has rowwise usage
* - input.scale_inv is in GEMM_READY format
* - output is an MXFP8 tensor
* - output has rowwise usage
* - output.scale_inv has appropriate shape
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -13,6 +13,7 @@
#include <cuda_runtime_api.h>
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
......@@ -52,8 +53,11 @@ struct NVTEShape {
* It does not own the memory it points to.
*/
struct NVTEBasicTensor {
/*! Pointer to data buffer. */
void *data_ptr;
/*! Data type. */
NVTEDType dtype;
/*! Tensor shape. */
NVTEShape shape;
};
......@@ -61,13 +65,14 @@ struct NVTEBasicTensor {
* \brief Indicates the kind of the tensor parameter to set/get.
*/
enum NVTETensorParam {
kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */
kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */
kNVTEScale = 2, /*!< Scale tensor */
kNVTEAmax = 3, /*!< Amax tensor */
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTERowwiseData = 0, /*!< Data usable in rowwise manner */
kNVTEColumnwiseData = 1, /*!< Data usable in columnwise manner */
kNVTEScale = 2, /*!< Scale tensor */
kNVTEAmax = 3, /*!< Amax tensor */
kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */
kNVTENumTensorParams
};
......@@ -143,8 +148,9 @@ void *nvte_tensor_columnwise_data(const NVTETensor tensor);
/*! \brief Construct a shape from an array of dimension sizes.
*
* \param[data] Pointer to start of shape array.
* \param[data] Number of dimensions (must be <= 14)
* \param[data] Pointer to start of shape array. If NULL, the shape
* will be filled with zeros.
* \param[ndim] Number of dimensions (must be <= 14)
*
* \return A shape. The shape will own its own copy of the data.
*/
......@@ -177,7 +183,7 @@ size_t nvte_tensor_ndims(const NVTETensor tensor);
/*! \brief Get the size of a specific tensor dimension.
*
* \param[in] tensor Tensor.
* \param[in] size_t Dimension index.
* \param[in] dim Dimension index.
*
* \return Size of the tensor at the specified dimension.
*/
......@@ -258,12 +264,13 @@ NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor);
/*! \brief Reset tensor value to zero.
*
* \param[in] tensor Tensor.
*
* \return A scale_inv shape of the input tensor.
* \param[in] stream CUDA stream to use for the operation.
*/
void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream);
/*! \brief Set a parameter of the tensor.
*
* \warning Deprecated in favor of nvte_set_tensor_param_v2.
*
* \param[in/out] tensor Tensor.
* \param[in] param_name The parameter to be set.
......@@ -273,12 +280,38 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name,
const NVTEBasicTensor *param);
/*! \brief Get a value of the parameter of the tensor.
*
* \warning Deprecated in favor of nvte_set_tensor_param_v2.
*
* \param[in] tensor Tensor.
* \param[in] param_name The parameter to be set.
*/
NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name);
/*! \brief Set a tensor parameter.
*
* \param[in/out] tensor Tensor.
* \param[in] param Tensor parameter type.
* \param[in] buf Memory address to read parameter value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const void *buf,
size_t size_in_bytes);
/*! \brief Query a tensor parameter.
*
* \param[in] tensor Tensor.
* \param[in] param Tensor parameter type.
* \param[out] buf Memory address to write parameter value.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Get the granularity of scaling of this tensor.
*
* \param[in] tensor Tensor.
......@@ -324,12 +357,7 @@ enum NVTEQuantizationConfigAttribute {
conditional early even when captured in a static CUDA graph.
*/
kNVTEQuantizationConfigNoopTensor = 2,
/*! Data format for an FP8 block-scaled tensor
*
* This is not the right design since the tensor format is a
* property of the tensor, not the quantization. This enum will
* likely be refactored away in the future.
*/
/*! \warning Deprecated */
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3,
/*! RNG state (NVTETensor with 2 elements - seed and offset */
kNVTEQuantizationConfigRNGState = 4,
......@@ -337,6 +365,12 @@ enum NVTEQuantizationConfigAttribute {
kNVTEQuantizationConfigNVFP42DQuantization = 5,
/*! Whether to enable stochastic rounding */
kNVTEQuantizationConfigStochasticRounding = 6,
/*! Whether to enable fast math operations with reduced accuracy.
*
* Optimizations are kernel-specific and they may be applied
* inconsistently between kernels.
*/
kNVTEQuantizationConfigUseFastMath = 7,
kNVTEQuantizationConfigNumAttributes
};
......@@ -347,14 +381,14 @@ NVTEQuantizationConfig nvte_create_quantization_config();
/*! \brief Query an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value.
* Ignored if NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, void *buf,
......@@ -362,10 +396,10 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config,
/*! \brief Set an option in quantization config.
*
* \param[in] config Quantization config.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
* \param[in/out] config Quantization config.
* \param[in] attr Option type.
* \param[in] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config,
NVTEQuantizationConfigAttribute attr, const void *buf,
......@@ -394,6 +428,114 @@ int nvte_is_non_tn_fp8_gemm_supported();
*/
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream);
/*! \brief TE Grouped Tensor type
*
* NVTEGroupedTensor is a collection of tensors with potentially different shapes
* but the same dtype and scaling mode. It does not own the memory it points to.
*/
typedef void *NVTEGroupedTensor;
/*! \enum NVTEGroupedTensorParam
* \brief Indicates the kind of the grouped tensor parameter to set/get.
*/
enum NVTEGroupedTensorParam {
kNVTEGroupedRowwiseData = 0, /*!< Data usable in rowwise manner */
kNVTEGroupedColumnwiseData = 1, /*!< Data usable in columnwise manner */
kNVTEGroupedScale = 2, /*!< Scale tensor */
kNVTEGroupedAmax = 3, /*!< Amax tensor */
kNVTEGroupedRowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */
kNVTEGroupedColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */
kNVTEGroupedColumnwiseAmax = 6, /*!< Columnwise Amax tensor */
kNVTEGroupedFirstDims = 7, /*!< First dimension sizes (device pointer to int64_t array) */
kNVTEGroupedLastDims = 8, /*!< Last dimension sizes (device pointer to int64_t array) */
kNVTEGroupedTensorOffsets =
9, /*!< Tensor offsets for contiguous layout (device pointer to int64_t array) */
kNVTENumGroupedTensorParams
};
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Create a new TE grouped tensor.
*
* Create a new TE grouped tensor. Before use its parameters need to be set.
* TE grouped tensors are just wrappers on top of raw data and do not
* own memory.
*
* \param[in] scaling_mode Scaling mode of the grouped tensor.
* \param[in] num_tensors Number of tensors in the group (must be > 0).
* \param[in] logical_shape Logical 2D shape of the grouped data.
*
* \return A new TE grouped tensor.
*/
NVTEGroupedTensor nvte_create_grouped_tensor(NVTEScalingMode scaling_mode, size_t num_tensors,
NVTEShape logical_shape);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Destroy a TE grouped tensor.
*
* Since the TE grouped tensor does not own memory, the underlying
* data is not freed during this operation.
*
* \param[in] tensor Grouped tensor to be destroyed.
*/
void nvte_destroy_grouped_tensor(NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Set a parameter of the grouped tensor.
*
* \param[in/out] tensor Grouped tensor.
* \param[in] param_name The parameter to be set.
* \param[in] param The value to be set (NVTEBasicTensor).
*/
void nvte_set_grouped_tensor_param(NVTEGroupedTensor *tensor, NVTEGroupedTensorParam param_name,
const NVTEBasicTensor *param);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a value of the parameter of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
* \param[in] param_name The parameter to be queried.
*
* \return NVTEBasicTensor containing the parameter data.
*/
NVTEBasicTensor nvte_get_grouped_tensor_param(const NVTEGroupedTensor tensor,
NVTEGroupedTensorParam param_name);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the number of tensors in a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Number of tensors in the group.
*/
size_t nvte_grouped_tensor_num_tensors(const NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a grouped tensor's data type.
*
* \param[in] tensor Grouped tensor.
*
* \return A data type of the grouped tensor.
*/
NVTEDType nvte_grouped_tensor_type(const NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get a scaling mode of the grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Scaling mode of the grouped tensor.
*/
NVTEScalingMode nvte_grouped_tensor_scaling_mode(const NVTEGroupedTensor tensor);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Get the logical shape of a grouped tensor.
*
* \param[in] tensor Grouped tensor.
*
* \return Logical 2D shape.
*/
NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor);
#ifdef __cplusplus
} // extern "C"
......@@ -435,7 +577,7 @@ inline bool is_int8_dtype(const DType t) {
/*! \brief Check if TE datatype is FP8
*
* Return true if TE datatype is FP8
* \param[in] DType TE Datatype of interest
* \param[in] t TE Datatype of interest
*/
inline bool is_fp8_dtype(const DType t) {
return t == DType::kFloat8E4M3 || t == DType::kFloat8E5M2;
......@@ -444,7 +586,7 @@ inline bool is_fp8_dtype(const DType t) {
/*! \brief Check if TE datatype is FP4
*
* Return true if TE datatype is FP4
* \param[in] DType TE Datatype of interest
* \param[in] t TE Datatype of interest
*/
inline bool is_fp4_dtype(const DType t) {
return t == DType::kFloat4E2M1;
......@@ -453,7 +595,7 @@ inline bool is_fp4_dtype(const DType t) {
/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16)
*
* Return true if TE datatype is high precision
* \param[in] DType TE Datatype of interest
* \param[in] t TE Datatype of interest
*/
inline bool is_high_precision_dtype(const DType t) {
return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16;
......@@ -477,20 +619,28 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Tensor data format.
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr,
const NVTEShape scale_inv_shape = defaultShape,
NVTEShape scale_inv_shape = defaultShape,
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) {
tensor_ = nvte_create_tensor(scaling_mode);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data);
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape};
nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax);
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape};
nvte_set_tensor_param(&tensor_, kNVTEScale, &scale);
nvte_set_tensor_param_v2(tensor_, kNVTERowwiseData, &data, sizeof(data));
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32,
amax_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param_v2(tensor_, kNVTEAmax, &amax, sizeof(amax));
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32,
scale_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param_v2(tensor_, kNVTEScale, &scale, sizeof(scale));
if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim &&
scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) {
// Scale-inv pointer has not been provided and shape matches default
scale_inv_shape = emptyShape;
}
NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv);
nvte_set_tensor_param_v2(tensor_, kNVTERowwiseScaleInv, &scale_inv, sizeof(scale_inv));
}
/*! \brief Constructs new TensorWrapper.
......@@ -506,6 +656,7 @@ class TensorWrapper {
* \param[in] scale_dptr Pointer to the scale value.
* \param[in] scale_inv_shape Shape of scale_inv
* \param[in] scale_inv_dptr Pointer to the inverse of scale value.
* \param[in] scaling_mode Tensor data format.
*/
TensorWrapper(void *dptr, const std::vector<size_t> &shape, const DType dtype,
float *amax_dptr = nullptr, float *scale_dptr = nullptr,
......@@ -560,7 +711,7 @@ class TensorWrapper {
const ShapeType &shape) noexcept {
NVTEShape nvte_shape = this->convertShape(shape);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(type), nvte_shape};
nvte_set_tensor_param(&tensor_, param, &data);
nvte_set_tensor_param_v2(tensor_, param, &data, sizeof(data));
return *this;
}
......@@ -599,10 +750,17 @@ class TensorWrapper {
return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape);
}
void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales) {
const auto val = static_cast<uint8_t>(with_gemm_swizzled_scales);
nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val));
}
// Parameter getters
NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept {
return nvte_get_tensor_param(tensor_, param);
NVTEBasicTensor ret;
nvte_get_tensor_param_v2(tensor_, param, &ret, sizeof(ret), nullptr);
return ret;
}
NVTEBasicTensor get_rowwise_data() const noexcept { return get_parameter(kNVTERowwiseData); }
......@@ -627,6 +785,12 @@ class TensorWrapper {
return get_parameter(kNVTEColumnwiseAmax);
}
bool get_with_gemm_swizzled_scales() const {
uint8_t val = 0;
nvte_get_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val), nullptr);
return static_cast<bool>(val);
}
/*! \brief Get an underlying NVTETensor.
*
* \return NVTETensor held by this TensorWrapper.
......@@ -639,7 +803,7 @@ class TensorWrapper {
*/
const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
return emptyShape;
}
return nvte_tensor_shape(tensor_);
}
......@@ -650,14 +814,14 @@ class TensorWrapper {
*/
const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
return emptyShape;
}
return nvte_tensor_columnwise_shape(tensor_);
}
/*! \brief Get the size of this TensorWrapper in the given dimension.
*
* \param[in] size_t Dimension index.
* \param[in] dim Dimension index.
*
* \return Size of this TensorWrapper in given dimension.
*/
......@@ -774,7 +938,7 @@ class TensorWrapper {
*/
const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
return emptyShape;
}
return nvte_tensor_scale_inv_shape(tensor_);
}
......@@ -793,6 +957,7 @@ class TensorWrapper {
static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
private:
NVTEShape convertShape(const NVTEShape &s) { return s; }
......@@ -805,15 +970,8 @@ class TensorWrapper {
NVTETensor tensor_ = nullptr;
};
/*! \enum Float8BlockScaleTensorFormat
* \brief Data format for an FP8 block-scaled tensor
*/
enum class Float8BlockScaleTensorFormat {
/*! FP8 data is transposed if needed and scales are swizzled */
GEMM_READY = 0,
/*! FP8 data is untransposed and scales are not swizzled or padded */
COMPACT = 1
};
/*! \warning Deprecated */
enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID };
/*! \struct QuantizationConfigWrapper
* \brief C++ wrapper for NVTEQuantizationConfigWrapper.
......@@ -825,9 +983,11 @@ class QuantizationConfigWrapper {
QuantizationConfigWrapper(const QuantizationConfigWrapper &) = delete;
QuantizationConfigWrapper &operator=(const QuantizationConfigWrapper &) = delete;
/*! \brief Move constructor. */
QuantizationConfigWrapper(QuantizationConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
/*! \brief Move-assignment operator. */
QuantizationConfigWrapper &operator=(QuantizationConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_quantization_config(config_);
......@@ -852,8 +1012,9 @@ class QuantizationConfigWrapper {
/*! \brief Set whether to force power of 2 scales */
void set_force_pow_2_scales(bool force_pow_2_scales) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales,
&force_pow_2_scales, sizeof(bool));
const auto val = static_cast<uint8_t>(force_pow_2_scales);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigForcePow2Scales, &val,
sizeof(val));
}
/*! \brief Set small value to add to amax */
......@@ -868,12 +1029,8 @@ class QuantizationConfigWrapper {
sizeof(NVTETensor));
}
/*! \brief Set FP8 block-scaled tensor format */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {
nvte_set_quantization_config_attribute(config_,
kNVTEQuantizationConfigFloat8BlockScaleTensorFormat,
&format, sizeof(Float8BlockScaleTensorFormat));
}
/*! \warning Deprecated */
void set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat format) {}
/*! \brief Set stochastic rounding state */
void set_rng_state(NVTETensor rng_state) {
......@@ -883,14 +1040,23 @@ class QuantizationConfigWrapper {
/*! \brief Set whether to use 2D block scaling for NVFP4 */
void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) {
const auto val = static_cast<uint8_t>(nvfp4_2d_quantization);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization,
&nvfp4_2d_quantization, sizeof(bool));
&val, sizeof(val));
}
/*! \brief Set whether to use stochastic rounding */
void set_stochastic_rounding(bool stochastic_rounding) {
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding,
&stochastic_rounding, sizeof(bool));
const auto val = static_cast<uint8_t>(stochastic_rounding);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding, &val,
sizeof(val));
}
/*! \brief Set whether to enable fast math operations */
void set_use_fast_math(bool use_fast_math) {
const auto val = static_cast<uint8_t>(use_fast_math);
nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigUseFastMath, &val,
sizeof(val));
}
private:
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -231,7 +231,7 @@ void nvte_cast_transpose_dbias_dsrelu(const NVTETensor input, const NVTETensor a
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
......@@ -250,7 +250,7 @@ void nvte_dgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_inp
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
......@@ -269,7 +269,7 @@ void nvte_dswiglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
......@@ -288,7 +288,7 @@ void nvte_dreglu_cast_transpose(const NVTETensor input, const NVTETensor act_inp
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
......@@ -307,7 +307,7 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in
* - columnwise data of `output` is equal to `transpose(cast(dact(input)))`
*
* \param[in] input Input tensor of shape [N, H].
* \param[in] gated_act_input Tensor used as input to the forward of
* \param[in] act_input Tensor used as input to the forward of
* gated activation operation.
* Shape [N, H * 2].
* \param[in,out] output Result of the cast.
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -17,6 +17,7 @@
#include <sstream>
#include "../recipe/recipe_common.cuh"
#include "../util/ptx.cuh"
#include "../utils.cuh"
#include "multi_tensor_apply.cuh"
......@@ -58,6 +59,28 @@ struct ComputeScaleAndScaleInvFunctor {
}
};
struct ComputeScaleInvE8M0Functor {
__device__ __forceinline__ void operator()(int chunk_size, volatile int *unused,
TensorListMetadata<2> &tl) {
int tensor_loc = tl.block_to_tensor[blockIdx.x];
int chunk_idx = tl.block_to_chunk[blockIdx.x];
int n = tl.sizes[tensor_loc];
bf16 *amax = reinterpret_cast<bf16 *>(tl.addresses[0][tensor_loc]);
amax += chunk_idx * chunk_size;
e8m0_t *scale_inv = reinterpret_cast<e8m0_t *>(tl.addresses[1][tensor_loc]);
scale_inv += chunk_idx * chunk_size;
n -= chunk_idx * chunk_size;
for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
scale_inv[i_start] = ptx::float_to_e8m0(static_cast<float>(amax[i_start]) *
Quantized_Limits<fp8e4m3>::max_norm_rcp);
}
}
};
void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag,
std::vector<std::vector<Tensor *>> tensor_lists,
float max_fp8, bool force_pow_2_scales,
......@@ -68,6 +91,19 @@ void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_f
NVTE_CHECK_CUDA(cudaGetLastError());
}
void multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size,
std::vector<std::vector<Tensor *>> tensor_lists,
cudaStream_t stream) {
NVTE_CHECK(tensor_lists[0][0]->data.dtype == DType::kBFloat16, "amax should be bf16");
auto scale_inv_dtype = tensor_lists[1][0]->data.dtype;
NVTE_CHECK(scale_inv_dtype == DType::kByte || scale_inv_dtype == DType::kFloat8E8M0,
"scale_inv should be e8m0/uint8");
Tensor dummy;
multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, dummy, tensor_lists, ComputeScaleInvE8M0Functor(),
stream);
NVTE_CHECK_CUDA(cudaGetLastError());
}
} // namespace multi_tensor_compute_scale
} // namespace transformer_engine
......@@ -85,3 +121,15 @@ void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETens
convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8,
force_pow_2_scales, epsilon, stream);
}
void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **tensor_lists,
const size_t num_tensor_lists,
const size_t num_tensors_per_list,
cudaStream_t stream) {
NVTE_API_CALL(nvte_multi_tensor_compute_scale_inv_e8m0_cuda);
using namespace transformer_engine;
multi_tensor_compute_scale::multi_tensor_compute_scale_inv_e8m0_cuda(
chunk_size, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list),
stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment