Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1165 additions
and
143 deletions
+1165
-143
transformer_engine/common/include/transformer_engine/comm_gemm.h
...rmer_engine/common/include/transformer_engine/comm_gemm.h
+2
-0
transformer_engine/common/include/transformer_engine/fused_attn.h
...mer_engine/common/include/transformer_engine/fused_attn.h
+38
-28
transformer_engine/common/include/transformer_engine/gemm.h
transformer_engine/common/include/transformer_engine/gemm.h
+171
-0
transformer_engine/common/include/transformer_engine/hadamard_transform.h
...ne/common/include/transformer_engine/hadamard_transform.h
+34
-0
transformer_engine/common/include/transformer_engine/multi_tensor.h
...r_engine/common/include/transformer_engine/multi_tensor.h
+11
-0
transformer_engine/common/recipe/__init__.py
transformer_engine/common/recipe/__init__.py
+21
-14
transformer_engine/common/transformer_engine.cpp
transformer_engine/common/transformer_engine.cpp
+2
-2
transformer_engine/common/triton/permutation.py
transformer_engine/common/triton/permutation.py
+12
-0
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+8
-0
transformer_engine/common/util/cuda_runtime.h
transformer_engine/common/util/cuda_runtime.h
+6
-0
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+6
-6
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+314
-3
transformer_engine/debug/features/disable_fp8_gemm.py
transformer_engine/debug/features/disable_fp8_gemm.py
+25
-19
transformer_engine/debug/features/disable_fp8_layer.py
transformer_engine/debug/features/disable_fp8_layer.py
+28
-35
transformer_engine/debug/features/disable_quantization_gemm.py
...former_engine/debug/features/disable_quantization_gemm.py
+59
-0
transformer_engine/debug/features/disable_quantization_layer.py
...ormer_engine/debug/features/disable_quantization_layer.py
+61
-0
transformer_engine/debug/features/log_fp8_tensor_stats.py
transformer_engine/debug/features/log_fp8_tensor_stats.py
+43
-9
transformer_engine/debug/features/log_nvfp4_tensor_stats.py
transformer_engine/debug/features/log_nvfp4_tensor_stats.py
+237
-0
transformer_engine/debug/features/utils/stats_computation.py
transformer_engine/debug/features/utils/stats_computation.py
+62
-0
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+25
-27
No files found.
transformer_engine/common/include/transformer_engine/comm_gemm.h
View file @
9df0c4a3
...
...
@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/
void
nvte_comm_gemm_ctx_destroy
(
NVTECommGemmCtx
*
ctx
);
...
...
transformer_engine/common/include/transformer_engine/fused_attn.h
View file @
9df0c4a3
...
...
@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/
NVTE_Fused_Attn_Backend
nvte_get_fused_attn_backend
(
bool
is_training
,
NVTEDType
q_dtype
,
NVTEDType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
float
dropout
,
size_t
num_attn_heads
,
size_t
num_gqa_groups
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
size_t
head_dim_qk
,
size_t
head_dim_v
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
);
int64_t
window_size_right
,
bool
return_max_logit
,
bool
cuda_graph
,
bool
deterministic
);
/*! \brief Compute dot product attention with packed QKV input.
*
...
...
@@ -269,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[
deprecated
(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead."
)]]
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
void
nvte_fused_attn_fwd_qkvpacked
(
const
NVTETensor
QKV
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
const
NVTETensor
rng_state
,
size_t
max_seqlen
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
...
...
@@ -332,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
...
...
@@ -346,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor
dSoftmaxOffset
,
const
NVTETensor
cu_seqlens
,
const
NVTETensor
cu_seqlens_padded
,
size_t
max_seqlen
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute dot product attention with packed KV input.
*
...
...
@@ -409,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
...
...
@@ -424,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
...
...
@@ -478,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
...
...
@@ -494,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
const
NVTETensor
cu_seqlens_kv_padded
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute dot product attention with separate Q, K and V.
*
...
...
@@ -559,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
void
nvte_fused_attn_fwd
(
const
NVTETensor
Q
,
const
NVTETensor
K
,
const
NVTETensor
V
,
const
NVTETensor
Bias
,
const
NVTETensor
SoftmaxOffset
,
NVTETensor
S
,
NVTETensor
O
,
NVTETensorPack
*
Aux_CTX_Tensors
,
const
NVTETensor
cu_seqlens_q
,
const
NVTETensor
cu_seqlens_kv
,
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
NVTETensor
workspace
,
cudaStream_t
stream
);
const
NVTETensor
cu_seqlens_q_padded
,
const
NVTETensor
cu_seqlens_kv_padded
,
const
NVTETensor
page_table_k
,
const
NVTETensor
page_table_v
,
const
NVTETensor
rng_state
,
size_t
max_seqlen_q
,
size_t
max_seqlen_kv
,
bool
is_training
,
bool
return_max_logit
,
bool
cuda_graph
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
...
...
@@ -628,6 +636,7 @@ void nvte_fused_attn_fwd(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
...
...
@@ -643,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t
max_seqlen_kv
,
float
attn_scale
,
float
dropout
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
attn_mask_type
,
NVTE_Softmax_Type
softmax_type
,
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
int64_t
window_size_left
,
int64_t
window_size_right
,
bool
bottom_right_diagonal
,
bool
deterministic
,
bool
cuda_graph
,
NVTETensor
workspace
,
cudaStream_t
stream
);
/*! \brief Update the RNG state with the seed and calculated offset.
*
...
...
transformer_engine/common/include/transformer_engine/gemm.h
View file @
9df0c4a3
...
...
@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
...
...
@@ -20,6 +22,9 @@ extern "C" {
/*! \brief Configuration for matrix multiplication. */
typedef
void
*
NVTEMatmulConfig
;
/*! \brief Configuration for grouped matrix multiplication. */
typedef
void
*
NVTEGroupedMatmulConfig
;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
...
...
@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
kNVTEMatmulConfigNumAttributes
};
/*! \enum NVTEGroupedMatmulConfigAttribute
* \brief Type of option for grouped matrix multiplication.
*/
enum
NVTEGroupedMatmulConfigAttribute
{
/*! Average M dimension hint
*
* Optional hint for average M dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgM
=
0
,
/*! Average N dimension hint
*
* Optional hint for average N dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgN
=
1
,
/*! Average K (reduction) dimension hint
*
* Optional hint for average K dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from A's logical shape.
*/
kNVTEGroupedMatmulConfigAvgK
=
2
,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEGroupedMatmulConfigSMCount
=
3
,
kNVTEGroupedMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig
nvte_create_matmul_config
();
...
...
@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Destroy a matrix multiplication configuration. */
void
nvte_destroy_matmul_config
(
NVTEMatmulConfig
config
);
/*! \brief Create a grouped matrix multiplication configuration. */
NVTEGroupedMatmulConfig
nvte_create_grouped_matmul_config
();
/*! \brief Query an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void
nvte_get_grouped_matmul_config_attribute
(
NVTEGroupedMatmulConfig
config
,
NVTEGroupedMatmulConfigAttribute
attr
,
void
*
buf
,
size_t
size_in_bytes
,
size_t
*
size_written
);
/*! \brief Set an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void
nvte_set_grouped_matmul_config_attribute
(
NVTEGroupedMatmulConfig
config
,
NVTEGroupedMatmulConfigAttribute
attr
,
const
void
*
buf
,
size_t
size_in_bytes
);
/*! \brief Destroy a grouped matrix multiplication configuration. */
void
nvte_destroy_grouped_matmul_config
(
NVTEGroupedMatmulConfig
config
);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
...
...
@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
bool
accumulate
,
bool
use_split_accumulator
,
int
math_sm_count
,
cudaStream_t
stream
);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
* Performs batched GEMM on a collection of matrices with potentially different shapes.
* All tensors in the group must have compatible dimensions for matrix multiplication.
* Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous
* memory layout and shape metadata.
*
* \param[in] A Input grouped tensor A.
* \param[in] transa Whether to transpose A matrices.
* \param[in] B Input grouped tensor B.
* \param[in] transb Whether to transpose B matrices.
* \param[in] C Input grouped tensor C (can be NULL for beta=0).
* \param[out] D Output grouped tensor D.
* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements).
* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements).
* \param[in] workspace_setup Workspace tensor for pointer array setup.
* \param[in] workspace_cublas Workspace tensor for cuBLAS operations.
* \param[in] config Additional configuration (can be NULL for defaults).
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.2+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
void
nvte_grouped_gemm
(
const
NVTEGroupedTensor
A
,
int
transa
,
const
NVTEGroupedTensor
B
,
int
transb
,
const
NVTEGroupedTensor
C
,
NVTEGroupedTensor
D
,
const
NVTETensor
alpha
,
const
NVTETensor
beta
,
NVTETensor
workspace_setup
,
NVTETensor
workspace_cublas
,
NVTEGroupedMatmulConfig
config
,
cudaStream_t
stream
);
#ifdef __HIP_PLATFORM_AMD__
void
nvte_multi_stream_cublas_batchgemm
(
const
NVTETensor
*
A
,
const
NVTETensor
*
B
,
NVTETensor
*
D
,
const
NVTETensor
*
bias
,
NVTETensor
*
pre_gelu_out
,
...
...
@@ -356,6 +463,70 @@ class MatmulConfigWrapper {
NVTEMatmulConfig
config_
=
nullptr
;
};
/*! \struct GroupedMatmulConfigWrapper
* \brief C++ wrapper for NVTEGroupedMatmulConfig.
*/
class
GroupedMatmulConfigWrapper
{
public:
GroupedMatmulConfigWrapper
()
:
config_
{
nvte_create_grouped_matmul_config
()}
{}
GroupedMatmulConfigWrapper
(
const
GroupedMatmulConfigWrapper
&
)
=
delete
;
GroupedMatmulConfigWrapper
&
operator
=
(
const
GroupedMatmulConfigWrapper
&
)
=
delete
;
GroupedMatmulConfigWrapper
(
GroupedMatmulConfigWrapper
&&
other
)
:
config_
{
other
.
config_
}
{
other
.
config_
=
nullptr
;
}
GroupedMatmulConfigWrapper
&
operator
=
(
GroupedMatmulConfigWrapper
&&
other
)
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_grouped_matmul_config
(
config_
);
}
config_
=
other
.
config_
;
other
.
config_
=
nullptr
;
return
*
this
;
}
~
GroupedMatmulConfigWrapper
()
{
if
(
config_
!=
nullptr
)
{
nvte_destroy_grouped_matmul_config
(
config_
);
config_
=
nullptr
;
}
}
/*! \brief Get the underlying NVTEGroupedMatmulConfig.
*
* \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.
*/
operator
NVTEGroupedMatmulConfig
()
const
noexcept
{
return
config_
;
}
/*! \brief Set average M dimension hint for algorithm selection. */
void
set_avg_m
(
int64_t
avg_m
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigAvgM
,
&
avg_m
,
sizeof
(
int64_t
));
}
/*! \brief Set average N dimension hint for algorithm selection. */
void
set_avg_n
(
int64_t
avg_n
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigAvgN
,
&
avg_n
,
sizeof
(
int64_t
));
}
/*! \brief Set average K dimension hint for algorithm selection. */
void
set_avg_k
(
int64_t
avg_k
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigAvgK
,
&
avg_k
,
sizeof
(
int64_t
));
}
/*! \brief Set number of streaming multiprocessors to use. */
void
set_sm_count
(
int
sm_count
)
{
nvte_set_grouped_matmul_config_attribute
(
config_
,
kNVTEGroupedMatmulConfigSMCount
,
&
sm_count
,
sizeof
(
int
));
}
private:
/*! \brief Wrapped NVTEGroupedMatmulConfig. */
NVTEGroupedMatmulConfig
config_
=
nullptr
;
};
}
// namespace transformer_engine
#endif // __cplusplus
...
...
transformer_engine/common/include/transformer_engine/hadamard_transform.h
View file @
9df0c4a3
...
...
@@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*! \brief Grouped-tensor amax with Hadamard transform (graph safe, device-managed grouping).
*
* This function is experimental and the API is not stable.
*
* This API assumes that the split info (grouping of tensors) is on device and unknown to the host;
* therefore, this is a graph safe API and the grouped-tensor argument is passed as a single device structure.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output amax (row/col). Only the row-wise and
* column-wise amaxes are updated.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_amax_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
int
random_sign_mask
,
int
random_sign_mask_t
,
cudaStream_t
stream
);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
...
...
@@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
const
NVTEQuantizationConfig
quant_config
,
NVTETensor
quant_workspace
,
cudaStream_t
stream
);
/*!
* \brief Perform the grouped-tensor Hadamard transform cast fusion operation in graph-safe mode.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output (row/column-wise quantized results).
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_hadamard_transform_cast_fusion_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
const
NVTETensor
hadamard_matrix
,
const
NVTEQuantizationConfig
quant_config
,
NVTETensor
quant_workspace
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/include/transformer_engine/multi_tensor.h
View file @
9df0c4a3
...
...
@@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **
void
nvte_group_amax
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
size_t
num_tensors
,
cudaStream_t
stream
);
/*! \brief Grouped-tensor amax without doing hadamard transform.
*
* This function is experimental and the API is not stable.
*
* \param[in] input NVTEGroupedTensor Input tensor.
* \param[in,out] output NVTEGroupedTensor Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void
nvte_group_amax_graph_safe
(
const
NVTEGroupedTensor
input
,
NVTEGroupedTensor
output
,
cudaStream_t
stream
);
#ifdef __cplusplus
}
// extern "C"
#endif
...
...
transformer_engine/common/recipe/__init__.py
View file @
9df0c4a3
...
...
@@ -88,33 +88,40 @@ class Recipe:
Base recipe class.
"""
def
nvfp4
(
self
):
@
classmethod
def
nvfp4
(
cls
):
"""Whether the given recipe is NVFP4 1D block scaling."""
return
is
instance
(
self
,
NVFP4BlockScaling
)
return
is
subclass
(
cls
,
NVFP4BlockScaling
)
def
mxfp8
(
self
):
@
classmethod
def
mxfp8
(
cls
):
"""Whether the given recipe is MXFP8 block scaling."""
return
is
instance
(
self
,
MXFP8BlockScaling
)
return
is
subclass
(
cls
,
MXFP8BlockScaling
)
def
delayed
(
self
):
@
classmethod
def
delayed
(
cls
):
"""Whether the given recipe is delayed scaling."""
return
is
instance
(
self
,
DelayedScaling
)
return
is
subclass
(
cls
,
DelayedScaling
)
def
float8_current_scaling
(
self
):
@
classmethod
def
float8_current_scaling
(
cls
):
"""Whether the given recipe is (per-tensor) current scaling."""
return
is
instance
(
self
,
Float8CurrentScaling
)
return
is
subclass
(
cls
,
Float8CurrentScaling
)
def
float8_per_tensor_scaling
(
self
):
@
classmethod
def
float8_per_tensor_scaling
(
cls
):
"""Whether the given recipe is per-tensor scaling."""
return
is
instance
(
self
,
(
DelayedScaling
,
Float8CurrentScaling
))
return
is
subclass
(
cls
,
(
DelayedScaling
,
Float8CurrentScaling
))
def
float8_block_scaling
(
self
):
@
classmethod
def
float8_block_scaling
(
cls
):
"""Whether the given recipe is float8 blockwise scaling."""
return
is
instance
(
self
,
Float8BlockScaling
)
return
is
subclass
(
cls
,
Float8BlockScaling
)
def
custom
(
self
):
@
classmethod
def
custom
(
cls
):
"""Whether the given recipe is custom."""
return
is
instance
(
self
,
CustomRecipe
)
return
is
subclass
(
cls
,
CustomRecipe
)
@
dataclass
()
...
...
transformer_engine/common/transformer_engine.cpp
View file @
9df0c4a3
...
...
@@ -458,9 +458,9 @@ class TensorAllocator {
}
void
Free
(
NVTETensor
t
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
if
(
index
==
0
)
return
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid tensor."
);
free_list
.
push_back
(
index
);
// Clean up
...
...
@@ -568,9 +568,9 @@ class GroupedTensorAllocator {
}
void
Free
(
NVTEGroupedTensor
t
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
uintptr_t
index
=
reinterpret_cast
<
uintptr_t
>
(
t
);
if
(
index
==
0
)
return
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
NVTE_CHECK
(
index
<=
memory
.
size
(),
"Invalid grouped tensor."
);
free_list
.
push_back
(
index
);
// Clean up
...
...
transformer_engine/common/triton/permutation.py
View file @
9df0c4a3
...
...
@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr
+
load_split_offset
,
mask
=
load_split_offset
<
num_splits
,
other
=
0
).
to
(
tl
.
int32
)
input_split_sizes_cumsum
=
tl
.
cumsum
(
input_split_sizes
)
# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens
=
tl
.
sum
(
input_split_sizes
)
input_split_sizes_mask
=
tl
.
where
(
input_split_sizes_cumsum
<=
pid
,
1
,
0
)
input_chunk_idx
=
tl
.
sum
(
input_split_sizes_mask
)
input_split_sizes_presum
=
tl
.
sum
(
input_split_sizes
*
input_split_sizes_mask
)
...
...
@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).
to
(
tl
.
int32
)
output_pre_split_sizes
=
tl
.
where
(
load_split_offset
<
output_chunk_idx
,
output_split_sizes
,
0
)
dst_row
=
tl
.
sum
(
output_pre_split_sizes
)
+
in_chunk_offset
# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row
=
tl
.
where
(
pid
<
total_valid_tokens
,
dst_row
,
pid
)
tl
.
store
(
dst_rows_ptr
+
pid
,
dst_row
)
...
...
transformer_engine/common/util/cuda_runtime.cpp
View file @
9df0c4a3
...
...
@@ -6,6 +6,8 @@
#include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem>
#include <mutex>
...
...
@@ -232,6 +234,12 @@ int cudart_version() {
return
version
;
}
size_t
cublas_version
()
{
// Cache version to avoid cuBLAS logging overhead
static
size_t
version
=
cublasLtGetVersion
();
return
version
;
}
}
// namespace cuda
}
// namespace transformer_engine
transformer_engine/common/util/cuda_runtime.h
View file @
9df0c4a3
...
...
@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false);
*/
int
cudart_version
();
/* \brief cuBLAS version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
size_t
cublas_version
();
}
// namespace cuda
}
// namespace transformer_engine
...
...
transformer_engine/common/util/logging.h
View file @
9df0c4a3
...
...
@@ -145,7 +145,7 @@
do { \
const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ",
std::to_s
tring(status)); \
NVTE_ERROR("cuBLASMp Error: ",
cublasMpGetStatusS
tring(status)); \
} \
} while (false)
...
...
transformer_engine/common/util/ptx.cuh
View file @
9df0c4a3
...
...
@@ -172,6 +172,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mbarrier_arrive_expect_tx_cta_relaxed_shared_cta
(
uint64_t
*
mbar
,
const
uint32_t
tx_count
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;"
::
"r"
(
mbar_ptr
),
"r"
(
tx_count
));
#else
NVTE_DEVICE_ERROR
(
"mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
fence_mbarrier_init_release_cluster
()
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
"fence.mbarrier_init.release.cluster;"
);
...
...
@@ -251,13 +263,86 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mbarrier_wait_parity_acquire_cta_shared_cta
(
uint64_t
*
mbar
,
uint32_t
phase_parity
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
asm
volatile
(
"{
\n\t
"
".reg .b64 r1;
\n\t
"
".reg .pred waitComplete;
\n\t
"
// predicate representing if barrier condition is met
"WAIT:
\n\t
"
// loop around barrier wait
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1;
\n\t
"
"@waitComplete bra DONE;
\n\t
"
// mbarrier conditions are met
"bra WAIT;
\n\t
"
// just a time-out, try again
"DONE:
\n\t
"
"}
\n\t
"
:
:
"r"
(
mbar_ptr
),
"r"
(
phase_parity
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+."
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
try_cancel_cta
(
uint64_t
*
mbar
,
__uint128_t
*
response_data_ptr
)
{
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
uint32_t
workID_response
=
__cvta_generic_to_shared
(
response_data_ptr
);
asm
volatile
(
"clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::"
"all.b128 "
"[%0], [%1];"
::
"r"
(
workID_response
),
"r"
(
mbar_ptr
));
}
else
{
NVTE_DEVICE_ERROR
(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
}
__device__
__forceinline__
void
get_cancelled_cta_id_2D
(
__uint128_t
*
response_data_ptr
,
int32_t
&
ctaid_X
,
int32_t
&
ctaid_Y
)
{
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
uint32_t
workID_response
=
__cvta_generic_to_shared
(
response_data_ptr
);
asm
volatile
(
"{
\n\t
"
".reg .s32 x_ctaid;
\n\t
"
".reg .s32 y_ctaid;
\n\t
"
"mov .s32 x_ctaid, -1;
\n\t
"
"mov .s32 y_ctaid, -1;
\n\t
"
".reg.b128 try_cancel_response;
\n\t
"
"ld.shared.b128 try_cancel_response, [%2];
\n\t
"
".reg .pred P1;
\n\t
"
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response;
\n\t
"
"@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, "
"_}, try_cancel_response;
\n\t
"
"mov .s32 %0, x_ctaid;
\n\t
"
"mov .s32 %1, y_ctaid;
\n\t
"
"}
\n\t
"
:
"=r"
(
ctaid_X
),
"=r"
(
ctaid_Y
)
:
"r"
(
workID_response
)
:
"memory"
);
}
else
{
NVTE_DEVICE_ERROR
(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
}
constexpr
uint32_t
FP32_MANTISSA_BITS
=
23
;
constexpr
uint32_t
FP32_EXPONENT_BIAS
=
127
;
__device__
__forceinline__
float
exp2f_rcp
(
e8m0_t
biased_exp
)
{
return
(
biased_exp
==
0
)
?
1
:
__int_as_float
((
254
-
biased_exp
)
<<
FP32_MANTISSA_BITS
);
// 127 - (biased_exp - 127)
// Handle the special case of NaN.
if
(
biased_exp
==
255
)
return
__int_as_float
(
0x7fffffff
);
// Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of
// the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left.
if
(
biased_exp
==
254
)
return
__int_as_float
(
0x00400000
);
// Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal.
return
__int_as_float
((
254
-
biased_exp
)
<<
FP32_MANTISSA_BITS
);
}
__device__
__forceinline__
float
exp2f
(
e8m0_t
biased_exp
)
{
...
...
@@ -671,6 +756,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
return
mul_cvt_fp32_to_fp4_4x_with_rn
(
in01
,
in23
,
scale
,
rbits
);
}
}
template
<
typename
SCALING_COEFFICIENT_TYPE
>
__device__
__forceinline__
uint32_t
mul_cvt_bf16_to_fp4_8x_round_to_nearest
(
const
uint64_t
in03
,
const
uint64_t
in47
,
const
SCALING_COEFFICIENT_TYPE
scaling_coefficient
)
{
uint32_t
out_8x
=
0
;
constexpr
bool
is_blackwell
=
ARCH_BLACKWELL_FAMILY
;
if
constexpr
(
is_blackwell
)
{
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
bf16
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.f32 zero;
\n\t
"
"mov.b32 zero, 0;
\n\t
"
".reg.b16 scaling_coeff;
\n\t
"
"mov.b16 scaling_coeff, %3;
\n\t
"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h;
\n\t
"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1;
\n\t
"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2;
\n\t
"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero;
\n\t
"
".reg.b8 f0, f1, f2, f3;
\n\t
"
// Elements reordered to match e2m1x4 packing order (v1,v0)
"cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;
\n\t
"
"mov.b32 %0, {f0, f1, f2, f3};
\n
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"h"
(
reinterpret_cast
<
const
uint16_t
&>
(
scaling_coefficient
)));
}
else
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
float
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.b64 scaling_coeff_2x;
\n\t
"
"mov.b64 scaling_coeff_2x, {%3, %3};
\n\t
"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1;
\n\t
"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2;
\n\t
"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"cvt.f32.bf16 v4, v4_bf16;
\n\t
"
"cvt.f32.bf16 v5, v5_bf16;
\n\t
"
"cvt.f32.bf16 v6, v6_bf16;
\n\t
"
"cvt.f32.bf16 v7, v7_bf16;
\n\t
"
".reg.b64 v01, v23, v45, v67;
\n\t
"
"mov.b64 v01, {v0, v1};
\n\t
"
"mov.b64 v23, {v2, v3};
\n\t
"
"mov.b64 v45, {v4, v5};
\n\t
"
"mov.b64 v67, {v6, v7};
\n\t
"
"mul.f32x2 v01, v01, scaling_coeff_2x;
\n\t
"
"mul.f32x2 v23, v23, scaling_coeff_2x;
\n\t
"
"mul.f32x2 v45, v45, scaling_coeff_2x;
\n\t
"
"mul.f32x2 v67, v67, scaling_coeff_2x;
\n\t
"
// Elements reordered to match the packing order (v1,v0)
"mov.b64 {v1, v0}, v01;
\n\t
"
"mov.b64 {v3, v2}, v23;
\n\t
"
"mov.b64 {v5, v4}, v45;
\n\t
"
"mov.b64 {v7, v6}, v67;
\n\t
"
".reg.b8 f0, f1, f2, f3;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;
\n\t
"
"cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;
\n\t
"
"mov.b32 %0, {f0, f1, f2, f3};
\n\t
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"f"
(
scaling_coefficient
));
}
else
{
NVTE_DEVICE_ERROR
(
"Not supported scaling coefficient type."
);
}
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
out_8x
;
}
template
<
typename
SCALING_COEFFICIENT_TYPE
>
__device__
__forceinline__
uint32_t
mul_cvt_bf16_to_fp4_8x_stochastic_rounding
(
const
uint64_t
in03
,
const
uint64_t
in47
,
const
SCALING_COEFFICIENT_TYPE
scaling_coefficient
,
const
uint32_t
rbits03
,
const
uint32_t
rbits47
)
{
uint32_t
out_8x
=
0
;
constexpr
bool
has_rs
=
ARCH_HAS_STOCHASTIC_ROUNDING
;
if
constexpr
(
has_rs
)
{
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
bf16
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.f32 zero;
\n\t
"
"mov.b32 zero, 0;
\n\t
"
".reg.b16 scaling_coeff;
\n\t
"
"mov.b16 scaling_coeff, %3;
\n\t
"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h;
\n\t
"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1;
\n\t
"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2;
\n\t
"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero;
\n\t
"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero;
\n\t
"
".reg.b16 b03, b47;
\n\t
"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5;
\n\t
"
"mov.b32 %0, {b03, b47};
\n
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"h"
(
reinterpret_cast
<
const
uint16_t
&>
(
scaling_coefficient
)),
"r"
(
rbits03
),
"r"
(
rbits47
));
}
else
if
constexpr
(
std
::
is_same
<
SCALING_COEFFICIENT_TYPE
,
float
>::
value
)
{
asm
volatile
(
"{
\n
"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16;
\n\t
"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1;
\n\t
"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2;
\n\t
"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7;
\n\t
"
"cvt.f32.bf16 v0, v0_bf16;
\n\t
"
"cvt.f32.bf16 v1, v1_bf16;
\n\t
"
"cvt.f32.bf16 v2, v2_bf16;
\n\t
"
"cvt.f32.bf16 v3, v3_bf16;
\n\t
"
"cvt.f32.bf16 v4, v4_bf16;
\n\t
"
"cvt.f32.bf16 v5, v5_bf16;
\n\t
"
"cvt.f32.bf16 v6, v6_bf16;
\n\t
"
"cvt.f32.bf16 v7, v7_bf16;
\n\t
"
"mul.f32 v0, v0, %3;
\n\t
"
"mul.f32 v1, v1, %3;
\n\t
"
"mul.f32 v2, v2, %3;
\n\t
"
"mul.f32 v3, v3, %3;
\n\t
"
"mul.f32 v4, v4, %3;
\n\t
"
"mul.f32 v5, v5, %3;
\n\t
"
"mul.f32 v6, v6, %3;
\n\t
"
"mul.f32 v7, v7, %3;
\n\t
"
".reg.b16 b03, b47;
\n\t
"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4;
\n\t
"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5;
\n\t
"
"mov.b32 %0, {b03, b47};
\n
"
"}"
:
"=r"
(
out_8x
)
:
"l"
(
in03
),
"l"
(
in47
),
"f"
(
scaling_coefficient
),
"r"
(
rbits03
),
"r"
(
rbits47
));
}
else
{
NVTE_DEVICE_ERROR
(
"Not supported scaling coefficient type."
);
}
}
else
{
NVTE_DEVICE_ERROR
(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX."
);
}
return
out_8x
;
}
#endif // FP4_TYPE_SUPPORTED
// SIMD like "Fused" cast + multiplication (x2)
...
...
@@ -1521,6 +1779,59 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
:
"r"
(
in2
[
0
]),
"r"
(
in2
[
1
]));
return
out
;
}
// Loads single BF16/FP16 element from shared memory state space
__device__
__forceinline__
bf16
ld_shared_b16
(
const
bf16
*
__restrict__
src_smem
)
{
const
uint32_t
src_smem_ptr
=
__cvta_generic_to_shared
(
src_smem
);
bf16
dst
;
asm
volatile
(
"ld.shared.b16 %0, [%1];"
:
"=h"
(
reinterpret_cast
<
uint16_t
&>
(
dst
))
:
"r"
(
src_smem_ptr
));
return
dst
;
}
// Loads pair of BF16/FP16 values from shared memory state space
__device__
__forceinline__
bf16x2
ld_shared_b32
(
const
bf16x2
*
__restrict__
src_smem
)
{
const
uint32_t
src_smem_ptr
=
__cvta_generic_to_shared
(
src_smem
);
bf16x2
dst
;
asm
volatile
(
"ld.shared.b32 %0, [%1];"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
dst
))
:
"r"
(
src_smem_ptr
));
return
dst
;
}
// Loads 8x BF16 values from shared memory state space
__device__
__forceinline__
__uint128_t
ld_shared_b128
(
const
bf16
*
__restrict__
src_smem
)
{
uint64_t
elts03
,
elts47
;
const
uint32_t
src_smem_ptr
=
__cvta_generic_to_shared
(
src_smem
);
asm
volatile
(
"{
\n\t
"
".reg.b128 xy;
\n\t
"
"ld.shared.b128 xy, [%2];
\n\t
"
"mov.b128 {%0, %1}, xy;
\n
"
"}
\n
"
:
"=l"
(
elts03
),
"=l"
(
elts47
)
:
"r"
(
src_smem_ptr
));
return
(
static_cast
<
__uint128_t
>
(
elts47
)
<<
64
)
|
static_cast
<
__uint128_t
>
(
elts03
);
}
#if FP4_TYPE_SUPPORTED
// Vectorized store of x8 FP4 elements into shared memory state space
__device__
__forceinline__
void
st_shared_b32
(
fp4e2m1x2
*
__restrict__
dst_smem
,
uint32_t
fp4_pack_x8
)
{
const
uint32_t
dst_smem_ptr
=
__cvta_generic_to_shared
(
dst_smem
);
asm
volatile
(
"st.shared.b32 [%0], %1;"
:
:
"r"
(
dst_smem_ptr
),
"r"
(
fp4_pack_x8
));
}
#endif
// Vectorized store of x16 FP4 elements into shared memory state space
#if FP4_TYPE_SUPPORTED
__device__
__forceinline__
void
st_shared_b64
(
fp4e2m1x2
*
__restrict__
dst_smem
,
uint64_t
fp4_pack_x16
)
{
const
uint32_t
dst_smem_ptr
=
__cvta_generic_to_shared
(
dst_smem
);
asm
volatile
(
"st.shared.b64 [%0], %1;"
:
:
"r"
(
dst_smem_ptr
),
"l"
(
fp4_pack_x16
));
}
#endif
#endif
}
// namespace ptx
...
...
transformer_engine/debug/features/disable_fp8_gemm.py
View file @
9df0c4a3
...
...
@@ -2,17 +2,28 @@
#
# See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
"""
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.api
import
TEConfigAPIMapper
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""
import
warnings
from
nvdlfw_inspect.registry
import
Registry
from
transformer_engine.debug.features.disable_quantization_gemm
import
DisableQuantizationGEMM
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableFP8GEMM
(
TEConfigAPIMapper
):
class
DisableFP8GEMM
(
DisableQuantizationGEMM
):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Parameters
----------
...
...
@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
DisableFP8GEMM:
# Deprecated: use DisableQuantizationGEMM
enabled: True
gemms: [dgrad, wgrad]
"""
@
api_method
def
fp8_gemm_enabled
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
for
key
in
config
:
if
key
!=
"gemm"
:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
"DisableFP8GEMM is deprecated. "
"Use DisableQuantizationGEMM instead, which works with all quantization "
"formats (FP8, NVFP4, etc.)."
,
DeprecationWarning
,
stacklevel
=
2
,
)
super
().
__init__
(
*
args
,
**
kwargs
)
transformer_engine/debug/features/disable_fp8_layer.py
View file @
9df0c4a3
...
...
@@ -2,17 +2,27 @@
#
# See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
"""
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
import
nvdlfw_inspect.api
as
debug_api
from
nvdlfw_inspect.registry
import
Registry
,
api_method
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""
import
warnings
from
nvdlfw_inspect.registry
import
Registry
from
transformer_engine.debug.features.disable_quantization_layer
import
DisableQuantizationLayer
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableFP8Layer
:
class
DisableFP8Layer
(
DisableQuantizationLayer
)
:
"""
Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example
-------
...
...
@@ -23,33 +33,16 @@ class DisableFP8Layer:
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
DisableFP8Layer:
# Deprecated: use DisableQuantizationLayer
enabled: True
"""
@
api_method
def
fp8_gemm_enabled
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for
key
in
config
:
if
key
not
in
[
"enabled"
,
"gemm"
]:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api
.
log_message
(
"FP8 Disabled"
,
layer_name
)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
def
parse_config_and_api
(
self
,
config
,
**
_kwargs
):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return
config
[
"enabled"
],
None
def
__init__
(
self
,
*
args
,
**
kwargs
):
warnings
.
warn
(
"DisableFP8Layer is deprecated. "
"Use DisableQuantizationLayer instead, which works with all quantization "
"formats (FP8, NVFP4, etc.)."
,
DeprecationWarning
,
stacklevel
=
2
,
)
super
().
__init__
(
*
args
,
**
kwargs
)
transformer_engine/debug/features/disable_quantization_gemm.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.api
import
TEConfigAPIMapper
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableQuantizationGEMM
(
TEConfigAPIMapper
):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Parameters
----------
gemms: List[str]
list of gemms to disable quantization for
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@
api_method
def
fp8_gemm_enabled
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for
key
in
config
:
if
key
!=
"gemm"
:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
transformer_engine/debug/features/disable_quantization_layer.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect"""
import
nvdlfw_inspect.api
as
debug_api
from
nvdlfw_inspect.registry
import
Registry
,
api_method
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
DisableQuantizationLayer
:
"""
Disables all quantized GEMMs in the layer, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Example
-------
.. code-block:: yaml
example_disable_quantization_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationLayer:
enabled: True
"""
@
api_method
def
fp8_gemm_enabled
(
self
,
config
,
layer_name
:
str
,
gemm
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for
key
in
config
:
if
key
not
in
[
"enabled"
,
"gemm"
]:
raise
ValueError
(
f
'[NVTORCH INSPECT ERROR] Unexpected key in config: "
{
key
}
".'
)
# If quantized training, disable quantization for the selected layers if this feature is enabled.
debug_api
.
log_message
(
"Quantization Disabled"
,
layer_name
)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return
False
,
iteration
+
1
def
parse_config_and_api
(
self
,
config
,
**
_kwargs
):
"""Determines whether to run the API.
DisableQuantizationLayer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return
config
[
"enabled"
],
None
transformer_engine/debug/features/log_fp8_tensor_stats.py
View file @
9df0c4a3
...
...
@@ -6,15 +6,17 @@
from
typing
import
Dict
,
Optional
,
List
,
Tuple
from
contextlib
import
contextmanager
import
warnings
import
torch
import
nvdlfw_inspect.api
as
debug_api
import
transformer_engine_torch
as
tex
from
nvdlfw_inspect.debug_features.log_tensor_stats
import
LogTensorStats
as
BaseLogTensorStats
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.debug.features.utils
import
get_reduction_params
,
next_enabled_iter
from
transformer_engine.pytorch.tensor
import
Quantizer
,
QuantizedTensor
from
transformer_engine.pytorch.tensor.float8_tensor
import
(
Float8Quantizer
,
...
...
@@ -22,7 +24,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from
transformer_engine.pytorch.tensor.mxfp8_tensor
import
MXFP8Quantizer
from
transformer_engine.pytorch.tensor.float8_blockwise_tensor
import
Float8BlockQuantizer
from
transformer_engine.debug.features.utils
import
get_reduction_params
,
next_enabled_iter
try
:
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
_nvfp4_available
=
True
except
ImportError
:
_nvfp4_available
=
False
NVFP4Quantizer
=
None
ALL_RECIPE_NAMES
=
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"mxfp8"
,
"fp8_block_scaling"
]
...
...
@@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
return
"mxfp8"
if
isinstance
(
quantizer
,
Float8BlockQuantizer
):
return
"fp8_block_scaling"
if
_nvfp4_available
and
isinstance
(
quantizer
,
NVFP4Quantizer
):
return
"nvfp4"
raise
ValueError
(
f
"Unsupported quantizer type:
{
type
(
quantizer
)
}
"
)
...
...
@@ -164,6 +175,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
if
recipe_from_stat
!=
""
and
recipe_from_stat
not
in
ALL_RECIPE_NAMES
:
raise
ValueError
(
f
"Stat
{
stat
}
contains an unsupported recipe name:
{
recipe_from_stat
}
"
)
# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if
recipe_from_stat
==
"nvfp4"
:
raise
ValueError
(
f
"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '
{
stat
}
' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
if
recipe_from_stat
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
]
and
columnwise
:
raise
ValueError
(
f
"Stat
{
stat
}
is not supported. Columnwise tensor statistics are not supported for"
...
...
@@ -189,6 +210,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
def
get_recipe_from_stat
(
self
,
stat
:
str
,
default_recipe
:
str
=
""
):
"""Returns the recipe name from the stat string."""
columnwise_stat
=
stat
.
endswith
(
"_columnwise"
)
for
recipe_name
in
ALL_RECIPE_NAMES
:
if
recipe_name
in
stat
:
...
...
@@ -213,7 +235,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype
=
None
fp8_dtype
=
tex
.
DType
.
kFloat8E4M3
if
recipe_name
in
[
"fp8_delayed_scaling"
,
"fp8_current_scaling"
,
"fp8_block_scaling"
]:
assert
isinstance
(
quantizer
,
(
Float8Quantizer
,
Float8CurrentScalingQuantizer
,
Float8BlockQuantizer
)
...
...
@@ -277,14 +299,26 @@ class LogFp8TensorStats(BaseLogTensorStats):
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert
rowwise_quantized_tensor
is
columnwise_quantized_tensor
assert
(
quantizer
is
not
None
),
"[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
# Skip logging if quantizer is None (layer runs in high precision)
if
quantizer
is
None
:
warnings
.
warn
(
f
"[LogFp8TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': layer runs in high precision (no quantizer)."
)
return
quantized_tensor
=
rowwise_quantized_tensor
assert
isinstance
(
quantized_tensor
,
QuantizedTensor
),
"[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
# Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision)
if
not
isinstance
(
quantized_tensor
,
QuantizedTensor
):
warnings
.
warn
(
f
"[LogFp8TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': incompatible precision "
f
"(expected QuantizedTensor, got
{
type
(
quantized_tensor
).
__name__
}
)."
)
return
recipe_name
=
_get_recipe_name
(
quantizer
)
for
stat
in
config
[
"stats"
]:
...
...
transformer_engine/debug/features/log_nvfp4_tensor_stats.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect"""
from
typing
import
Dict
,
Optional
from
contextlib
import
contextmanager
import
warnings
import
torch
import
nvdlfw_inspect.api
as
debug_api
from
nvdlfw_inspect.debug_features.log_tensor_stats
import
LogTensorStats
as
BaseLogTensorStats
from
nvdlfw_inspect.registry
import
Registry
,
api_method
from
transformer_engine.debug.features.utils.stats_buffer
import
STATS_BUFFERS
from
transformer_engine.pytorch.tensor
import
Quantizer
,
QuantizedTensor
from
transformer_engine.pytorch.tensor.nvfp4_tensor
import
NVFP4Quantizer
from
transformer_engine.debug.features.utils
import
get_reduction_params
,
next_enabled_iter
from
transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage
import
NVFP4TensorStorage
@
Registry
.
register_feature
(
namespace
=
"transformer_engine"
)
class
LogNvfp4TensorStats
(
BaseLogTensorStats
):
"""Logs statistics of NVFP4 quantized tensors.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_nvfp4_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 1
- tensor: gradient
stats: [underflows%, mse]
freq: 5
start_step: 0
end_step: 80
"""
def
check_if_stat_is_supported
(
self
,
stat
:
str
):
"""Returns True if stat is supported, raises ValueError otherwise."""
supported_stats
=
[
"underflows%"
,
"mse"
,
]
if
stat
not
in
supported_stats
:
raise
ValueError
(
f
"Stat
{
stat
}
is not supported for NVFP4. Supported stats:
{
supported_stats
}
"
)
return
True
def
get_stat_with_prefix
(
self
,
stat
:
str
)
->
str
:
"""Add nvfp4_ prefix to stat name for use in stats_computation."""
return
f
"nvfp4_
{
stat
}
"
@
contextmanager
def
update_aux_dict
(
self
,
aux_dict
:
Dict
,
quantized_tensor
:
QuantizedTensor
,
quantizer
:
Quantizer
,
# pylint: disable=unused-argument
original_tensor
:
torch
.
Tensor
,
):
"""
Updates the aux_dict with the quantized tensor and additional NVFP4-specific data.
Yields the aux_dict.
"""
aux_dict
=
{
"nvfp4"
:
quantized_tensor
,
"original_tensor"
:
original_tensor
,
}
try
:
yield
aux_dict
finally
:
pass
@
api_method
def
inspect_tensor_enabled
(
self
,
config
:
Dict
,
layer_name
:
str
,
tensor_name
:
str
,
iteration
:
int
):
# pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current
,
next_iter
=
next_enabled_iter
(
config
.
get
(
"start_step"
,
None
),
config
.
get
(
"end_step"
,
None
),
config
.
get
(
"start_end_list"
,
None
),
config
.
get
(
"freq"
,
1
),
iteration
,
)
STATS_BUFFERS
.
layers_to_next_iter
[
layer_name
]
=
next_iter
return
run_current
,
next_iter
@
api_method
def
inspect_tensor
(
self
,
config
:
Dict
,
layer_name
:
str
,
tensor_name
:
str
,
iteration
:
int
,
tp_group
,
tensor
:
torch
.
Tensor
,
rowwise_quantized_tensor
:
Optional
[
QuantizedTensor
]
=
None
,
columnwise_quantized_tensor
:
Optional
[
QuantizedTensor
]
=
None
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert
rowwise_quantized_tensor
is
columnwise_quantized_tensor
# Skip logging if quantizer is None (layer runs in high precision)
if
quantizer
is
None
:
warnings
.
warn
(
f
"[LogNvfp4TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': layer runs in high precision (no quantizer)."
)
return
quantized_tensor
=
rowwise_quantized_tensor
# Skip logging if not NVFP4 quantizer (incompatible precision)
if
not
isinstance
(
quantizer
,
NVFP4Quantizer
):
warnings
.
warn
(
f
"[LogNvfp4TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': incompatible precision "
f
"(expected NVFP4Quantizer, got
{
type
(
quantizer
).
__name__
}
)."
)
return
# Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision)
if
not
isinstance
(
quantized_tensor
,
NVFP4TensorStorage
):
warnings
.
warn
(
f
"[LogNvfp4TensorStats] Skipping stats collection for layer '
{
layer_name
}
', "
f
"tensor '
{
tensor_name
}
': incompatible precision "
f
"(expected NVFP4TensorStorage, got
{
type
(
quantized_tensor
).
__name__
}
)."
)
return
for
stat
in
config
[
"stats"
]:
self
.
check_if_stat_is_supported
(
stat
)
start_step
=
config
.
get
(
"start_step"
,
None
)
end_step
=
config
.
get
(
"end_step"
,
None
)
start_end_list
=
config
.
get
(
"start_end_list"
,
None
)
if
start_end_list
is
not
None
:
start_end_list
=
tuple
(
tuple
(
int
(
x
)
for
x
in
interval
)
for
interval
in
start_end_list
)
options
=
(
start_step
,
end_step
,
start_end_list
,
"nvfp4"
,
)
skip_reduction
,
reduction_group
,
reduce_within_microbatch
=
get_reduction_params
(
tensor_name
,
tp_group
)
# Add nvfp4_ prefix to all stats for internal use
prefixed_stats
=
[
self
.
get_stat_with_prefix
(
stat
)
for
stat
in
config
[
"stats"
]]
STATS_BUFFERS
.
try_add_buffer
(
layer_name
=
layer_name
,
tensor_name
=
tensor_name
,
stats
=
prefixed_stats
,
options
=
options
,
reduction_group
=
reduction_group
,
reduce_within_microbatch
=
reduce_within_microbatch
,
)
with
self
.
update_aux_dict
(
aux_dict
=
{},
quantized_tensor
=
quantized_tensor
,
quantizer
=
quantizer
,
original_tensor
=
tensor
,
)
as
aux_dict
:
STATS_BUFFERS
.
feed
(
layer_name
,
tensor_name
,
options
,
tensor
,
iteration
,
skip_reduction
,
aux_dict
=
aux_dict
,
)
debug_api
.
log_message
(
f
"Feature=
{
self
.
__class__
.
__name__
}
, API=inspect_tensor:
{
tensor_name
}
"
,
layer_name
,
extra_cachable_args
=
(
tensor_name
,),
)
transformer_engine/debug/features/utils/stats_computation.py
View file @
9df0c4a3
...
...
@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
add_underflows_stats
(
_recipe_name
,
_columnwise
)
add_scale_inv_stats
(
_recipe_name
,
_columnwise
)
add_mse_stats
(
_recipe_name
,
_columnwise
)
# NVFP4-specific statistics
def
count_nonzero_nvfp4
(
fp4_data
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Count the number of non-zero elements in the FP4 data.
FP4 data is stored as 2 4-bit values per byte (uint8).
We need to unpack and count non-zeros.
"""
# Each byte contains two FP4 values
# Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0)
zero_vals
=
torch
.
tensor
([
0
,
8
],
device
=
fp4_data
.
device
,
dtype
=
torch
.
uint8
)
# Extract first and second nibbles
first_nibble
=
fp4_data
%
16
second_nibble
=
fp4_data
//
16
# Count zeros
first_zeros
=
torch
.
isin
(
first_nibble
,
zero_vals
).
sum
()
second_zeros
=
torch
.
isin
(
second_nibble
,
zero_vals
).
sum
()
total_elements
=
fp4_data
.
numel
()
*
2
return
total_elements
-
first_zeros
-
second_zeros
def
add_nvfp4_underflows_stats
():
"""Register underflow stats for NVFP4.
Computes underflows by counting zeros in packed FP4 data vs original tensor.
"""
stat_num
=
"nvfp4_underflows_num"
stat_pct
=
"nvfp4_underflows%"
stats_to_num
[
stat_num
]
=
len
(
stats_to_num
)
stats_to_num
[
stat_pct
]
=
len
(
stats_to_num
)
# Count non-zeros in original vs FP4 packed data
STATS
[
stat_num
]
=
(
lambda
x
,
aux_dict
:
x
.
count_nonzero
()
-
count_nonzero_nvfp4
(
aux_dict
[
"nvfp4"
].
_rowwise_data
),
lambda
buffers
,
_sn
=
stat_num
:
sum
(
_get
(
buffers
,
_sn
)),
)
STATS
[
stat_pct
]
=
(
lambda
x
,
aux_dict
:
(
x
.
count_nonzero
()
-
count_nonzero_nvfp4
(
aux_dict
[
"nvfp4"
].
_rowwise_data
)
)
/
aux_dict
[
"nvfp4"
].
numel
()
*
100
,
lambda
buffers
,
_sn_num
=
stat_num
:
100
*
sum
(
_get
(
buffers
,
_sn_num
))
/
sum
(
_get
(
buffers
,
"numel"
)),
)
DEPENDENCIES
[
stat_num
]
=
{
stat_num
}
DEPENDENCIES
[
stat_pct
]
=
{
stat_num
,
"numel"
}
# Register NVFP4 stats
add_nvfp4_underflows_stats
()
add_mse_stats
(
"nvfp4"
)
# Reuse existing MSE function
transformer_engine/debug/pytorch/debug_quantization.py
View file @
9df0c4a3
...
...
@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
}
API_CALL_MODIFY
=
"modify_tensor()"
STANDARD_
FP8_
QUANTIZE
=
"
FP8
Quantize"
STANDARD_QUANTIZE
=
"Quantize"
HIGH_PRECISION
=
"High Precision"
...
...
@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_
FP8_
QUANTIZE, HIGH_PRECISION]
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self
.
output_tensor
=
tensor_name
in
[
"output"
,
"wgrad"
,
"dgrad"
]
if
self
.
output_tensor
:
...
...
@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
def
get_tensors_plan
(
self
):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_
FP8_
QUANTIZE, or HIGH_PRECISION, indicating the behavior
API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import
nvdlfw_inspect.api
as
debug_api
...
...
@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
rowwise_plan
=
API_CALL_MODIFY
else
:
if
self
.
parent_quantizer
is
not
None
:
fp8_
quantize
=
self
.
process_enabled_api_call
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
quantize
_enabled
=
self
.
process_enabled_api_call
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
# API name kept for compatibility
layer_name
=
self
.
layer_name
,
gemm
=
self
.
rowwise_gemm_name
,
iteration
=
self
.
iteration
,
)
)
if
fp8_
quantize
:
rowwise_plan
=
STANDARD_
FP8_
QUANTIZE
if
quantize
_enabled
:
rowwise_plan
=
STANDARD_QUANTIZE
if
rowwise_plan
is
None
:
rowwise_plan
=
HIGH_PRECISION
...
...
@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
columnwise_plan
=
API_CALL_MODIFY
else
:
if
self
.
parent_quantizer
is
not
None
:
fp8_
quantize
=
self
.
process_enabled_api_call
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
quantize
_enabled
=
self
.
process_enabled_api_call
(
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
# API name kept for compatibility
layer_name
=
self
.
layer_name
,
gemm
=
self
.
columnwise_gemm_name
,
iteration
=
self
.
iteration
,
)
)
if
fp8_
quantize
:
columnwise_plan
=
STANDARD_
FP8_
QUANTIZE
if
quantize
_enabled
:
columnwise_plan
=
STANDARD_QUANTIZE
if
columnwise_plan
is
None
:
columnwise_plan
=
HIGH_PRECISION
...
...
@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
del
args
[
"quantizer"
]
if
(
self
.
rowwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_
FP8_
QUANTIZE
]
self
.
rowwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_QUANTIZE
]
and
self
.
inspect_tensor_postquantize_enabled_rowwise
):
args
[
"tensor"
]
=
rowwise_gemm_tensor
...
...
@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
**
args
)
if
(
self
.
columnwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_
FP8_
QUANTIZE
]
self
.
columnwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_QUANTIZE
]
and
self
.
inspect_tensor_postquantize_enabled_columnwise
):
args
[
"tensor"
]
=
columnwise_gemm_tensor
...
...
@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
self
.
parent_quantizer
.
set_usage
(
rowwise
=
True
)
rowwise_gemm_tensor
,
columnwise_gemm_tensor
=
None
,
None
if
STANDARD_
FP8_
QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
if
STANDARD_QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
quantized_tensor
=
self
.
parent_quantizer
(
tensor
)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be
in fp8
,
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be
quantized
,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if
self
.
rowwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
:
if
self
.
rowwise_tensor_plan
==
STANDARD_QUANTIZE
:
rowwise_gemm_tensor
=
quantized_tensor
if
self
.
columnwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
:
if
self
.
columnwise_tensor_plan
==
STANDARD_QUANTIZE
:
columnwise_gemm_tensor
=
quantized_tensor
# 2. modify_tensor() is called, if it is used.
...
...
@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import
nvdlfw_inspect.api
as
debug_api
assert
self
.
parent_quantizer
is
None
,
"
FP8
output is not supported for debug=True."
assert
self
.
parent_quantizer
is
None
,
"
Quantized
output is not supported for debug=True."
assert
self
.
output_tensor
tensor_to_gemm
=
{
"output"
:
"fprop"
,
"wgrad"
:
"wgrad"
,
"dgrad"
:
"dgrad"
}
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
...
...
@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
):
return
True
if
self
.
parent_quantizer
is
not
None
:
if
self
.
rowwise_tensor_plan
!=
STANDARD_
FP8_
QUANTIZE
:
if
self
.
rowwise_tensor_plan
!=
STANDARD_QUANTIZE
:
return
True
if
self
.
columnwise_tensor_plan
!=
STANDARD_
FP8_
QUANTIZE
:
if
self
.
columnwise_tensor_plan
!=
STANDARD_QUANTIZE
:
return
True
return
False
...
...
@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
if
self
.
parent_quantizer
is
not
None
:
if
(
dst
.
rowwise_gemm_tensor
is
not
None
and
self
.
rowwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
and
self
.
rowwise_tensor_plan
==
STANDARD_QUANTIZE
):
if
hasattr
(
dst
.
rowwise_gemm_tensor
,
"quantize_"
):
dst
.
rowwise_gemm_tensor
.
quantize_
(
src
,
noop_flag
=
None
)
...
...
@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
updated_rowwise_gemm
=
True
if
(
dst
.
columnwise_gemm_tensor
is
not
None
and
self
.
columnwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
and
self
.
columnwise_tensor_plan
==
STANDARD_QUANTIZE
and
not
updated_rowwise_gemm
):
if
hasattr
(
dst
.
columnwise_gemm_tensor
,
"quantize_"
):
...
...
@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize
=
(
self
.
rowwise_usage
and
self
.
rowwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
)
rowwise_gemm_quantize
=
self
.
rowwise_usage
and
self
.
rowwise_tensor_plan
==
STANDARD_QUANTIZE
columnwise_gemm_quantize
=
(
self
.
columnwise_usage
and
self
.
columnwise_tensor_plan
==
STANDARD_
FP8_
QUANTIZE
self
.
columnwise_usage
and
self
.
columnwise_tensor_plan
==
STANDARD_QUANTIZE
)
if
STANDARD_
FP8_
QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
if
STANDARD_QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
self
.
parent_quantizer
.
set_usage
(
rowwise
=
rowwise_gemm_quantize
,
columnwise
=
columnwise_gemm_quantize
,
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment