Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
...@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank ...@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context. /*! \brief Destroy a comm-gemm context.
* *
* \param[in] ctx Context to destroy. * \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/ */
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);
......
...@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); ...@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/ */
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph); int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
/*! \brief Compute dot product attention with packed QKV input. /*! \brief Compute dot product attention with packed QKV input.
* *
...@@ -269,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ...@@ -269,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
[[deprecated( [[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]] "Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
size_t max_seqlen, bool is_training, bool return_max_logit, bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input. /*! \brief Compute the backward of the dot product attention with packed QKV input.
* *
...@@ -332,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, ...@@ -332,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -346,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -346,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
NVTETensor workspace, cudaStream_t stream); bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input. /*! \brief Compute dot product attention with packed KV input.
* *
...@@ -409,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked( ...@@ -409,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
...@@ -424,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -424,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input. /*! \brief Compute the backward of the dot product attention with packed KV input.
* *
...@@ -478,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked( ...@@ -478,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -494,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -494,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
cudaStream_t stream); NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V. /*! \brief Compute dot product attention with separate Q, K and V.
* *
...@@ -559,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked( ...@@ -559,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation. * \param[in] stream CUDA stream used for this operation.
*/ */
void nvte_fused_attn_fwd( void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor cu_seqlens_q_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, const NVTETensor page_table_v, const NVTETensor rng_state,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V. /*! \brief Compute the backward of the dot product attention with separate Q, K and V.
* *
...@@ -628,6 +636,7 @@ void nvte_fused_attn_fwd( ...@@ -628,6 +636,7 @@ void nvte_fused_attn_fwd(
* \param[in] softmax_type Attention softmax type. * \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half). * \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor. * \param[in] workspace Workspace tensor.
...@@ -643,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso ...@@ -643,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout, size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, int64_t window_size_left, int64_t window_size_right,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream); bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset. /*! \brief Update the RNG state with the seed and calculated offset.
* *
......
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_GEMM_H_ #ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_ #define TRANSFORMER_ENGINE_GEMM_H_
#include <stdint.h>
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -20,6 +22,9 @@ extern "C" { ...@@ -20,6 +22,9 @@ extern "C" {
/*! \brief Configuration for matrix multiplication. */ /*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig; typedef void *NVTEMatmulConfig;
/*! \brief Configuration for grouped matrix multiplication. */
typedef void *NVTEGroupedMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute /*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication. * \brief Type of option for matrix multiplication.
*/ */
...@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute { ...@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
kNVTEMatmulConfigNumAttributes kNVTEMatmulConfigNumAttributes
}; };
/*! \enum NVTEGroupedMatmulConfigAttribute
* \brief Type of option for grouped matrix multiplication.
*/
enum NVTEGroupedMatmulConfigAttribute {
/*! Average M dimension hint
*
* Optional hint for average M dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgM = 0,
/*! Average N dimension hint
*
* Optional hint for average N dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgN = 1,
/*! Average K (reduction) dimension hint
*
* Optional hint for average K dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from A's logical shape.
*/
kNVTEGroupedMatmulConfigAvgK = 2,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEGroupedMatmulConfigSMCount = 3,
kNVTEGroupedMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */ /*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config(); NVTEMatmulConfig nvte_create_matmul_config();
...@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA ...@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Destroy a matrix multiplication configuration. */ /*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config); void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Create a grouped matrix multiplication configuration. */
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config();
/*! \brief Query an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a grouped matrix multiplication configuration. */
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
* *
* This has been deprecated in favor of nvte_cublas_gemm_v2. * This has been deprecated in favor of nvte_cublas_gemm_v2.
...@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor ...@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
bool accumulate, bool use_split_accumulator, int math_sm_count, bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream); cudaStream_t stream);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
* Performs batched GEMM on a collection of matrices with potentially different shapes.
* All tensors in the group must have compatible dimensions for matrix multiplication.
* Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous
* memory layout and shape metadata.
*
* \param[in] A Input grouped tensor A.
* \param[in] transa Whether to transpose A matrices.
* \param[in] B Input grouped tensor B.
* \param[in] transb Whether to transpose B matrices.
* \param[in] C Input grouped tensor C (can be NULL for beta=0).
* \param[out] D Output grouped tensor D.
* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements).
* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements).
* \param[in] workspace_setup Workspace tensor for pointer array setup.
* \param[in] workspace_cublas Workspace tensor for cuBLAS operations.
* \param[in] config Additional configuration (can be NULL for defaults).
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.2+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream);
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out, const NVTETensor* bias, NVTETensor* pre_gelu_out,
...@@ -356,6 +463,70 @@ class MatmulConfigWrapper { ...@@ -356,6 +463,70 @@ class MatmulConfigWrapper {
NVTEMatmulConfig config_ = nullptr; NVTEMatmulConfig config_ = nullptr;
}; };
/*! \struct GroupedMatmulConfigWrapper
* \brief C++ wrapper for NVTEGroupedMatmulConfig.
*/
class GroupedMatmulConfigWrapper {
public:
GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {}
GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~GroupedMatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEGroupedMatmulConfig.
*
* \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.
*/
operator NVTEGroupedMatmulConfig() const noexcept { return config_; }
/*! \brief Set average M dimension hint for algorithm selection. */
void set_avg_m(int64_t avg_m) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m,
sizeof(int64_t));
}
/*! \brief Set average N dimension hint for algorithm selection. */
void set_avg_n(int64_t avg_n) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n,
sizeof(int64_t));
}
/*! \brief Set average K dimension hint for algorithm selection. */
void set_avg_k(int64_t avg_k) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k,
sizeof(int64_t));
}
/*! \brief Set number of streaming multiprocessors to use. */
void set_sm_count(int sm_count) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count,
sizeof(int));
}
private:
/*! \brief Wrapped NVTEGroupedMatmulConfig. */
NVTEGroupedMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine } // namespace transformer_engine
#endif // __cplusplus #endif // __cplusplus
......
...@@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp ...@@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int random_sign_mask, int random_sign_mask_t, int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Grouped-tensor amax with Hadamard transform (graph safe, device-managed grouping).
*
* This function is experimental and the API is not stable.
*
* This API assumes that the split info (grouping of tensors) is on device and unknown to the host;
* therefore, this is a graph safe API and the grouped-tensor argument is passed as a single device structure.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output amax (row/col). Only the row-wise and
* column-wise amaxes are updated.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input,
NVTEGroupedTensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*! /*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation. * \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
* *
...@@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso ...@@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
const NVTEQuantizationConfig quant_config, const NVTEQuantizationConfig quant_config,
NVTETensor quant_workspace, cudaStream_t stream); NVTETensor quant_workspace, cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor Hadamard transform cast fusion operation in graph-safe mode.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output (row/column-wise quantized results).
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion_graph_safe(
const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor ** ...@@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **
void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections,
size_t num_tensors, cudaStream_t stream); size_t num_tensors, cudaStream_t stream);
/*! \brief Grouped-tensor amax without doing hadamard transform.
*
* This function is experimental and the API is not stable.
*
* \param[in] input NVTEGroupedTensor Input tensor.
* \param[in,out] output NVTEGroupedTensor Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -88,33 +88,40 @@ class Recipe: ...@@ -88,33 +88,40 @@ class Recipe:
Base recipe class. Base recipe class.
""" """
def nvfp4(self): @classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling.""" """Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling) return issubclass(cls, NVFP4BlockScaling)
def mxfp8(self): @classmethod
def mxfp8(cls):
"""Whether the given recipe is MXFP8 block scaling.""" """Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling) return issubclass(cls, MXFP8BlockScaling)
def delayed(self): @classmethod
def delayed(cls):
"""Whether the given recipe is delayed scaling.""" """Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling) return issubclass(cls, DelayedScaling)
def float8_current_scaling(self): @classmethod
def float8_current_scaling(cls):
"""Whether the given recipe is (per-tensor) current scaling.""" """Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling) return issubclass(cls, Float8CurrentScaling)
def float8_per_tensor_scaling(self): @classmethod
def float8_per_tensor_scaling(cls):
"""Whether the given recipe is per-tensor scaling.""" """Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling)) return issubclass(cls, (DelayedScaling, Float8CurrentScaling))
def float8_block_scaling(self): @classmethod
def float8_block_scaling(cls):
"""Whether the given recipe is float8 blockwise scaling.""" """Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling) return issubclass(cls, Float8BlockScaling)
def custom(self): @classmethod
def custom(cls):
"""Whether the given recipe is custom.""" """Whether the given recipe is custom."""
return isinstance(self, CustomRecipe) return issubclass(cls, CustomRecipe)
@dataclass() @dataclass()
......
...@@ -458,9 +458,9 @@ class TensorAllocator { ...@@ -458,9 +458,9 @@ class TensorAllocator {
} }
void Free(NVTETensor t) { void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t); uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return; if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid tensor."); NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index); free_list.push_back(index);
// Clean up // Clean up
...@@ -568,9 +568,9 @@ class GroupedTensorAllocator { ...@@ -568,9 +568,9 @@ class GroupedTensorAllocator {
} }
void Free(NVTEGroupedTensor t) { void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t); uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return; if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor."); NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index); free_list.push_back(index);
// Clean up // Clean up
......
...@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel( ...@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32) ).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes) input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens = tl.sum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask) input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
...@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel( ...@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).to(tl.int32) ).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row = tl.where(pid < total_valid_tokens, dst_row, pid)
tl.store(dst_rows_ptr + pid, dst_row) tl.store(dst_rows_ptr + pid, dst_row)
......
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
#include "../util/cuda_runtime.h" #include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem> #include <filesystem>
#include <mutex> #include <mutex>
...@@ -232,6 +234,12 @@ int cudart_version() { ...@@ -232,6 +234,12 @@ int cudart_version() {
return version; return version;
} }
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace cuda } // namespace cuda
} // namespace transformer_engine } // namespace transformer_engine
...@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false); ...@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false);
*/ */
int cudart_version(); int cudart_version();
/* \brief cuBLAS version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
size_t cublas_version();
} // namespace cuda } // namespace cuda
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -141,12 +141,12 @@ ...@@ -141,12 +141,12 @@
#ifdef NVTE_WITH_CUBLASMP #ifdef NVTE_WITH_CUBLASMP
#define NVTE_CHECK_CUBLASMP(expr) \ #define NVTE_CHECK_CUBLASMP(expr) \
do { \ do { \
const cublasMpStatus_t status = (expr); \ const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \ if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \
} \ } \
} while (false) } while (false)
#endif // NVTE_WITH_CUBLASMP #endif // NVTE_WITH_CUBLASMP
......
...@@ -172,6 +172,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const ...@@ -172,6 +172,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(
uint64_t *mbar, const uint32_t tx_count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr),
"r"(tx_count));
#else
NVTE_DEVICE_ERROR(
"mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void fence_mbarrier_init_release_cluster() { __device__ __forceinline__ void fence_mbarrier_init_release_cluster() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile("fence.mbarrier_init.release.cluster;"); asm volatile("fence.mbarrier_init.release.cluster;");
...@@ -251,13 +263,86 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 ...@@ -251,13 +263,86 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar,
uint32_t phase_parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile(
"{\n\t"
".reg .b64 r1; \n\t"
".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met
"WAIT: \n\t" // loop around barrier wait
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t"
"@waitComplete bra DONE; \n\t" // mbarrier conditions are met
"bra WAIT; \n\t" // just a time-out, try again
"DONE: \n\t"
"}\n\t"
:
: "r"(mbar_ptr), "r"(phase_parity)
: "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::"
"all.b128 "
"[%0], [%1];" ::"r"(workID_response),
"r"(mbar_ptr));
} else {
NVTE_DEVICE_ERROR(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr,
int32_t &ctaid_X, int32_t &ctaid_Y) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr);
asm volatile(
"{\n\t"
".reg .s32 x_ctaid; \n\t"
".reg .s32 y_ctaid; \n\t"
"mov .s32 x_ctaid, -1; \n\t"
"mov .s32 y_ctaid, -1; \n\t"
".reg.b128 try_cancel_response; \n\t"
"ld.shared.b128 try_cancel_response, [%2]; \n\t"
".reg .pred P1; \n\t"
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t"
"@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, "
"_}, try_cancel_response; \n\t"
"mov .s32 %0, x_ctaid; \n\t"
"mov .s32 %1, y_ctaid; \n\t"
"}\n\t"
: "=r"(ctaid_X), "=r"(ctaid_Y)
: "r"(workID_response)
: "memory");
} else {
NVTE_DEVICE_ERROR(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_EXPONENT_BIAS = 127;
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { __device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 // Handle the special case of NaN.
: __int_as_float((254 - biased_exp) if (biased_exp == 255) return __int_as_float(0x7fffffff);
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) // Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of
// the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left.
if (biased_exp == 254) return __int_as_float(0x00400000);
// Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal.
return __int_as_float((254 - biased_exp) << FP32_MANTISSA_BITS);
} }
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) { __device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
...@@ -671,6 +756,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c ...@@ -671,6 +756,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
} }
} }
template <typename SCALING_COEFFICIENT_TYPE>
__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest(
const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) {
uint32_t out_8x = 0;
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, bf16>::value) {
asm volatile(
"{\n"
".reg.f32 zero; \n\t"
"mov.b32 zero, 0; \n\t"
".reg.b16 scaling_coeff; \n\t"
"mov.b16 scaling_coeff, %3; \n\t"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t"
".reg.b8 f0, f1, f2, f3; \n\t"
// Elements reordered to match e2m1x4 packing order (v1,v0)
"cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t"
"mov.b32 %0, {f0, f1, f2, f3};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "h"(reinterpret_cast<const uint16_t &>(scaling_coefficient)));
} else if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, float>::value) {
asm volatile(
"{\n"
".reg.b64 scaling_coeff_2x; \n\t"
"mov.b64 scaling_coeff_2x, {%3, %3}; \n\t"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"cvt.f32.bf16 v4, v4_bf16; \n\t"
"cvt.f32.bf16 v5, v5_bf16; \n\t"
"cvt.f32.bf16 v6, v6_bf16; \n\t"
"cvt.f32.bf16 v7, v7_bf16; \n\t"
".reg.b64 v01, v23, v45, v67; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mov.b64 v45, {v4, v5}; \n\t"
"mov.b64 v67, {v6, v7}; \n\t"
"mul.f32x2 v01, v01, scaling_coeff_2x; \n\t"
"mul.f32x2 v23, v23, scaling_coeff_2x; \n\t"
"mul.f32x2 v45, v45, scaling_coeff_2x; \n\t"
"mul.f32x2 v67, v67, scaling_coeff_2x; \n\t"
// Elements reordered to match the packing order (v1,v0)
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"mov.b64 {v5, v4}, v45; \n\t"
"mov.b64 {v7, v6}, v67; \n\t"
".reg.b8 f0, f1, f2, f3; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;\n\t"
"mov.b32 %0, {f0, f1, f2, f3};\n\t"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "f"(scaling_coefficient));
} else {
NVTE_DEVICE_ERROR("Not supported scaling coefficient type.");
}
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return out_8x;
}
template <typename SCALING_COEFFICIENT_TYPE>
__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding(
const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient,
const uint32_t rbits03, const uint32_t rbits47) {
uint32_t out_8x = 0;
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, bf16>::value) {
asm volatile(
"{\n"
".reg.f32 zero; \n\t"
"mov.b32 zero, 0; \n\t"
".reg.b16 scaling_coeff; \n\t"
"mov.b16 scaling_coeff, %3; \n\t"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t"
".reg.b16 b03, b47; \n\t"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t"
"mov.b32 %0, {b03, b47};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "h"(reinterpret_cast<const uint16_t &>(scaling_coefficient)),
"r"(rbits03), "r"(rbits47));
} else if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, float>::value) {
asm volatile(
"{\n"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"cvt.f32.bf16 v4, v4_bf16; \n\t"
"cvt.f32.bf16 v5, v5_bf16; \n\t"
"cvt.f32.bf16 v6, v6_bf16; \n\t"
"cvt.f32.bf16 v7, v7_bf16; \n\t"
"mul.f32 v0, v0, %3; \n\t"
"mul.f32 v1, v1, %3; \n\t"
"mul.f32 v2, v2, %3; \n\t"
"mul.f32 v3, v3, %3; \n\t"
"mul.f32 v4, v4, %3; \n\t"
"mul.f32 v5, v5, %3; \n\t"
"mul.f32 v6, v6, %3; \n\t"
"mul.f32 v7, v7, %3; \n\t"
".reg.b16 b03, b47; \n\t"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t"
"mov.b32 %0, {b03, b47};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "f"(scaling_coefficient), "r"(rbits03), "r"(rbits47));
} else {
NVTE_DEVICE_ERROR("Not supported scaling coefficient type.");
}
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return out_8x;
}
#endif // FP4_TYPE_SUPPORTED #endif // FP4_TYPE_SUPPORTED
// SIMD like "Fused" cast + multiplication (x2) // SIMD like "Fused" cast + multiplication (x2)
...@@ -1521,6 +1779,59 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) { ...@@ -1521,6 +1779,59 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
: "r"(in2[0]), "r"(in2[1])); : "r"(in2[0]), "r"(in2[1]));
return out; return out;
} }
// Loads single BF16/FP16 element from shared memory state space
__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) {
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
bf16 dst;
asm volatile("ld.shared.b16 %0, [%1];"
: "=h"(reinterpret_cast<uint16_t &>(dst))
: "r"(src_smem_ptr));
return dst;
}
// Loads pair of BF16/FP16 values from shared memory state space
__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) {
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
bf16x2 dst;
asm volatile("ld.shared.b32 %0, [%1];"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(src_smem_ptr));
return dst;
}
// Loads 8x BF16 values from shared memory state space
__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) {
uint64_t elts03, elts47;
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
asm volatile(
"{\n\t"
".reg.b128 xy; \n\t"
"ld.shared.b128 xy, [%2]; \n\t"
"mov.b128 {%0, %1}, xy; \n"
"}\n"
: "=l"(elts03), "=l"(elts47)
: "r"(src_smem_ptr));
return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03);
}
#if FP4_TYPE_SUPPORTED
// Vectorized store of x8 FP4 elements into shared memory state space
__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem,
uint32_t fp4_pack_x8) {
const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem);
asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8));
}
#endif
// Vectorized store of x16 FP4 elements into shared memory state space
#if FP4_TYPE_SUPPORTED
__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem,
uint64_t fp4_pack_x16) {
const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem);
asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16));
}
#endif
#endif #endif
} // namespace ptx } // namespace ptx
......
...@@ -2,17 +2,28 @@ ...@@ -2,17 +2,28 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect""" """DisableFP8GEMM Feature support for nvidia-dlframework-inspect
from nvdlfw_inspect.registry import Registry, api_method DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
from transformer_engine.debug.features.api import TEConfigAPIMapper New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM
@Registry.register_feature(namespace="transformer_engine") @Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper): class DisableFP8GEMM(DisableQuantizationGEMM):
""" """
GEMM operations are executed in higher precision, even when FP8 autocast is enabled. GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Parameters Parameters
---------- ----------
...@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper): ...@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers: layers:
layer_types: [fc1] layer_types: [fc1]
transformer_engine: transformer_engine:
DisableFP8GEMM: DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
enabled: True enabled: True
gemms: [dgrad, wgrad] gemms: [dgrad, wgrad]
""" """
@api_method def __init__(self, *args, **kwargs):
def fp8_gemm_enabled( warnings.warn(
self, config, layer_name: str, gemm: str, iteration: int "DisableFP8GEMM is deprecated. "
): # pylint: disable=unused-argument "Use DisableQuantizationGEMM instead, which works with all quantization "
"""API call responsible for choice between high-precision and FP8 GEMM execution.""" "formats (FP8, NVFP4, etc.).",
DeprecationWarning,
for key in config: stacklevel=2,
if key != "gemm": )
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') super().__init__(*args, **kwargs)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
...@@ -2,17 +2,27 @@ ...@@ -2,17 +2,27 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect""" """DisableFP8Layer Feature support for nvidia-dlframework-inspect
import nvdlfw_inspect.api as debug_api DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
from nvdlfw_inspect.registry import Registry, api_method New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer
@Registry.register_feature(namespace="transformer_engine") @Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer: class DisableFP8Layer(DisableQuantizationLayer):
""" """
Disables all FP8 GEMMs in the layer. Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example Example
------- -------
...@@ -20,36 +30,19 @@ class DisableFP8Layer: ...@@ -20,36 +30,19 @@ class DisableFP8Layer:
example_disable_fp8_layer: example_disable_fp8_layer:
enabled: True enabled: True
layers: layers:
layer_types: [fc1] layer_types: [fc1]
transformer_engine: transformer_engine:
DisableFP8Layer: DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
enabled: True enabled: True
""" """
@api_method def __init__(self, *args, **kwargs):
def fp8_gemm_enabled( warnings.warn(
self, config, layer_name: str, gemm: str, iteration: int "DisableFP8Layer is deprecated. "
): # pylint: disable=unused-argument "Use DisableQuantizationLayer instead, which works with all quantization "
"""API call responsible for selecting between high-precision and FP8 GEMM execution.""" "formats (FP8, NVFP4, etc.).",
for key in config: DeprecationWarning,
if key not in ["enabled", "gemm"]: stacklevel=2,
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') )
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config. super().__init__(*args, **kwargs)
debug_api.log_message("FP8 Disabled", layer_name)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationGEMM(TEConfigAPIMapper):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Parameters
----------
gemms: List[str]
list of gemms to disable quantization for
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect"""
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationLayer:
"""
Disables all quantized GEMMs in the layer, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Example
-------
.. code-block:: yaml
example_disable_quantization_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationLayer:
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If quantized training, disable quantization for the selected layers if this feature is enabled.
debug_api.log_message("Quantization Disabled", layer_name)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API.
DisableQuantizationLayer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
...@@ -6,15 +6,17 @@ ...@@ -6,15 +6,17 @@
from typing import Dict, Optional, List, Tuple from typing import Dict, Optional, List, Tuple
from contextlib import contextmanager from contextlib import contextmanager
import warnings
import torch import torch
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
import transformer_engine_torch as tex
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
...@@ -22,7 +24,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -22,7 +24,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
try:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
_nvfp4_available = True
except ImportError:
_nvfp4_available = False
NVFP4Quantizer = None
ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
...@@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]): ...@@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
return "mxfp8" return "mxfp8"
if isinstance(quantizer, Float8BlockQuantizer): if isinstance(quantizer, Float8BlockQuantizer):
return "fp8_block_scaling" return "fp8_block_scaling"
if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer):
return "nvfp4"
raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") raise ValueError(f"Unsupported quantizer type: {type(quantizer)}")
...@@ -164,6 +175,16 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -164,6 +175,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")
# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError( raise ValueError(
f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for"
...@@ -189,6 +210,7 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -189,6 +210,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string.""" """Returns the recipe name from the stat string."""
columnwise_stat = stat.endswith("_columnwise") columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES: for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat: if recipe_name in stat:
...@@ -213,7 +235,7 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -213,7 +235,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
Yields the aux_dict. Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor. Needs to clean after usage, because it possibly change the usage of the quantized tensor.
""" """
fp8_dtype = None fp8_dtype = tex.DType.kFloat8E4M3
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance( assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
...@@ -277,14 +299,26 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -277,14 +299,26 @@ class LogFp8TensorStats(BaseLogTensorStats):
API call used to collect the data about the tensor after process_tensor()/quantization. API call used to collect the data about the tensor after process_tensor()/quantization.
""" """
assert rowwise_quantized_tensor is columnwise_quantized_tensor assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None # Skip logging if quantizer is None (layer runs in high precision)
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." if quantizer is None:
warnings.warn(
f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': layer runs in high precision (no quantizer)."
)
return
quantized_tensor = rowwise_quantized_tensor quantized_tensor = rowwise_quantized_tensor
assert isinstance(
quantized_tensor, QuantizedTensor # Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision)
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." if not isinstance(quantized_tensor, QuantizedTensor):
warnings.warn(
f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected QuantizedTensor, got {type(quantized_tensor).__name__})."
)
return
recipe_name = _get_recipe_name(quantizer) recipe_name = _get_recipe_name(quantizer)
for stat in config["stats"]: for stat in config["stats"]:
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Optional
from contextlib import contextmanager
import warnings
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
@Registry.register_feature(namespace="transformer_engine")
class LogNvfp4TensorStats(BaseLogTensorStats):
"""Logs statistics of NVFP4 quantized tensors.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_nvfp4_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 1
- tensor: gradient
stats: [underflows%, mse]
freq: 5
start_step: 0
end_step: 80
"""
def check_if_stat_is_supported(self, stat: str):
"""Returns True if stat is supported, raises ValueError otherwise."""
supported_stats = [
"underflows%",
"mse",
]
if stat not in supported_stats:
raise ValueError(
f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}"
)
return True
def get_stat_with_prefix(self, stat: str) -> str:
"""Add nvfp4_ prefix to stat name for use in stats_computation."""
return f"nvfp4_{stat}"
@contextmanager
def update_aux_dict(
self,
aux_dict: Dict,
quantized_tensor: QuantizedTensor,
quantizer: Quantizer, # pylint: disable=unused-argument
original_tensor: torch.Tensor,
):
"""
Updates the aux_dict with the quantized tensor and additional NVFP4-specific data.
Yields the aux_dict.
"""
aux_dict = {
"nvfp4": quantized_tensor,
"original_tensor": original_tensor,
}
try:
yield aux_dict
finally:
pass
@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
tp_group,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
# Skip logging if quantizer is None (layer runs in high precision)
if quantizer is None:
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': layer runs in high precision (no quantizer)."
)
return
quantized_tensor = rowwise_quantized_tensor
# Skip logging if not NVFP4 quantizer (incompatible precision)
if not isinstance(quantizer, NVFP4Quantizer):
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected NVFP4Quantizer, got {type(quantizer).__name__})."
)
return
# Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision)
if not isinstance(quantized_tensor, NVFP4TensorStorage):
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected NVFP4TensorStorage, got {type(quantized_tensor).__name__})."
)
return
for stat in config["stats"]:
self.check_if_stat_is_supported(stat)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = (
start_step,
end_step,
start_end_list,
"nvfp4",
)
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
)
# Add nvfp4_ prefix to all stats for internal use
prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]]
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=prefixed_stats,
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
with self.update_aux_dict(
aux_dict={},
quantized_tensor=quantized_tensor,
quantizer=quantizer,
original_tensor=tensor,
) as aux_dict:
STATS_BUFFERS.feed(
layer_name,
tensor_name,
options,
tensor,
iteration,
skip_reduction,
aux_dict=aux_dict,
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
...@@ -443,3 +443,65 @@ for _columnwise in [True, False]: ...@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
add_underflows_stats(_recipe_name, _columnwise) add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise) add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise) add_mse_stats(_recipe_name, _columnwise)
# NVFP4-specific statistics
def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor:
"""Count the number of non-zero elements in the FP4 data.
FP4 data is stored as 2 4-bit values per byte (uint8).
We need to unpack and count non-zeros.
"""
# Each byte contains two FP4 values
# Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0)
zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8)
# Extract first and second nibbles
first_nibble = fp4_data % 16
second_nibble = fp4_data // 16
# Count zeros
first_zeros = torch.isin(first_nibble, zero_vals).sum()
second_zeros = torch.isin(second_nibble, zero_vals).sum()
total_elements = fp4_data.numel() * 2
return total_elements - first_zeros - second_zeros
def add_nvfp4_underflows_stats():
"""Register underflow stats for NVFP4.
Computes underflows by counting zeros in packed FP4 data vs original tensor.
"""
stat_num = "nvfp4_underflows_num"
stat_pct = "nvfp4_underflows%"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
# Count non-zeros in original vs FP4 packed data
STATS[stat_num] = (
lambda x, aux_dict: x.count_nonzero()
- count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data)
)
/ aux_dict["nvfp4"].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
# Register NVFP4 stats
add_nvfp4_underflows_stats()
add_mse_stats("nvfp4") # Reuse existing MSE function
...@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = { ...@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
} }
API_CALL_MODIFY = "modify_tensor()" API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize" STANDARD_QUANTIZE = "Quantize"
HIGH_PRECISION = "High Precision" HIGH_PRECISION = "High Precision"
...@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer): ...@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
# inspect_tensor*_enabled are bool fields, # inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls. # indicating whether some feature will need to run inspect_tensor_* calls.
# #
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor. # determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor: if self.output_tensor:
...@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer): ...@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
def get_tensors_plan(self): def get_tensors_plan(self):
""" """
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors. of this quantizer with respect to these tensors.
""" """
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
...@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer): ...@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
rowwise_plan = API_CALL_MODIFY rowwise_plan = API_CALL_MODIFY
else: else:
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call( quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name, layer_name=self.layer_name,
gemm=self.rowwise_gemm_name, gemm=self.rowwise_gemm_name,
iteration=self.iteration, iteration=self.iteration,
) )
) )
if fp8_quantize: if quantize_enabled:
rowwise_plan = STANDARD_FP8_QUANTIZE rowwise_plan = STANDARD_QUANTIZE
if rowwise_plan is None: if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION rowwise_plan = HIGH_PRECISION
...@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer): ...@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
columnwise_plan = API_CALL_MODIFY columnwise_plan = API_CALL_MODIFY
else: else:
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call( quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name, layer_name=self.layer_name,
gemm=self.columnwise_gemm_name, gemm=self.columnwise_gemm_name,
iteration=self.iteration, iteration=self.iteration,
) )
) )
if fp8_quantize: if quantize_enabled:
columnwise_plan = STANDARD_FP8_QUANTIZE columnwise_plan = STANDARD_QUANTIZE
if columnwise_plan is None: if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION columnwise_plan = HIGH_PRECISION
...@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer): ...@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
del args["quantizer"] del args["quantizer"]
if ( if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise and self.inspect_tensor_postquantize_enabled_rowwise
): ):
args["tensor"] = rowwise_gemm_tensor args["tensor"] = rowwise_gemm_tensor
...@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer): ...@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
debug_api.transformer_engine.inspect_tensor_postquantize(**args) debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if ( if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise and self.inspect_tensor_postquantize_enabled_columnwise
): ):
args["tensor"] = columnwise_gemm_tensor args["tensor"] = columnwise_gemm_tensor
...@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer): ...@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
self.parent_quantizer.set_usage(rowwise=True) self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
quantized_tensor = self.parent_quantizer(tensor) quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized,
# one tensor with columnwise=True and rowwise=True is computed # one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it. # and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: if self.rowwise_tensor_plan == STANDARD_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: if self.columnwise_tensor_plan == STANDARD_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used. # 2. modify_tensor() is called, if it is used.
...@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer): ...@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
"""This call is invoked after the gemm to inspect and modify the output tensor.""" """This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." assert self.parent_quantizer is None, "Quantized output is not supported for debug=True."
assert self.output_tensor assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY: if self.rowwise_tensor_plan == API_CALL_MODIFY:
...@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer): ...@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
): ):
return True return True
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: if self.rowwise_tensor_plan != STANDARD_QUANTIZE:
return True return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: if self.columnwise_tensor_plan != STANDARD_QUANTIZE:
return True return True
return False return False
...@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer): ...@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
if ( if (
dst.rowwise_gemm_tensor is not None dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE and self.rowwise_tensor_plan == STANDARD_QUANTIZE
): ):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"): if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
...@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer): ...@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
updated_rowwise_gemm = True updated_rowwise_gemm = True
if ( if (
dst.columnwise_gemm_tensor is not None dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE and self.columnwise_tensor_plan == STANDARD_QUANTIZE
and not updated_rowwise_gemm and not updated_rowwise_gemm
): ):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"): if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
...@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer): ...@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
""" """
Updates the usage of the parent quantizer. Updates the usage of the parent quantizer.
""" """
rowwise_gemm_quantize = ( rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = ( columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE
) )
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage( self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize, rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize, columnwise=columnwise_gemm_quantize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment