Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1165 additions
and
143 deletions
+1165
-143
transformer_engine/common/include/transformer_engine/comm_gemm.h
...rmer_engine/common/include/transformer_engine/comm_gemm.h
+2
-0
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+38
-28
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+171
-0
transformer_engine/common/include/transformer_engine/hadamard_transform.h
...ne/common/include/transformer_engine/hadamard_transform.h
+34
-0
transformer_engine/common/include/transformer_engine/multi_tensor.h
...r_engine/common/include/transformer_engine/multi_tensor.h
+11
-0
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+21
-14
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+2
-2
transformer_engine/common/triton/permutation.py
transformer_engine/common/triton/permutation.py
+12
-0
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+8
-0
transformer_engine/common/util/cuda_runtime.h
transformer_engine/common/util/cuda_runtime.h
+6
-0
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+6
-6
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+314
-3
transformer_engine/debug/features/disable_fp8_gemm.py
transformer_engine/debug/features/disable_fp8_gemm.py
+25
-19
transformer_engine/debug/features/disable_fp8_layer.py
transformer_engine/debug/features/disable_fp8_layer.py
+28
-35
transformer_engine/debug/features/disable_quantization_gemm.py
...former_engine/debug/features/disable_quantization_gemm.py
+59
-0
transformer_engine/debug/features/disable_quantization_layer.py
...ormer_engine/debug/features/disable_quantization_layer.py
+61
-0
transformer_engine/debug/features/log_fp8_tensor_stats.py
transformer_engine/debug/features/log_fp8_tensor_stats.py
+43
-9
transformer_engine/debug/features/log_nvfp4_tensor_stats.py
transformer_engine/debug/features/log_nvfp4_tensor_stats.py
+237
-0
transformer_engine/debug/features/utils/stats_computation.py
transformer_engine/debug/features/utils/stats_computation.py
+62
-0
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+25
-27
No files found.
transformer_engine/common/include/transformer_engine/comm_gemm.h
View file @
9df0c4a3
...
@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
...
@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context.
/*! \brief Destroy a comm-gemm context.
*
*
* \param[in] ctx Context to destroy.
* \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/
*/
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
);
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
);
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
9df0c4a3
...
@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
...
@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
);
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
,
bool
deterministic
);
/*! \brief Compute dot product attention with packed QKV input.
/*! \brief Compute dot product attention with packed QKV input.
*
*
...
@@ -269,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
...
@@ -269,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] softmax_type Attention softmax type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
[[
deprecated
(
[[
deprecated
(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead."
)]]
"Q, K, V tensors instead."
)]]
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
*
...
@@ -332,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
...
@@ -332,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] softmax_type Attention softmax type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
...
@@ -346,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
...
@@ -346,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute dot product attention with packed KV input.
/*! \brief Compute dot product attention with packed KV input.
*
*
...
@@ -409,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
...
@@ -409,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
...
@@ -424,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -424,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed KV input.
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
*
...
@@ -478,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
...
@@ -478,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
...
@@ -494,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -494,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
cudaStream_t
stream
);
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute dot product attention with separate Q, K and V.
/*! \brief Compute dot product attention with separate Q, K and V.
*
*
...
@@ -559,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
...
@@ -559,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
* \param[in] stream CUDA stream used for this operation.
*/
*/
void
nvte_fused_attn_fwd
(
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
*
...
@@ -628,6 +636,7 @@ void nvte_fused_attn_fwd(
...
@@ -628,6 +636,7 @@ void nvte_fused_attn_fwd(
* \param[in] softmax_type Attention softmax type.
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
* \param[in] workspace Workspace tensor.
...
@@ -643,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
...
@@ -643,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Update the RNG state with the seed and calculated offset.
/*! \brief Update the RNG state with the seed and calculated offset.
*
*
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
9df0c4a3
...
@@ -11,6 +11,8 @@
...
@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#include <stdint.h>
#include "transformer_engine.h"
#include "transformer_engine.h"
#ifdef __cplusplus
#ifdef __cplusplus
...
@@ -20,6 +22,9 @@ extern "C" {
...
@@ -20,6 +22,9 @@ extern "C" {
/*! \brief Configuration for matrix multiplication. */
/*! \brief Configuration for matrix multiplication. */
typedef
void
*
NVTEMatmulConfig
;
typedef
void
*
NVTEMatmulConfig
;
/*! \brief Configuration for grouped matrix multiplication. */
typedef
void
*
NVTEGroupedMatmulConfig
;
/*! \enum NVTEMatmulConfigAttribute
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
* \brief Type of option for matrix multiplication.
*/
*/
...
@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
...
@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
kNVTEMatmulConfigNumAttributes
kNVTEMatmulConfigNumAttributes
};
};
/*! \enum NVTEGroupedMatmulConfigAttribute
* \brief Type of option for grouped matrix multiplication.
*/
enum
NVTEGroupedMatmulConfigAttribute
{
/*! Average M dimension hint
*
* Optional hint for average M dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgM
=
0
,
/*! Average N dimension hint
*
* Optional hint for average N dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgN
=
1
,
/*! Average K (reduction) dimension hint
*
* Optional hint for average K dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from A's logical shape.
*/
kNVTEGroupedMatmulConfigAvgK
=
2
,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEGroupedMatmulConfigSMCount
=
3
,
kNVTEGroupedMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig
nvte_create_matmul_config
();
NVTEMatmulConfig
nvte_create_matmul_config
();
...
@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
...
@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Destroy a matrix multiplication configuration. */
/*! \brief Destroy a matrix multiplication configuration. */
void
nvte_destroy_matmul_config
(
NVTEMatmulConfig
config
);
void
nvte_destroy_matmul_config
(
NVTEMatmulConfig
config
);
/*! \brief Create a grouped matrix multiplication configuration. */
NVTEGroupedMatmulConfig
nvte_create_grouped_matmul_config
();
/*! \brief Query an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void
nvte_get_grouped_matmul_config_attribute
(
NVTEGroupedMatmulConfig
config
,
NVTEGroupedMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
);
/*! \brief Set an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void
nvte_set_grouped_matmul_config_attribute
(
NVTEGroupedMatmulConfig
config
,
NVTEGroupedMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
);
/*! \brief Destroy a grouped matrix multiplication configuration. */
void
nvte_destroy_grouped_matmul_config
(
NVTEGroupedMatmulConfig
config
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
* This has been deprecated in favor of nvte_cublas_gemm_v2.
...
@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
...
@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
* Performs batched GEMM on a collection of matrices with potentially different shapes.
* All tensors in the group must have compatible dimensions for matrix multiplication.
* Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous
* memory layout and shape metadata.
*
* \param[in] A Input grouped tensor A.
* \param[in] transa Whether to transpose A matrices.
* \param[in] B Input grouped tensor B.
* \param[in] transb Whether to transpose B matrices.
* \param[in] C Input grouped tensor C (can be NULL for beta=0).
* \param[out] D Output grouped tensor D.
* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements).
* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements).
* \param[in] workspace_setup Workspace tensor for pointer array setup.
* \param[in] workspace_cublas Workspace tensor for cuBLAS operations.
* \param[in] config Additional configuration (can be NULL for defaults).
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.2+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
void
nvte_grouped_gemm
(
const
NVTEGroupedTensor
A
,
int
transa
,
const
NVTEGroupedTensor
B
,
int
transb
,
const
NVTEGroupedTensor
C
,
NVTEGroupedTensor
D
,
const
NVTETensor
alpha
,
const
NVTETensor
beta
,
NVTETensor
workspace_setup
,
NVTETensor
workspace_cublas
,
NVTEGroupedMatmulConfig
config
,
cudaStream_t
stream
);
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
...
@@ -356,6 +463,70 @@ class MatmulConfigWrapper {
...
@@ -356,6 +463,70 @@ class MatmulConfigWrapper {
NVTEMatmulConfig
config_
=
nullptr
;
NVTEMatmulConfig
config_
=
nullptr
;
};
};
/*! \struct GroupedMatmulConfigWrapper
* \brief C++ wrapper for NVTEGroupedMatmulConfig.
*/
class
GroupedMatmulConfigWrapper
{
public:
GroupedMatmulConfigWrapper
()
:
config_
{
nvte_create_grouped_matmul_config
()}
{}
GroupedMatmulConfigWrapper
(
const
GroupedMatmulConfigWrapper
&
)
=
delete
;
GroupedMatmulConfigWrapper
&
operator
=
(
const
GroupedMatmulConfigWrapper
&
)
=
delete
;
GroupedMatmulConfigWrapper
(
GroupedMatmulConfigWrapper
&&
other
)
:
config_
{
other
.
config_
}
{
other
.
config_
=
nullptr
;
}
GroupedMatmulConfigWrapper
&
operator
=
(
GroupedMatmulConfigWrapper
&&
other
)
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_grouped_matmul_config
(
config_
);
}
config_
=
other
.
config_
;
other
.
config_
=
nullptr
;
return
*
this
;
}
~
GroupedMatmulConfigWrapper
()
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_grouped_matmul_config
(
config_
);
config_
=
nullptr
;
}
}
/*! \brief Get the underlying NVTEGroupedMatmulConfig.
*
* \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.
*/
operator
NVTEGroupedMatmulConfig
()
const
noexcept
{
return
config_
;
}
/*! \brief Set average M dimension hint for algorithm selection. */
void
set_avg_m
(
int64_t
avg_m
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigAvgM
,
&
avg_m
,
sizeof
(
int64_t
));
}
/*! \brief Set average N dimension hint for algorithm selection. */
void
set_avg_n
(
int64_t
avg_n
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigAvgN
,
&
avg_n
,
sizeof
(
int64_t
));
}
/*! \brief Set average K dimension hint for algorithm selection. */
void
set_avg_k
(
int64_t
avg_k
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigAvgK
,
&
avg_k
,
sizeof
(
int64_t
));
}
/*! \brief Set number of streaming multiprocessors to use. */
void
set_sm_count
(
int
sm_count
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigSMCount
,
&
sm_count
,
sizeof
(
int
));
}
private:
/*! \brief Wrapped NVTEGroupedMatmulConfig. */
NVTEGroupedMatmulConfig
config_
=
nullptr
;
};
}
// namespace transformer_engine
}
// namespace transformer_engine
#endif // __cplusplus
#endif // __cplusplus
...
...
transformer_engine/common/include/transformer_engine/hadamard_transform.h
View file @
9df0c4a3
...
@@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
...
@@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int
random_sign_mask
,
int
random_sign_mask_t
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
cudaStream_t
stream
);
/*! \brief Grouped-tensor amax with Hadamard transform (graph safe, device-managed grouping).
*
* This function is experimental and the API is not stable.
*
* This API assumes that the split info (grouping of tensors) is on device and unknown to the host;
* therefore, this is a graph safe API and the grouped-tensor argument is passed as a single device structure.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output amax (row/col). Only the row-wise and
* column-wise amaxes are updated.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_amax_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*!
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
*
...
@@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
...
@@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
const
NVTEQuantizationConfig
quant_config
,
const
NVTEQuantizationConfig
quant_config
,
NVTETensor
quant_workspace
,
cudaStream_t
stream
);
NVTETensor
quant_workspace
,
cudaStream_t
stream
);
/*!
* \brief Perform the grouped-tensor Hadamard transform cast fusion operation in graph-safe mode.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output (row/column-wise quantized results).
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_cast_fusion_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
NVTETensor
quant_workspace
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/include/transformer_engine/multi_tensor.h
View file @
9df0c4a3
...
@@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **
...
@@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **
void
nvte_group_amax
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
void
nvte_group_amax
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
size_t
num_tensors
,
cudaStream_t
stream
);
size_t
num_tensors
,
cudaStream_t
stream
);
/*! \brief Grouped-tensor amax without doing hadamard transform.
*
* This function is experimental and the API is not stable.
*
* \param[in] input NVTEGroupedTensor Input tensor.
* \param[in,out] output NVTEGroupedTensor Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_amax_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
#ifdef __cplusplus
#ifdef __cplusplus
}
// extern "C"
}
// extern "C"
#endif
#endif
...
...
transformer_engine/common/recipe/__init__.py
View file @
9df0c4a3
...
@@ -88,33 +88,40 @@ class Recipe:
...
@@ -88,33 +88,40 @@ class Recipe:
Base recipe class.
Base recipe class.
"""
"""
def
nvfp4
(
self
):
@
classmethod
def
nvfp4
(
cls
):
"""Whether the given recipe is NVFP4 1D block scaling."""
"""Whether the given recipe is NVFP4 1D block scaling."""
return
is
instance
(
self
,
NVFP4BlockScaling
)
return
is
subclass
(
cls
,
NVFP4BlockScaling
)
def
mxfp8
(
self
):
@
classmethod
def
mxfp8
(
cls
):
"""Whether the given recipe is MXFP8 block scaling."""
"""Whether the given recipe is MXFP8 block scaling."""
return
is
instance
(
self
,
MXFP8BlockScaling
)
return
is
subclass
(
cls
,
MXFP8BlockScaling
)
def
delayed
(
self
):
@
classmethod
def
delayed
(
cls
):
"""Whether the given recipe is delayed scaling."""
"""Whether the given recipe is delayed scaling."""
return
is
instance
(
self
,
DelayedScaling
)
return
is
subclass
(
cls
,
DelayedScaling
)
def
float8_current_scaling
(
self
):
@
classmethod
def
float8_current_scaling
(
cls
):
"""Whether the given recipe is (per-tensor) current scaling."""
"""Whether the given recipe is (per-tensor) current scaling."""
return
is
instance
(
self
,
Float8CurrentScaling
)
return
is
subclass
(
cls
,
Float8CurrentScaling
)
def
float8_per_tensor_scaling
(
self
):
@
classmethod
def
float8_per_tensor_scaling
(
cls
):
"""Whether the given recipe is per-tensor scaling."""
"""Whether the given recipe is per-tensor scaling."""
return
is
instance
(
self
,
(
DelayedScaling
,
Float8CurrentScaling
))
return
is
subclass
(
cls
,
(
DelayedScaling
,
Float8CurrentScaling
))
def
float8_block_scaling
(
self
):
@
classmethod
def
float8_block_scaling
(
cls
):
"""Whether the given recipe is float8 blockwise scaling."""
"""Whether the given recipe is float8 blockwise scaling."""
return
is
instance
(
self
,
Float8BlockScaling
)
return
is
subclass
(
cls
,
Float8BlockScaling
)
def
custom
(
self
):
@
classmethod
def
custom
(
cls
):
"""Whether the given recipe is custom."""
"""Whether the given recipe is custom."""
return
is
instance
(
self
,
CustomRecipe
)
return
is
subclass
(
cls
,
CustomRecipe
)
@
dataclass
()
@
dataclass
()
...
...
transformer_engine/common/transformer_engine.cpp
View file @
9df0c4a3
...
@@ -458,9 +458,9 @@ class TensorAllocator {
...
@@ -458,9 +458,9 @@ class TensorAllocator {
}
}
void
Free
(
NVTETensor
t
)
{
void
Free
(
NVTETensor
t
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
if
(
index
==
0
)
return
;
if
(
index
==
0
)
return
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid tensor."
);
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid tensor."
);
free_list
.
push_back
(
index
);
free_list
.
push_back
(
index
);
// Clean up
// Clean up
...
@@ -568,9 +568,9 @@ class GroupedTensorAllocator {
...
@@ -568,9 +568,9 @@ class GroupedTensorAllocator {
}
}
void
Free
(
NVTEGroupedTensor
t
)
{
void
Free
(
NVTEGroupedTensor
t
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
if
(
index
==
0
)
return
;
if
(
index
==
0
)
return
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid grouped tensor."
);
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid grouped tensor."
);
free_list
.
push_back
(
index
);
free_list
.
push_back
(
index
);
// Clean up
// Clean up
...
...
transformer_engine/common/triton/permutation.py
View file @
9df0c4a3
...
@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
...
@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr
+
load_split_offset
,
mask
=
load_split_offset
<
num_splits
,
other
=
0
split_sizes_ptr
+
load_split_offset
,
mask
=
load_split_offset
<
num_splits
,
other
=
0
).
to
(
tl
.
int32
)
).
to
(
tl
.
int32
)
input_split_sizes_cumsum
=
tl
.
cumsum
(
input_split_sizes
)
input_split_sizes_cumsum
=
tl
.
cumsum
(
input_split_sizes
)
# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens
=
tl
.
sum
(
input_split_sizes
)
input_split_sizes_mask
=
tl
.
where
(
input_split_sizes_cumsum
<=
pid
,
1
,
0
)
input_split_sizes_mask
=
tl
.
where
(
input_split_sizes_cumsum
<=
pid
,
1
,
0
)
input_chunk_idx
=
tl
.
sum
(
input_split_sizes_mask
)
input_chunk_idx
=
tl
.
sum
(
input_split_sizes_mask
)
input_split_sizes_presum
=
tl
.
sum
(
input_split_sizes
*
input_split_sizes_mask
)
input_split_sizes_presum
=
tl
.
sum
(
input_split_sizes
*
input_split_sizes_mask
)
...
@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
...
@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).
to
(
tl
.
int32
)
).
to
(
tl
.
int32
)
output_pre_split_sizes
=
tl
.
where
(
load_split_offset
<
output_chunk_idx
,
output_split_sizes
,
0
)
output_pre_split_sizes
=
tl
.
where
(
load_split_offset
<
output_chunk_idx
,
output_split_sizes
,
0
)
dst_row
=
tl
.
sum
(
output_pre_split_sizes
)
+
in_chunk_offset
dst_row
=
tl
.
sum
(
output_pre_split_sizes
)
+
in_chunk_offset
# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row
=
tl
.
where
(
pid
<
total_valid_tokens
,
dst_row
,
pid
)
tl
.
store
(
dst_rows_ptr
+
pid
,
dst_row
)
tl
.
store
(
dst_rows_ptr
+
pid
,
dst_row
)
...
...
transformer_engine/common/util/cuda_runtime.cpp
View file @
9df0c4a3
...
@@ -6,6 +6,8 @@
...
@@ -6,6 +6,8 @@
#include "../util/cuda_runtime.h"
#include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem>
#include <filesystem>
#include <mutex>
#include <mutex>
...
@@ -232,6 +234,12 @@ int cudart_version() {
...
@@ -232,6 +234,12 @@ int cudart_version() {
return
version
;
return
version
;
}
}
size_t
cublas_version
()
{
// Cache version to avoid cuBLAS logging overhead
static
size_t
version
=
cublasLtGetVersion
();
return
version
;
}
}
// namespace cuda
}
// namespace cuda
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/common/util/cuda_runtime.h
View file @
9df0c4a3
...
@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false);
...
@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false);
*/
*/
int
cudart_version
();
int
cudart_version
();
/* \brief cuBLAS version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
size_t
cublas_version
();
}
// namespace cuda
}
// namespace cuda
}
// namespace transformer_engine
}
// namespace transformer_engine
...
...
transformer_engine/common/util/logging.h
View file @
9df0c4a3
...
@@ -141,12 +141,12 @@
...
@@ -141,12 +141,12 @@
#ifdef NVTE_WITH_CUBLASMP
#ifdef NVTE_WITH_CUBLASMP
#define NVTE_CHECK_CUBLASMP(expr) \
#define NVTE_CHECK_CUBLASMP(expr)
\
do { \
do {
\
const cublasMpStatus_t status = (expr); \
const cublasMpStatus_t status = (expr);
\
if (status != CUBLASMP_STATUS_SUCCESS) { \
if (status != CUBLASMP_STATUS_SUCCESS) {
\
NVTE_ERROR("cuBLASMp Error: ",
std::to_s
tring(status)); \
NVTE_ERROR("cuBLASMp Error: ",
cublasMpGetStatusS
tring(status)); \
} \
}
\
} while (false)
} while (false)
#endif // NVTE_WITH_CUBLASMP
#endif // NVTE_WITH_CUBLASMP
...
...
transformer_engine/common/util/ptx.cuh
View file @
9df0c4a3
...
@@ -172,6 +172,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const
...
@@ -172,6 +172,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mbarrier_arrive_expect_tx_cta_relaxed_shared_cta
(
uint64_t
*
mbar
,
const
uint32_t
tx_count
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;"
::
"r"
(
mbar_ptr
),
"r"
(
tx_count
));
#else
NVTE_DEVICE_ERROR
(
"mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
fence_mbarrier_init_release_cluster
()
{
__device__
__forceinline__
void
fence_mbarrier_init_release_cluster
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
"fence.mbarrier_init.release.cluster;"
);
asm
volatile
(
"fence.mbarrier_init.release.cluster;"
);
...
@@ -251,13 +263,86 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
...
@@ -251,13 +263,86 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
__device__
__forceinline__
void
mbarrier_wait_parity_acquire_cta_shared_cta
(
uint64_t
*
mbar
,
uint32_t
phase_parity
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"{
\n\t
"
".reg .b64 r1;
\n\t
"
".reg .pred waitComplete;
\n\t
"
// predicate representing if barrier condition is met
"WAIT:
\n\t
"
// loop around barrier wait
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1;
\n\t
"
"@waitComplete bra DONE;
\n\t
"
// mbarrier conditions are met
"bra WAIT;
\n\t
"
// just a time-out, try again
"DONE:
\n\t
"
"}
\n\t
"
:
:
"r"
(
mbar_ptr
),
"r"
(
phase_parity
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
try_cancel_cta
(
uint64_t
*
mbar
,
__uint128_t
*
response_data_ptr
)
{
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
workID_response
=
__cvta_generic_to_shared
(
response_data_ptr
);
asm
volatile
(
"clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::"
"all.b128 "
"[%0], [%1];"
::
"r"
(
workID_response
),
"r"
(
mbar_ptr
));
}
else
{
NVTE_DEVICE_ERROR
(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
}
__device__
__forceinline__
void
get_cancelled_cta_id_2D
(
__uint128_t
*
response_data_ptr
,
int32_t
&
ctaid_X
,
int32_t
&
ctaid_Y
)
{
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
uint32_t
workID_response
=
__cvta_generic_to_shared
(
response_data_ptr
);
asm
volatile
(
"{
\n\t
"
".reg .s32 x_ctaid;
\n\t
"
".reg .s32 y_ctaid;
\n\t
"
"mov .s32 x_ctaid, -1;
\n\t
"
"mov .s32 y_ctaid, -1;
\n\t
"
".reg.b128 try_cancel_response;
\n\t
"
"ld.shared.b128 try_cancel_response, [%2];
\n\t
"
".reg .pred P1;
\n\t
"
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response;
\n\t
"
"@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, "
"_}, try_cancel_response;
\n\t
"
"mov .s32 %0, x_ctaid;
\n\t
"
"mov .s32 %1, y_ctaid;
\n\t
"
"}
\n\t
"
:
"=r"
(
ctaid_X
),
"=r"
(
ctaid_Y
)
:
"r"
(
workID_response
)
:
"memory"
);
}
else
{
NVTE_DEVICE_ERROR
(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
}
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
__device__
__forceinline__
float
exp2f_rcp
(
e8m0_t
biased_exp
)
{
__device__
__forceinline__
float
exp2f_rcp
(
e8m0_t
biased_exp
)
{
return
(
biased_exp
==
0
)
?
1
// Handle the special case of NaN.
:
__int_as_float
((
254
-
biased_exp
)
if
(
biased_exp
==
255
)
return
__int_as_float
(
0x7fffffff
);
<<
FP32_MANTISSA_BITS
);
// 127 - (biased_exp - 127)
// Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of
// the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left.
if
(
biased_exp
==
254
)
return
__int_as_float
(
0x00400000
);
// Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal.
return
__int_as_float
((
254
-
biased_exp
)
<<
FP32_MANTISSA_BITS
);
}
}
__device__
__forceinline__
float
exp2f
(
e8m0_t
biased_exp
)
{
__device__
__forceinline__
float
exp2f
(
e8m0_t
biased_exp
)
{
...
@@ -671,6 +756,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
...
@@ -671,6 +756,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
return
mul_cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
scale
,
rbits
);
return
mul_cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
scale
,
rbits
);
}
}
}
}
template
<
typename
SCALING_COEFFICIENT_TYPE
>
__device__
__forceinline__
uint32_t
mul_cvt_bf16_to_fp4_8x_round_to_nearest
(
const
uint64_t
in03
,
const
uint64_t
in47
,
const
SCALING_COEFFICIENT_TYPE
scaling_coefficient
)
{
uint32_t
out_8x
=
0
;
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
bf16
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.f32 zero;
\n\t
"
"mov.b32 zero, 0;
\n\t
"
".reg.b16 scaling_coeff;
\n\t
"
"mov.b16 scaling_coeff, %3;
\n\t
"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h;
\n\t
"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1;
\n\t
"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2;
\n\t
"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero;
\n\t
"
".reg.b8 f0, f1, f2, f3;
\n\t
"
// Elements reordered to match e2m1x4 packing order (v1,v0)
"cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;
\n\t
"
"mov.b32 %0, {f0, f1, f2, f3};
\n
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"h"
(
reinterpret_cast
<
const
uint16_t
&>
(
scaling_coefficient
)));
}
else
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
float
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.b64 scaling_coeff_2x;
\n\t
"
"mov.b64 scaling_coeff_2x, {%3, %3};
\n\t
"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1;
\n\t
"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2;
\n\t
"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"cvt.f32.bf16 v4, v4_bf16;
\n\t
"
"cvt.f32.bf16 v5, v5_bf16;
\n\t
"
"cvt.f32.bf16 v6, v6_bf16;
\n\t
"
"cvt.f32.bf16 v7, v7_bf16;
\n\t
"
".reg.b64 v01, v23, v45, v67;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mov.b64 v45, {v4, v5};
\n\t
"
"mov.b64 v67, {v6, v7};
\n\t
"
"mul.f32x2 v01, v01, scaling_coeff_2x;
\n\t
"
"mul.f32x2 v23, v23, scaling_coeff_2x;
\n\t
"
"mul.f32x2 v45, v45, scaling_coeff_2x;
\n\t
"
"mul.f32x2 v67, v67, scaling_coeff_2x;
\n\t
"
// Elements reordered to match the packing order (v1,v0)
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"mov.b64 {v5, v4}, v45;
\n\t
"
"mov.b64 {v7, v6}, v67;
\n\t
"
".reg.b8 f0, f1, f2, f3;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;
\n\t
"
"mov.b32 %0, {f0, f1, f2, f3};
\n\t
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"f"
(
scaling_coefficient
));
}
else
{
NVTE_DEVICE_ERROR
(
"Not supported scaling coefficient type."
);
}
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
out_8x
;
}
template
<
typename
SCALING_COEFFICIENT_TYPE
>
__device__
__forceinline__
uint32_t
mul_cvt_bf16_to_fp4_8x_stochastic_rounding
(
const
uint64_t
in03
,
const
uint64_t
in47
,
const
SCALING_COEFFICIENT_TYPE
scaling_coefficient
,
const
uint32_t
rbits03
,
const
uint32_t
rbits47
)
{
uint32_t
out_8x
=
0
;
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
bf16
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.f32 zero;
\n\t
"
"mov.b32 zero, 0;
\n\t
"
".reg.b16 scaling_coeff;
\n\t
"
"mov.b16 scaling_coeff, %3;
\n\t
"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h;
\n\t
"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1;
\n\t
"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2;
\n\t
"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero;
\n\t
"
".reg.b16 b03, b47;
\n\t
"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5;
\n\t
"
"mov.b32 %0, {b03, b47};
\n
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"h"
(
reinterpret_cast
<
const
uint16_t
&>
(
scaling_coefficient
)),
"r"
(
rbits03
),
"r"
(
rbits47
));
}
else
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
float
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1;
\n\t
"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2;
\n\t
"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"cvt.f32.bf16 v4, v4_bf16;
\n\t
"
"cvt.f32.bf16 v5, v5_bf16;
\n\t
"
"cvt.f32.bf16 v6, v6_bf16;
\n\t
"
"cvt.f32.bf16 v7, v7_bf16;
\n\t
"
"mul.f32 v0, v0, %3;
\n\t
"
"mul.f32 v1, v1, %3;
\n\t
"
"mul.f32 v2, v2, %3;
\n\t
"
"mul.f32 v3, v3, %3;
\n\t
"
"mul.f32 v4, v4, %3;
\n\t
"
"mul.f32 v5, v5, %3;
\n\t
"
"mul.f32 v6, v6, %3;
\n\t
"
"mul.f32 v7, v7, %3;
\n\t
"
".reg.b16 b03, b47;
\n\t
"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5;
\n\t
"
"mov.b32 %0, {b03, b47};
\n
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"f"
(
scaling_coefficient
),
"r"
(
rbits03
),
"r"
(
rbits47
));
}
else
{
NVTE_DEVICE_ERROR
(
"Not supported scaling coefficient type."
);
}
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
out_8x
;
}
#endif // FP4_TYPE_SUPPORTED
#endif // FP4_TYPE_SUPPORTED
// SIMD like "Fused" cast + multiplication (x2)
// SIMD like "Fused" cast + multiplication (x2)
...
@@ -1521,6 +1779,59 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
...
@@ -1521,6 +1779,59 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
:
"r"
(
in2
[
0
]),
"r"
(
in2
[
1
]));
:
"r"
(
in2
[
0
]),
"r"
(
in2
[
1
]));
return
out
;
return
out
;
}
}
// Loads single BF16/FP16 element from shared memory state space
__device__
__forceinline__
bf16
ld_shared_b16
(
const
bf16
*
__restrict__
src_smem
)
{
const
uint32_t
src_smem_ptr
=
__cvta_generic_to_shared
(
src_smem
);
bf16
dst
;
asm
volatile
(
"ld.shared.b16 %0, [%1];"
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
dst
))
:
"r"
(
src_smem_ptr
));
return
dst
;
}
// Loads pair of BF16/FP16 values from shared memory state space
__device__
__forceinline__
bf16x2
ld_shared_b32
(
const
bf16x2
*
__restrict__
src_smem
)
{
const
uint32_t
src_smem_ptr
=
__cvta_generic_to_shared
(
src_smem
);
bf16x2
dst
;
asm
volatile
(
"ld.shared.b32 %0, [%1];"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
dst
))
:
"r"
(
src_smem_ptr
));
return
dst
;
}
// Loads 8x BF16 values from shared memory state space
__device__
__forceinline__
__uint128_t
ld_shared_b128
(
const
bf16
*
__restrict__
src_smem
)
{
uint64_t
elts03
,
elts47
;
const
uint32_t
src_smem_ptr
=
__cvta_generic_to_shared
(
src_smem
);
asm
volatile
(
"{
\n\t
"
".reg.b128 xy;
\n\t
"
"ld.shared.b128 xy, [%2];
\n\t
"
"mov.b128 {%0, %1}, xy;
\n
"
"}
\n
"
:
"=l"
(
elts03
),
"=l"
(
elts47
)
:
"r"
(
src_smem_ptr
));
return
(
static_cast
<
__uint128_t
>
(
elts47
)
<<
64
)
|
static_cast
<
__uint128_t
>
(
elts03
);
}
#if FP4_TYPE_SUPPORTED
// Vectorized store of x8 FP4 elements into shared memory state space
__device__
__forceinline__
void
st_shared_b32
(
fp4e2m1x2
*
__restrict__
dst_smem
,
uint32_t
fp4_pack_x8
)
{
const
uint32_t
dst_smem_ptr
=
__cvta_generic_to_shared
(
dst_smem
);
asm
volatile
(
"st.shared.b32 [%0], %1;"
:
:
"r"
(
dst_smem_ptr
),
"r"
(
fp4_pack_x8
));
}
#endif
// Vectorized store of x16 FP4 elements into shared memory state space
#if FP4_TYPE_SUPPORTED
__device__
__forceinline__
void
st_shared_b64
(
fp4e2m1x2
*
__restrict__
dst_smem
,
uint64_t
fp4_pack_x16
)
{
const
uint32_t
dst_smem_ptr
=
__cvta_generic_to_shared
(
dst_smem
);
asm
volatile
(
"st.shared.b64 [%0], %1;"
:
:
"r"
(
dst_smem_ptr
),
"l"
(
fp4_pack_x16
));
}
#endif
#endif
#endif
}
// namespace ptx
}
// namespace ptx
...
...
transformer_engine/debug/features/disable_fp8_gemm.py
View file @
9df0c4a3
...
@@ -2,17 +2,28 @@
...
@@ -2,17 +2,28 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
"""
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
from
nvdlfw_inspect.registry
import
Registry
,
api_method
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
from
transformer_engine.debug.features.api
import
TEConfigAPIMapper
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""
import
warnings
from
nvdlfw_inspect.registry
import
Registry
from
transformer_engine.debug.features.disable_quantization_gemm
import
DisableQuantizationGEMM
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableFP8GEMM
(
TEConfigAPIMapper
):
class
DisableFP8GEMM
(
DisableQuantizationGEMM
):
"""
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Parameters
Parameters
----------
----------
...
@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
...
@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers:
layers:
layer_types: [fc1]
layer_types: [fc1]
transformer_engine:
transformer_engine:
DisableFP8GEMM:
DisableFP8GEMM:
# Deprecated: use DisableQuantizationGEMM
enabled: True
enabled: True
gemms: [dgrad, wgrad]
gemms: [dgrad, wgrad]
"""
"""
@
api_method
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
fp8_gemm_enabled
(
warnings
.
warn
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
"DisableFP8GEMM is deprecated. "
):
# pylint: disable=unused-argument
"Use DisableQuantizationGEMM instead, which works with all quantization "
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
"formats (FP8, NVFP4, etc.)."
,
DeprecationWarning
,
for
key
in
config
:
stacklevel
=
2
,
if
key
!=
"gemm"
:
)
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
super
().
__init__
(
*
args
,
**
kwargs
)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
transformer_engine/debug/features/disable_fp8_layer.py
View file @
9df0c4a3
...
@@ -2,17 +2,27 @@
...
@@ -2,17 +2,27 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
"""
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
import
nvdlfw_inspect.api
as
debug_api
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
from
nvdlfw_inspect.registry
import
Registry
,
api_method
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""
import
warnings
from
nvdlfw_inspect.registry
import
Registry
from
transformer_engine.debug.features.disable_quantization_layer
import
DisableQuantizationLayer
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableFP8Layer
:
class
DisableFP8Layer
(
DisableQuantizationLayer
)
:
"""
"""
Disables all FP8 GEMMs in the layer.
Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example
Example
-------
-------
...
@@ -20,36 +30,19 @@ class DisableFP8Layer:
...
@@ -20,36 +30,19 @@ class DisableFP8Layer:
example_disable_fp8_layer:
example_disable_fp8_layer:
enabled: True
enabled: True
layers:
layers:
layer_types: [fc1]
layer_types: [fc1]
transformer_engine:
transformer_engine:
DisableFP8Layer:
DisableFP8Layer:
# Deprecated: use DisableQuantizationLayer
enabled: True
enabled: True
"""
"""
@
api_method
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
fp8_gemm_enabled
(
warnings
.
warn
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
"DisableFP8Layer is deprecated. "
):
# pylint: disable=unused-argument
"Use DisableQuantizationLayer instead, which works with all quantization "
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
"formats (FP8, NVFP4, etc.)."
,
for
key
in
config
:
DeprecationWarning
,
if
key
not
in
[
"enabled"
,
"gemm"
]:
stacklevel
=
2
,
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
)
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
super
().
__init__
(
*
args
,
**
kwargs
)
debug_api
.
log_message
(
"FP8 Disabled"
,
layer_name
)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
def
parse_config_and_api
(
self
,
config
,
**
_kwargs
):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return
config
[
"enabled"
],
None
transformer_engine/debug/features/disable_quantization_gemm.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.api
import
TEConfigAPIMapper
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableQuantizationGEMM
(
TEConfigAPIMapper
):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Parameters
----------
gemms: List[str]
list of gemms to disable quantization for
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@
api_method
def
fp8_gemm_enabled
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for
key
in
config
:
if
key
!=
"gemm"
:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
transformer_engine/debug/features/disable_quantization_layer.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect"""
import
nvdlfw_inspect.api
as
debug_api
from
nvdlfw_inspect.registry
import
Registry
,
api_method
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableQuantizationLayer
:
"""
Disables all quantized GEMMs in the layer, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Example
-------
.. code-block:: yaml
example_disable_quantization_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationLayer:
enabled: True
"""
@
api_method
def
fp8_gemm_enabled
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for
key
in
config
:
if
key
not
in
[
"enabled"
,
"gemm"
]:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
# If quantized training, disable quantization for the selected layers if this feature is enabled.
debug_api
.
log_message
(
"Quantization Disabled"
,
layer_name
)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
def
parse_config_and_api
(
self
,
config
,
**
_kwargs
):
"""Determines whether to run the API.
DisableQuantizationLayer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return
config
[
"enabled"
],
None
transformer_engine/debug/features/log_fp8_tensor_stats.py
View file @
9df0c4a3
...
@@ -6,15 +6,17 @@
...
@@ -6,15 +6,17 @@
from
typing
import
Dict
,
Optional
,
List
,
Tuple
from
typing
import
Dict
,
Optional
,
List
,
Tuple
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
warnings
import
torch
import
torch
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine_torch
as
tex
from
nvdlfw_inspect.debug_features.log_tensor_stats
import
LogTensorStats
as
BaseLogTensorStats
from
nvdlfw_inspect.debug_features.log_tensor_stats
import
LogTensorStats
as
BaseLogTensorStats
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils
import
get_reduction_params
,
next_enabled_iter
from
transformer_engine.pytorch.tensor
import
Quantizer
,
QuantizedTensor
from
transformer_engine.pytorch.tensor
import
Quantizer
,
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
Float8Quantizer
,
...
@@ -22,7 +24,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
...
@@ -22,7 +24,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.debug.features.utils
import
get_reduction_params
,
next_enabled_iter
try
:
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
_nvfp4_available
=
True
except
ImportError
:
_nvfp4_available
=
False
NVFP4Quantizer
=
None
ALL_RECIPE_NAMES
=
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
,
"fp8_block_scaling"
]
ALL_RECIPE_NAMES
=
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
,
"fp8_block_scaling"
]
...
@@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
...
@@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
return
"mxfp8"
return
"mxfp8"
if
isinstance
(
quantizer
,
Float8BlockQuantizer
):
if
isinstance
(
quantizer
,
Float8BlockQuantizer
):
return
"fp8_block_scaling"
return
"fp8_block_scaling"
if
_nvfp4_available
and
isinstance
(
quantizer
,
NVFP4Quantizer
):
return
"nvfp4"
raise
ValueError
(
f
"Unsupported quantizer type:
{
type
(
quantizer
)
}
"
)
raise
ValueError
(
f
"Unsupported quantizer type:
{
type
(
quantizer
)
}
"
)
...
@@ -164,6 +175,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
...
@@ -164,6 +175,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
if
recipe_from_stat
!=
""
and
recipe_from_stat
not
in
ALL_RECIPE_NAMES
:
if
recipe_from_stat
!=
""
and
recipe_from_stat
not
in
ALL_RECIPE_NAMES
:
raise
ValueError
(
f
"Stat
{
stat
}
contains an unsupported recipe name:
{
recipe_from_stat
}
"
)
raise
ValueError
(
f
"Stat
{
stat
}
contains an unsupported recipe name:
{
recipe_from_stat
}
"
)
# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if
recipe_from_stat
==
"nvfp4"
:
raise
ValueError
(
f
"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '
{
stat
}
' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
if
recipe_from_stat
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
]
and
columnwise
:
if
recipe_from_stat
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
]
and
columnwise
:
raise
ValueError
(
raise
ValueError
(
f
"Stat
{
stat
}
is not supported. Columnwise tensor statistics are not supported for"
f
"Stat
{
stat
}
is not supported. Columnwise tensor statistics are not supported for"
...
@@ -189,6 +210,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
...
@@ -189,6 +210,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
def
get_recipe_from_stat
(
self
,
stat
:
str
,
default_recipe
:
str
=
""
):
def
get_recipe_from_stat
(
self
,
stat
:
str
,
default_recipe
:
str
=
""
):
"""Returns the recipe name from the stat string."""
"""Returns the recipe name from the stat string."""
columnwise_stat
=
stat
.
endswith
(
"_columnwise"
)
columnwise_stat
=
stat
.
endswith
(
"_columnwise"
)
for
recipe_name
in
ALL_RECIPE_NAMES
:
for
recipe_name
in
ALL_RECIPE_NAMES
:
if
recipe_name
in
stat
:
if
recipe_name
in
stat
:
...
@@ -213,7 +235,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
...
@@ -213,7 +235,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
Yields the aux_dict.
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
"""
fp8_dtype
=
None
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
if
recipe_name
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"fp8_block_scaling"
]:
if
recipe_name
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"fp8_block_scaling"
]:
assert
isinstance
(
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
)
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
)
...
@@ -277,14 +299,26 @@ class LogFp8TensorStats(BaseLogTensorStats):
...
@@ -277,14 +299,26 @@ class LogFp8TensorStats(BaseLogTensorStats):
API call used to collect the data about the tensor after process_tensor()/quantization.
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
"""
assert
rowwise_quantized_tensor
is
columnwise_quantized_tensor
assert
rowwise_quantized_tensor
is
columnwise_quantized_tensor
assert
(
quantizer
is
not
None
# Skip logging if quantizer is None (layer runs in high precision)
),
"[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
if
quantizer
is
None
:
warnings
.
warn
(
f
"[LogFp8TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': layer runs in high precision (no quantizer)."
)
return
quantized_tensor
=
rowwise_quantized_tensor
quantized_tensor
=
rowwise_quantized_tensor
assert
isinstance
(
quantized_tensor
,
QuantizedTensor
# Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision)
),
"[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
if
not
isinstance
(
quantized_tensor
,
QuantizedTensor
):
warnings
.
warn
(
f
"[LogFp8TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': incompatible precision "
f
"(expected QuantizedTensor, got
{
type
(
quantized_tensor
).
__name__
}
)."
)
return
recipe_name
=
_get_recipe_name
(
quantizer
)
recipe_name
=
_get_recipe_name
(
quantizer
)
for
stat
in
config
[
"stats"
]:
for
stat
in
config
[
"stats"
]:
...
...
transformer_engine/debug/features/log_nvfp4_tensor_stats.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect"""
from
typing
import
Dict
,
Optional
from
contextlib
import
contextmanager
import
warnings
import
torch
import
nvdlfw_inspect.api
as
debug_api
from
nvdlfw_inspect.debug_features.log_tensor_stats
import
LogTensorStats
as
BaseLogTensorStats
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.pytorch.tensor
import
Quantizer
,
QuantizedTensor
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
transformer_engine.debug.features.utils
import
get_reduction_params
,
next_enabled_iter
from
transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
LogNvfp4TensorStats
(
BaseLogTensorStats
):
"""Logs statistics of NVFP4 quantized tensors.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_nvfp4_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 1
- tensor: gradient
stats: [underflows%, mse]
freq: 5
start_step: 0
end_step: 80
"""
def
check_if_stat_is_supported
(
self
,
stat
:
str
):
"""Returns True if stat is supported, raises ValueError otherwise."""
supported_stats
=
[
"underflows%"
,
"mse"
,
]
if
stat
not
in
supported_stats
:
raise
ValueError
(
f
"Stat
{
stat
}
is not supported for NVFP4. Supported stats:
{
supported_stats
}
"
)
return
True
def
get_stat_with_prefix
(
self
,
stat
:
str
)
->
str
:
"""Add nvfp4_ prefix to stat name for use in stats_computation."""
return
f
"nvfp4_
{
stat
}
"
@
contextmanager
def
update_aux_dict
(
self
,
aux_dict
:
Dict
,
quantized_tensor
:
QuantizedTensor
,
quantizer
:
Quantizer
,
# pylint: disable=unused-argument
original_tensor
:
torch
.
Tensor
,
):
"""
Updates the aux_dict with the quantized tensor and additional NVFP4-specific data.
Yields the aux_dict.
"""
aux_dict
=
{
"nvfp4"
:
quantized_tensor
,
"original_tensor"
:
original_tensor
,
}
try
:
yield
aux_dict
finally
:
pass
@
api_method
def
inspect_tensor_enabled
(
self
,
config
:
Dict
,
layer_name
:
str
,
tensor_name
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current
,
next_iter
=
next_enabled_iter
(
config
.
get
(
"start_step"
,
None
),
config
.
get
(
"end_step"
,
None
),
config
.
get
(
"start_end_list"
,
None
),
config
.
get
(
"freq"
,
1
),
iteration
,
)
STATS_BUFFERS
.
layers_to_next_iter
[
layer_name
]
=
next_iter
return
run_current
,
next_iter
@
api_method
def
inspect_tensor
(
self
,
config
:
Dict
,
layer_name
:
str
,
tensor_name
:
str
,
iteration
:
int
,
tp_group
,
tensor
:
torch
.
Tensor
,
rowwise_quantized_tensor
:
Optional
[
QuantizedTensor
]
=
None
,
columnwise_quantized_tensor
:
Optional
[
QuantizedTensor
]
=
None
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert
rowwise_quantized_tensor
is
columnwise_quantized_tensor
# Skip logging if quantizer is None (layer runs in high precision)
if
quantizer
is
None
:
warnings
.
warn
(
f
"[LogNvfp4TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': layer runs in high precision (no quantizer)."
)
return
quantized_tensor
=
rowwise_quantized_tensor
# Skip logging if not NVFP4 quantizer (incompatible precision)
if
not
isinstance
(
quantizer
,
NVFP4Quantizer
):
warnings
.
warn
(
f
"[LogNvfp4TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': incompatible precision "
f
"(expected NVFP4Quantizer, got
{
type
(
quantizer
).
__name__
}
)."
)
return
# Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision)
if
not
isinstance
(
quantized_tensor
,
NVFP4TensorStorage
):
warnings
.
warn
(
f
"[LogNvfp4TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': incompatible precision "
f
"(expected NVFP4TensorStorage, got
{
type
(
quantized_tensor
).
__name__
}
)."
)
return
for
stat
in
config
[
"stats"
]:
self
.
check_if_stat_is_supported
(
stat
)
start_step
=
config
.
get
(
"start_step"
,
None
)
end_step
=
config
.
get
(
"end_step"
,
None
)
start_end_list
=
config
.
get
(
"start_end_list"
,
None
)
if
start_end_list
is
not
None
:
start_end_list
=
tuple
(
tuple
(
int
(
x
)
for
x
in
interval
)
for
interval
in
start_end_list
)
options
=
(
start_step
,
end_step
,
start_end_list
,
"nvfp4"
,
)
skip_reduction
,
reduction_group
,
reduce_within_microbatch
=
get_reduction_params
(
tensor_name
,
tp_group
)
# Add nvfp4_ prefix to all stats for internal use
prefixed_stats
=
[
self
.
get_stat_with_prefix
(
stat
)
for
stat
in
config
[
"stats"
]]
STATS_BUFFERS
.
try_add_buffer
(
layer_name
=
layer_name
,
tensor_name
=
tensor_name
,
stats
=
prefixed_stats
,
options
=
options
,
reduction_group
=
reduction_group
,
reduce_within_microbatch
=
reduce_within_microbatch
,
)
with
self
.
update_aux_dict
(
aux_dict
=
{},
quantized_tensor
=
quantized_tensor
,
quantizer
=
quantizer
,
original_tensor
=
tensor
,
)
as
aux_dict
:
STATS_BUFFERS
.
feed
(
layer_name
,
tensor_name
,
options
,
tensor
,
iteration
,
skip_reduction
,
aux_dict
=
aux_dict
,
)
debug_api
.
log_message
(
f
"Feature=
{
self
.
__class__
.
__name__
}
, API=inspect_tensor:
{
tensor_name
}
"
,
layer_name
,
extra_cachable_args
=
(
tensor_name
,),
)
transformer_engine/debug/features/utils/stats_computation.py
View file @
9df0c4a3
...
@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
...
@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
add_underflows_stats
(
_recipe_name
,
_columnwise
)
add_underflows_stats
(
_recipe_name
,
_columnwise
)
add_scale_inv_stats
(
_recipe_name
,
_columnwise
)
add_scale_inv_stats
(
_recipe_name
,
_columnwise
)
add_mse_stats
(
_recipe_name
,
_columnwise
)
add_mse_stats
(
_recipe_name
,
_columnwise
)
# NVFP4-specific statistics
def
count_nonzero_nvfp4
(
fp4_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Count the number of non-zero elements in the FP4 data.
FP4 data is stored as 2 4-bit values per byte (uint8).
We need to unpack and count non-zeros.
"""
# Each byte contains two FP4 values
# Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0)
zero_vals
=
torch
.
tensor
([
0
,
8
],
device
=
fp4_data
.
device
,
dtype
=
torch
.
uint8
)
# Extract first and second nibbles
first_nibble
=
fp4_data
%
16
second_nibble
=
fp4_data
//
16
# Count zeros
first_zeros
=
torch
.
isin
(
first_nibble
,
zero_vals
).
sum
()
second_zeros
=
torch
.
isin
(
second_nibble
,
zero_vals
).
sum
()
total_elements
=
fp4_data
.
numel
()
*
2
return
total_elements
-
first_zeros
-
second_zeros
def
add_nvfp4_underflows_stats
():
"""Register underflow stats for NVFP4.
Computes underflows by counting zeros in packed FP4 data vs original tensor.
"""
stat_num
=
"nvfp4_underflows_num"
stat_pct
=
"nvfp4_underflows%"
stats_to_num
[
stat_num
]
=
len
(
stats_to_num
)
stats_to_num
[
stat_pct
]
=
len
(
stats_to_num
)
# Count non-zeros in original vs FP4 packed data
STATS
[
stat_num
]
=
(
lambda
x
,
aux_dict
:
x
.
count_nonzero
()
-
count_nonzero_nvfp4
(
aux_dict
[
"nvfp4"
].
_rowwise_data
),
lambda
buffers
,
_sn
=
stat_num
:
sum
(
_get
(
buffers
,
_sn
)),
)
STATS
[
stat_pct
]
=
(
lambda
x
,
aux_dict
:
(
x
.
count_nonzero
()
-
count_nonzero_nvfp4
(
aux_dict
[
"nvfp4"
].
_rowwise_data
)
)
/
aux_dict
[
"nvfp4"
].
numel
()
*
100
,
lambda
buffers
,
_sn_num
=
stat_num
:
100
*
sum
(
_get
(
buffers
,
_sn_num
))
/
sum
(
_get
(
buffers
,
"numel"
)),
)
DEPENDENCIES
[
stat_num
]
=
{
stat_num
}
DEPENDENCIES
[
stat_pct
]
=
{
stat_num
,
"numel"
}
# Register NVFP4 stats
add_nvfp4_underflows_stats
()
add_mse_stats
(
"nvfp4"
)
# Reuse existing MSE function
transformer_engine/debug/pytorch/debug_quantization.py
View file @
9df0c4a3
...
@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
...
@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
}
}
API_CALL_MODIFY
=
"modify_tensor()"
API_CALL_MODIFY
=
"modify_tensor()"
STANDARD_
FP8_
QUANTIZE
=
"
FP8
Quantize"
STANDARD_QUANTIZE
=
"Quantize"
HIGH_PRECISION
=
"High Precision"
HIGH_PRECISION
=
"High Precision"
...
@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
...
@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
# inspect_tensor*_enabled are bool fields,
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
# indicating whether some feature will need to run inspect_tensor_* calls.
#
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_
FP8_
QUANTIZE, HIGH_PRECISION]
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
# determining what will happen when the quantizer is used for that tensor.
self
.
output_tensor
=
tensor_name
in
[
"output"
,
"wgrad"
,
"dgrad"
]
self
.
output_tensor
=
tensor_name
in
[
"output"
,
"wgrad"
,
"dgrad"
]
if
self
.
output_tensor
:
if
self
.
output_tensor
:
...
@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
...
@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
def
get_tensors_plan
(
self
):
def
get_tensors_plan
(
self
):
"""
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_
FP8_
QUANTIZE, or HIGH_PRECISION, indicating the behavior
API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
of this quantizer with respect to these tensors.
"""
"""
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
...
@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
...
@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
rowwise_plan
=
API_CALL_MODIFY
rowwise_plan
=
API_CALL_MODIFY
else
:
else
:
if
self
.
parent_quantizer
is
not
None
:
if
self
.
parent_quantizer
is
not
None
:
fp8_
quantize
=
self
.
process_enabled_api_call
(
quantize
_enabled
=
self
.
process_enabled_api_call
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
# API name kept for compatibility
layer_name
=
self
.
layer_name
,
layer_name
=
self
.
layer_name
,
gemm
=
self
.
rowwise_gemm_name
,
gemm
=
self
.
rowwise_gemm_name
,
iteration
=
self
.
iteration
,
iteration
=
self
.
iteration
,
)
)
)
)
if
fp8_
quantize
:
if
quantize
_enabled
:
rowwise_plan
=
STANDARD_
FP8_
QUANTIZE
rowwise_plan
=
STANDARD_QUANTIZE
if
rowwise_plan
is
None
:
if
rowwise_plan
is
None
:
rowwise_plan
=
HIGH_PRECISION
rowwise_plan
=
HIGH_PRECISION
...
@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
...
@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
columnwise_plan
=
API_CALL_MODIFY
columnwise_plan
=
API_CALL_MODIFY
else
:
else
:
if
self
.
parent_quantizer
is
not
None
:
if
self
.
parent_quantizer
is
not
None
:
fp8_
quantize
=
self
.
process_enabled_api_call
(
quantize
_enabled
=
self
.
process_enabled_api_call
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
# API name kept for compatibility
layer_name
=
self
.
layer_name
,
layer_name
=
self
.
layer_name
,
gemm
=
self
.
columnwise_gemm_name
,
gemm
=
self
.
columnwise_gemm_name
,
iteration
=
self
.
iteration
,
iteration
=
self
.
iteration
,
)
)
)
)
if
fp8_
quantize
:
if
quantize
_enabled
:
columnwise_plan
=
STANDARD_
FP8_
QUANTIZE
columnwise_plan
=
STANDARD_QUANTIZE
if
columnwise_plan
is
None
:
if
columnwise_plan
is
None
:
columnwise_plan
=
HIGH_PRECISION
columnwise_plan
=
HIGH_PRECISION
...
@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
...
@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
del
args
[
"quantizer"
]
del
args
[
"quantizer"
]
if
(
if
(
self
.
rowwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_
FP8_
QUANTIZE
]
self
.
rowwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_QUANTIZE
]
and
self
.
inspect_tensor_postquantize_enabled_rowwise
and
self
.
inspect_tensor_postquantize_enabled_rowwise
):
):
args
[
"tensor"
]
=
rowwise_gemm_tensor
args
[
"tensor"
]
=
rowwise_gemm_tensor
...
@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
...
@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
**
args
)
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
**
args
)
if
(
if
(
self
.
columnwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_
FP8_
QUANTIZE
]
self
.
columnwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_QUANTIZE
]
and
self
.
inspect_tensor_postquantize_enabled_columnwise
and
self
.
inspect_tensor_postquantize_enabled_columnwise
):
):
args
[
"tensor"
]
=
columnwise_gemm_tensor
args
[
"tensor"
]
=
columnwise_gemm_tensor
...
@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
...
@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
self
.
parent_quantizer
.
set_usage
(
rowwise
=
True
)
self
.
parent_quantizer
.
set_usage
(
rowwise
=
True
)
rowwise_gemm_tensor
,
columnwise_gemm_tensor
=
None
,
None
rowwise_gemm_tensor
,
columnwise_gemm_tensor
=
None
,
None
if
STANDARD_
FP8_
QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
if
STANDARD_QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
quantized_tensor
=
self
.
parent_quantizer
(
tensor
)
quantized_tensor
=
self
.
parent_quantizer
(
tensor
)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be
in fp8
,
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be
quantized
,
# one tensor with columnwise=True and rowwise=True is computed
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if
self
.
rowwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
:
if
self
.
rowwise_tensor_plan
==
STANDARD_QUANTIZE
:
rowwise_gemm_tensor
=
quantized_tensor
rowwise_gemm_tensor
=
quantized_tensor
if
self
.
columnwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
:
if
self
.
columnwise_tensor_plan
==
STANDARD_QUANTIZE
:
columnwise_gemm_tensor
=
quantized_tensor
columnwise_gemm_tensor
=
quantized_tensor
# 2. modify_tensor() is called, if it is used.
# 2. modify_tensor() is called, if it is used.
...
@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
...
@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import
nvdlfw_inspect.api
as
debug_api
import
nvdlfw_inspect.api
as
debug_api
assert
self
.
parent_quantizer
is
None
,
"
FP8
output is not supported for debug=True."
assert
self
.
parent_quantizer
is
None
,
"
Quantized
output is not supported for debug=True."
assert
self
.
output_tensor
assert
self
.
output_tensor
tensor_to_gemm
=
{
"output"
:
"fprop"
,
"wgrad"
:
"wgrad"
,
"dgrad"
:
"dgrad"
}
tensor_to_gemm
=
{
"output"
:
"fprop"
,
"wgrad"
:
"wgrad"
,
"dgrad"
:
"dgrad"
}
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
...
@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
...
@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
):
):
return
True
return
True
if
self
.
parent_quantizer
is
not
None
:
if
self
.
parent_quantizer
is
not
None
:
if
self
.
rowwise_tensor_plan
!=
STANDARD_
FP8_
QUANTIZE
:
if
self
.
rowwise_tensor_plan
!=
STANDARD_QUANTIZE
:
return
True
return
True
if
self
.
columnwise_tensor_plan
!=
STANDARD_
FP8_
QUANTIZE
:
if
self
.
columnwise_tensor_plan
!=
STANDARD_QUANTIZE
:
return
True
return
True
return
False
return
False
...
@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
...
@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
if
self
.
parent_quantizer
is
not
None
:
if
self
.
parent_quantizer
is
not
None
:
if
(
if
(
dst
.
rowwise_gemm_tensor
is
not
None
dst
.
rowwise_gemm_tensor
is
not
None
and
self
.
rowwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
and
self
.
rowwise_tensor_plan
==
STANDARD_QUANTIZE
):
):
if
hasattr
(
dst
.
rowwise_gemm_tensor
,
"quantize_"
):
if
hasattr
(
dst
.
rowwise_gemm_tensor
,
"quantize_"
):
dst
.
rowwise_gemm_tensor
.
quantize_
(
src
,
noop_flag
=
None
)
dst
.
rowwise_gemm_tensor
.
quantize_
(
src
,
noop_flag
=
None
)
...
@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
...
@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
updated_rowwise_gemm
=
True
updated_rowwise_gemm
=
True
if
(
if
(
dst
.
columnwise_gemm_tensor
is
not
None
dst
.
columnwise_gemm_tensor
is
not
None
and
self
.
columnwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
and
self
.
columnwise_tensor_plan
==
STANDARD_QUANTIZE
and
not
updated_rowwise_gemm
and
not
updated_rowwise_gemm
):
):
if
hasattr
(
dst
.
columnwise_gemm_tensor
,
"quantize_"
):
if
hasattr
(
dst
.
columnwise_gemm_tensor
,
"quantize_"
):
...
@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
...
@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
"""
"""
Updates the usage of the parent quantizer.
Updates the usage of the parent quantizer.
"""
"""
rowwise_gemm_quantize
=
(
rowwise_gemm_quantize
=
self
.
rowwise_usage
and
self
.
rowwise_tensor_plan
==
STANDARD_QUANTIZE
self
.
rowwise_usage
and
self
.
rowwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize
=
(
columnwise_gemm_quantize
=
(
self
.
columnwise_usage
and
self
.
columnwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
self
.
columnwise_usage
and
self
.
columnwise_tensor_plan
==
STANDARD_QUANTIZE
)
)
if
STANDARD_
FP8_
QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
if
STANDARD_QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
self
.
parent_quantizer
.
set_usage
(
self
.
parent_quantizer
.
set_usage
(
rowwise
=
rowwise_gemm_quantize
,
rowwise
=
rowwise_gemm_quantize
,
columnwise
=
columnwise_gemm_quantize
,
columnwise
=
columnwise_gemm_quantize
,
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment