Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);
......
......@@ -208,13 +208,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph);
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic);
/*! \brief Compute dot product attention with packed QKV input.
*
......@@ -269,22 +270,21 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
[[deprecated(
"nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate "
"Q, K, V tensors instead.")]]
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream);
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed QKV input.
*
......@@ -332,6 +332,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
......@@ -346,8 +347,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with packed KV input.
*
......@@ -409,6 +410,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
......@@ -424,7 +426,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with packed KV input.
*
......@@ -478,6 +481,7 @@ void nvte_fused_attn_fwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
......@@ -494,8 +498,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
cudaStream_t stream);
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute dot product attention with separate Q, K and V.
*
......@@ -559,19 +563,23 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] workspace Workspace tensor.
* \param[in] stream CUDA stream used for this operation.
*/
void nvte_fused_attn_fwd(
const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V,
const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv,
const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded,
const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream);
const NVTETensor cu_seqlens_q_padded,
const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k,
const NVTETensor page_table_v, const NVTETensor rng_state,
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training,
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream);
/*! \brief Compute the backward of the dot product attention with separate Q, K and V.
*
......@@ -628,6 +636,7 @@ void nvte_fused_attn_fwd(
* \param[in] softmax_type Attention softmax type.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] bottom_right_diagonal Whether to align sliding window and ALiBi diagonal to the bottom right corner of the softmax matrix.
* \param[in] deterministic Whether to execute with deterministic behaviours.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] workspace Workspace tensor.
......@@ -643,8 +652,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream);
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Update the RNG state with the seed and calculated offset.
*
......
......@@ -11,6 +11,8 @@
#ifndef TRANSFORMER_ENGINE_GEMM_H_
#define TRANSFORMER_ENGINE_GEMM_H_
#include <stdint.h>
#include "transformer_engine.h"
#ifdef __cplusplus
......@@ -20,6 +22,9 @@ extern "C" {
/*! \brief Configuration for matrix multiplication. */
typedef void *NVTEMatmulConfig;
/*! \brief Configuration for grouped matrix multiplication. */
typedef void *NVTEGroupedMatmulConfig;
/*! \enum NVTEMatmulConfigAttribute
* \brief Type of option for matrix multiplication.
*/
......@@ -52,6 +57,36 @@ enum NVTEMatmulConfigAttribute {
kNVTEMatmulConfigNumAttributes
};
/*! \enum NVTEGroupedMatmulConfigAttribute
* \brief Type of option for grouped matrix multiplication.
*/
enum NVTEGroupedMatmulConfigAttribute {
/*! Average M dimension hint
*
* Optional hint for average M dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgM = 0,
/*! Average N dimension hint
*
* Optional hint for average N dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from D's logical shape.
*/
kNVTEGroupedMatmulConfigAvgN = 1,
/*! Average K (reduction) dimension hint
*
* Optional hint for average K dimension across all matrices in the group.
* Used by cuBLASLt for algorithm selection heuristics. If not set,
* computed automatically from A's logical shape.
*/
kNVTEGroupedMatmulConfigAvgK = 2,
/*! Number of streaming multiprocessors to use in GEMM kernel. */
kNVTEGroupedMatmulConfigSMCount = 3,
kNVTEGroupedMatmulConfigNumAttributes
};
/*! \brief Create a matrix multiplication configuration. */
NVTEMatmulConfig nvte_create_matmul_config();
......@@ -82,6 +117,38 @@ void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigA
/*! \brief Destroy a matrix multiplication configuration. */
void nvte_destroy_matmul_config(NVTEMatmulConfig config);
/*! \brief Create a grouped matrix multiplication configuration. */
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config();
/*! \brief Query an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to write option value. Ignored if
* NULL.
* \param[in] size_in_bytes Size of buf.
* \param[out] size_written Number of bytes that have been written to
* buf. If buf is NULL, then the number of
* bytes that would have been written.
*/
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written);
/*! \brief Set an option in grouped matrix multiplication configuration.
*
* \param[in] config Grouped matrix multiplication configuration.
* \param[in] attr Option type.
* \param[out] buf Memory address to read option value.
* \param[in] size_in_bytes Size of buf.
*/
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes);
/*! \brief Destroy a grouped matrix multiplication configuration. */
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config);
/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated).
*
* This has been deprecated in favor of nvte_cublas_gemm_v2.
......@@ -229,6 +296,46 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor
bool accumulate, bool use_split_accumulator, int math_sm_count,
cudaStream_t stream);
/* EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. */
/*! \brief Grouped matrix multiplication: D = alpha * op(A) @ op(B) + beta * C
*
* \note Requires cuBLAS 13.2+ (CUDA 13.1+) and Blackwell (SM100) or newer GPU architecture.
* Will error at runtime if compiled with an older cuBLAS version or run on
* a pre-Blackwell GPU.
*
* Performs batched GEMM on a collection of matrices with potentially different shapes.
* All tensors in the group must have compatible dimensions for matrix multiplication.
* Uses NVTEGroupedTensor to efficiently handle collections of tensors with contiguous
* memory layout and shape metadata.
*
* \param[in] A Input grouped tensor A.
* \param[in] transa Whether to transpose A matrices.
* \param[in] B Input grouped tensor B.
* \param[in] transb Whether to transpose B matrices.
* \param[in] C Input grouped tensor C (can be NULL for beta=0).
* \param[out] D Output grouped tensor D.
* \param[in] alpha Scale multipliers for A @ B (NVTETensor with num_tensors elements).
* \param[in] beta Scale multipliers for C (NVTETensor with num_tensors elements).
* \param[in] workspace_setup Workspace tensor for pointer array setup.
* \param[in] workspace_cublas Workspace tensor for cuBLAS operations.
* \param[in] config Additional configuration (can be NULL for defaults).
* \param[in] stream CUDA stream for the operation.
*
* Requirements:
* - cuBLAS 13.2+ (CUDA 13.1+)
* - Blackwell (SM100) or newer GPU architecture
* - A, B, C (if provided), D must have the same num_tensors
* - For each i: D[i] = alpha[i] * op(A[i]) @ op(B[i]) + beta[i] * C[i]
* - Shape compatibility: if transa=false, transb=false:
* - A[i]: (M[i], K[i]), B[i]: (K[i], N[i]), D[i]: (M[i], N[i])
*/
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream);
#ifdef __HIP_PLATFORM_AMD__
void nvte_multi_stream_cublas_batchgemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D,
const NVTETensor* bias, NVTETensor* pre_gelu_out,
......@@ -356,6 +463,70 @@ class MatmulConfigWrapper {
NVTEMatmulConfig config_ = nullptr;
};
/*! \struct GroupedMatmulConfigWrapper
* \brief C++ wrapper for NVTEGroupedMatmulConfig.
*/
class GroupedMatmulConfigWrapper {
public:
GroupedMatmulConfigWrapper() : config_{nvte_create_grouped_matmul_config()} {}
GroupedMatmulConfigWrapper(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper &operator=(const GroupedMatmulConfigWrapper &) = delete;
GroupedMatmulConfigWrapper(GroupedMatmulConfigWrapper &&other) : config_{other.config_} {
other.config_ = nullptr;
}
GroupedMatmulConfigWrapper &operator=(GroupedMatmulConfigWrapper &&other) {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
}
config_ = other.config_;
other.config_ = nullptr;
return *this;
}
~GroupedMatmulConfigWrapper() {
if (config_ != nullptr) {
nvte_destroy_grouped_matmul_config(config_);
config_ = nullptr;
}
}
/*! \brief Get the underlying NVTEGroupedMatmulConfig.
*
* \return NVTEGroupedMatmulConfig held by this GroupedMatmulConfigWrapper.
*/
operator NVTEGroupedMatmulConfig() const noexcept { return config_; }
/*! \brief Set average M dimension hint for algorithm selection. */
void set_avg_m(int64_t avg_m) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgM, &avg_m,
sizeof(int64_t));
}
/*! \brief Set average N dimension hint for algorithm selection. */
void set_avg_n(int64_t avg_n) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgN, &avg_n,
sizeof(int64_t));
}
/*! \brief Set average K dimension hint for algorithm selection. */
void set_avg_k(int64_t avg_k) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigAvgK, &avg_k,
sizeof(int64_t));
}
/*! \brief Set number of streaming multiprocessors to use. */
void set_sm_count(int sm_count) {
nvte_set_grouped_matmul_config_attribute(config_, kNVTEGroupedMatmulConfigSMCount, &sm_count,
sizeof(int));
}
private:
/*! \brief Wrapped NVTEGroupedMatmulConfig. */
NVTEGroupedMatmulConfig config_ = nullptr;
};
} // namespace transformer_engine
#endif // __cplusplus
......
......@@ -86,6 +86,24 @@ void nvte_group_hadamard_transform_amax(const NVTETensor input, NVTETensor* outp
int random_sign_mask, int random_sign_mask_t,
cudaStream_t stream);
/*! \brief Grouped-tensor amax with Hadamard transform (graph safe, device-managed grouping).
*
* This function is experimental and the API is not stable.
*
* This API assumes that the split info (grouping of tensors) is on device and unknown to the host;
* therefore, this is a graph safe API and the grouped-tensor argument is passed as a single device structure.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output amax (row/col). Only the row-wise and
* column-wise amaxes are updated.
* \param[in] random_sign_mask 16-bit sign mask for RHT.
* \param[in] random_sign_mask_t 16-bit sign mask for transposed RHT.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input,
NVTEGroupedTensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor columnwise Hadamard transform cast fusion operation.
*
......@@ -124,6 +142,22 @@ void nvte_group_hadamard_transform_cast_fusion(const NVTETensor input, NVTETenso
const NVTEQuantizationConfig quant_config,
NVTETensor quant_workspace, cudaStream_t stream);
/*!
* \brief Perform the grouped-tensor Hadamard transform cast fusion operation in graph-safe mode.
*
* This function is experimental and the API is not stable. Group_ prefix means contiguous input concatenated.
*
* \param[in] input NVTEGroupedTensor representing grouped input tensors.
* \param[in,out] output NVTEGroupedTensor for output (row/column-wise quantized results).
* \param[in] hadamard_matrix Hadamard matrix to use for transformation.
* \param[in] quant_config Quantization configuration.
* \param[in] quant_workspace Workspace buffer. Must be at least 4 bytes.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_hadamard_transform_cast_fusion_graph_safe(
const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -296,6 +296,17 @@ void nvte_multi_tensor_compute_scale_inv_e8m0_cuda(int chunk_size, NVTETensor **
void nvte_group_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections,
size_t num_tensors, cudaStream_t stream);
/*! \brief Grouped-tensor amax without doing hadamard transform.
*
* This function is experimental and the API is not stable.
*
* \param[in] input NVTEGroupedTensor Input tensor.
* \param[in,out] output NVTEGroupedTensor Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -88,33 +88,40 @@ class Recipe:
Base recipe class.
"""
def nvfp4(self):
@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
return issubclass(cls, NVFP4BlockScaling)
def mxfp8(self):
@classmethod
def mxfp8(cls):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
return issubclass(cls, MXFP8BlockScaling)
def delayed(self):
@classmethod
def delayed(cls):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
return issubclass(cls, DelayedScaling)
def float8_current_scaling(self):
@classmethod
def float8_current_scaling(cls):
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
return issubclass(cls, Float8CurrentScaling)
def float8_per_tensor_scaling(self):
@classmethod
def float8_per_tensor_scaling(cls):
"""Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling))
return issubclass(cls, (DelayedScaling, Float8CurrentScaling))
def float8_block_scaling(self):
@classmethod
def float8_block_scaling(cls):
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
return issubclass(cls, Float8BlockScaling)
def custom(self):
@classmethod
def custom(cls):
"""Whether the given recipe is custom."""
return isinstance(self, CustomRecipe)
return issubclass(cls, CustomRecipe)
@dataclass()
......
......@@ -458,9 +458,9 @@ class TensorAllocator {
}
void Free(NVTETensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid tensor.");
free_list.push_back(index);
// Clean up
......@@ -568,9 +568,9 @@ class GroupedTensorAllocator {
}
void Free(NVTEGroupedTensor t) {
std::lock_guard<std::mutex> lock(mutex);
uintptr_t index = reinterpret_cast<uintptr_t>(t);
if (index == 0) return;
std::lock_guard<std::mutex> lock(mutex);
NVTE_CHECK(index <= memory.size(), "Invalid grouped tensor.");
free_list.push_back(index);
// Clean up
......
......@@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens = tl.sum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
......@@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row = tl.where(pid < total_valid_tokens, dst_row, pid)
tl.store(dst_rows_ptr + pid, dst_row)
......
......@@ -6,6 +6,8 @@
#include "../util/cuda_runtime.h"
#include <cublasLt.h>
#include <filesystem>
#include <mutex>
......@@ -232,6 +234,12 @@ int cudart_version() {
return version;
}
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace cuda
} // namespace transformer_engine
......@@ -85,6 +85,12 @@ const std::string &include_directory(bool required = false);
*/
int cudart_version();
/* \brief cuBLAS version number at run-time
*
* Versions may differ between compile-time and run-time.
*/
size_t cublas_version();
} // namespace cuda
} // namespace transformer_engine
......
......@@ -145,7 +145,7 @@
do { \
const cublasMpStatus_t status = (expr); \
if (status != CUBLASMP_STATUS_SUCCESS) { \
NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \
NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \
} \
} while (false)
......
......@@ -172,6 +172,18 @@ __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(
uint64_t *mbar, const uint32_t tx_count) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.shared::cta.b64 _, [%0], %1;" ::"r"(mbar_ptr),
"r"(tx_count));
#else
NVTE_DEVICE_ERROR(
"mbarrier_arrive_expect_tx_cta_relaxed_shared_cta is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void fence_mbarrier_init_release_cluster() {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm volatile("fence.mbarrier_init.release.cluster;");
......@@ -251,13 +263,86 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void mbarrier_wait_parity_acquire_cta_shared_cta(uint64_t *mbar,
uint32_t phase_parity) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
asm volatile(
"{\n\t"
".reg .b64 r1; \n\t"
".reg .pred waitComplete; \n\t" // predicate representing if barrier condition is met
"WAIT: \n\t" // loop around barrier wait
"mbarrier.try_wait.parity.acquire.cta.shared::cta.b64 waitComplete, [%0], %1; \n\t"
"@waitComplete bra DONE; \n\t" // mbarrier conditions are met
"bra WAIT; \n\t" // just a time-out, try again
"DONE: \n\t"
"}\n\t"
:
: "r"(mbar_ptr), "r"(phase_parity)
: "memory");
#else
NVTE_DEVICE_ERROR("mbarrier_wait_parity_acquire_cta_shared_cta is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__ __forceinline__ void try_cancel_cta(uint64_t *mbar, __uint128_t *response_data_ptr) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint32_t mbar_ptr = __cvta_generic_to_shared(mbar);
uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr);
asm volatile(
"clusterlaunchcontrol.try_cancel.async.mbarrier::complete_tx::bytes.multicast::cluster::"
"all.b128 "
"[%0], [%1];" ::"r"(workID_response),
"r"(mbar_ptr));
} else {
NVTE_DEVICE_ERROR(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
__device__ __forceinline__ void get_cancelled_cta_id_2D(__uint128_t *response_data_ptr,
int32_t &ctaid_X, int32_t &ctaid_Y) {
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
uint32_t workID_response = __cvta_generic_to_shared(response_data_ptr);
asm volatile(
"{\n\t"
".reg .s32 x_ctaid; \n\t"
".reg .s32 y_ctaid; \n\t"
"mov .s32 x_ctaid, -1; \n\t"
"mov .s32 y_ctaid, -1; \n\t"
".reg.b128 try_cancel_response; \n\t"
"ld.shared.b128 try_cancel_response, [%2]; \n\t"
".reg .pred P1; \n\t"
"clusterlaunchcontrol.query_cancel.is_canceled.pred.b128 P1, try_cancel_response; \n\t"
"@P1 clusterlaunchcontrol.query_cancel.get_first_ctaid.v4.b32.b128 {x_ctaid, y_ctaid, _, "
"_}, try_cancel_response; \n\t"
"mov .s32 %0, x_ctaid; \n\t"
"mov .s32 %1, y_ctaid; \n\t"
"}\n\t"
: "=r"(ctaid_X), "=r"(ctaid_Y)
: "r"(workID_response)
: "memory");
} else {
NVTE_DEVICE_ERROR(
"Cluster Launch Control PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1
: __int_as_float((254 - biased_exp)
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127)
// Handle the special case of NaN.
if (biased_exp == 255) return __int_as_float(0x7fffffff);
// Handle the special case where the unbiased exponent is 127, so the reciprocal is 2^-127 which needs the first bit of
// the mantissa to be 1, which can't be obtained by shifting `FP32_MANTISSA_BITS` bits to the left.
if (biased_exp == 254) return __int_as_float(0x00400000);
// Fast calculation when the unbiased exp is in [-126, 126], and only the exponent part is used to express the reciprocal.
return __int_as_float((254 - biased_exp) << FP32_MANTISSA_BITS);
}
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
......@@ -671,6 +756,179 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c
return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits);
}
}
template <typename SCALING_COEFFICIENT_TYPE>
__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_round_to_nearest(
const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient) {
uint32_t out_8x = 0;
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, bf16>::value) {
asm volatile(
"{\n"
".reg.f32 zero; \n\t"
"mov.b32 zero, 0; \n\t"
".reg.b16 scaling_coeff; \n\t"
"mov.b16 scaling_coeff, %3; \n\t"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t"
".reg.b8 f0, f1, f2, f3; \n\t"
// Elements reordered to match e2m1x4 packing order (v1,v0)
"cvt.rn.satfinite.e2m1x2.f32 f0, v1, v0;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v3, v2;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f2, v5, v4;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f3, v7, v6;\n\t"
"mov.b32 %0, {f0, f1, f2, f3};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "h"(reinterpret_cast<const uint16_t &>(scaling_coefficient)));
} else if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, float>::value) {
asm volatile(
"{\n"
".reg.b64 scaling_coeff_2x; \n\t"
"mov.b64 scaling_coeff_2x, {%3, %3}; \n\t"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"cvt.f32.bf16 v4, v4_bf16; \n\t"
"cvt.f32.bf16 v5, v5_bf16; \n\t"
"cvt.f32.bf16 v6, v6_bf16; \n\t"
"cvt.f32.bf16 v7, v7_bf16; \n\t"
".reg.b64 v01, v23, v45, v67; \n\t"
"mov.b64 v01, {v0, v1}; \n\t"
"mov.b64 v23, {v2, v3}; \n\t"
"mov.b64 v45, {v4, v5}; \n\t"
"mov.b64 v67, {v6, v7}; \n\t"
"mul.f32x2 v01, v01, scaling_coeff_2x; \n\t"
"mul.f32x2 v23, v23, scaling_coeff_2x; \n\t"
"mul.f32x2 v45, v45, scaling_coeff_2x; \n\t"
"mul.f32x2 v67, v67, scaling_coeff_2x; \n\t"
// Elements reordered to match the packing order (v1,v0)
"mov.b64 {v1, v0}, v01; \n\t"
"mov.b64 {v3, v2}, v23; \n\t"
"mov.b64 {v5, v4}, v45; \n\t"
"mov.b64 {v7, v6}, v67; \n\t"
".reg.b8 f0, f1, f2, f3; \n\t"
"cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f2, v4, v5;\n\t"
"cvt.rn.satfinite.e2m1x2.f32 f3, v6, v7;\n\t"
"mov.b32 %0, {f0, f1, f2, f3};\n\t"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "f"(scaling_coefficient));
} else {
NVTE_DEVICE_ERROR("Not supported scaling coefficient type.");
}
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return out_8x;
}
template <typename SCALING_COEFFICIENT_TYPE>
__device__ __forceinline__ uint32_t mul_cvt_bf16_to_fp4_8x_stochastic_rounding(
const uint64_t in03, const uint64_t in47, const SCALING_COEFFICIENT_TYPE scaling_coefficient,
const uint32_t rbits03, const uint32_t rbits47) {
uint32_t out_8x = 0;
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, bf16>::value) {
asm volatile(
"{\n"
".reg.f32 zero; \n\t"
"mov.b32 zero, 0; \n\t"
".reg.b16 scaling_coeff; \n\t"
"mov.b16 scaling_coeff, %3; \n\t"
".reg.b16 v0_h, v1_h, v2_h, v3_h, v4_h, v5_h, v6_h, v7_h; \n\t"
"mov.b64 {v0_h, v1_h, v2_h, v3_h}, %1; \n\t"
"mov.b64 {v4_h, v5_h, v6_h, v7_h}, %2; \n\t"
".reg.f32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"fma.rn.f32.bf16 v0, v0_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v1, v1_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v2, v2_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v3, v3_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v4, v4_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v5, v5_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v6, v6_h, scaling_coeff, zero; \n\t"
"fma.rn.f32.bf16 v7, v7_h, scaling_coeff, zero; \n\t"
".reg.b16 b03, b47; \n\t"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t"
"mov.b32 %0, {b03, b47};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "h"(reinterpret_cast<const uint16_t &>(scaling_coefficient)),
"r"(rbits03), "r"(rbits47));
} else if constexpr (std::is_same<SCALING_COEFFICIENT_TYPE, float>::value) {
asm volatile(
"{\n"
".reg.b16 v0_bf16, v1_bf16, v2_bf16, v3_bf16, v4_bf16, v5_bf16, v6_bf16, v7_bf16; \n\t"
"mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16}, %1; \n\t"
"mov.b64 {v4_bf16, v5_bf16, v6_bf16, v7_bf16}, %2; \n\t"
".reg.b32 v0, v1, v2, v3, v4, v5, v6, v7; \n\t"
"cvt.f32.bf16 v0, v0_bf16; \n\t"
"cvt.f32.bf16 v1, v1_bf16; \n\t"
"cvt.f32.bf16 v2, v2_bf16; \n\t"
"cvt.f32.bf16 v3, v3_bf16; \n\t"
"cvt.f32.bf16 v4, v4_bf16; \n\t"
"cvt.f32.bf16 v5, v5_bf16; \n\t"
"cvt.f32.bf16 v6, v6_bf16; \n\t"
"cvt.f32.bf16 v7, v7_bf16; \n\t"
"mul.f32 v0, v0, %3; \n\t"
"mul.f32 v1, v1, %3; \n\t"
"mul.f32 v2, v2, %3; \n\t"
"mul.f32 v3, v3, %3; \n\t"
"mul.f32 v4, v4, %3; \n\t"
"mul.f32 v5, v5, %3; \n\t"
"mul.f32 v6, v6, %3; \n\t"
"mul.f32 v7, v7, %3; \n\t"
".reg.b16 b03, b47; \n\t"
// Elements reordered to match e2m1x4 packing order (v3,v2,v1,v0)
"cvt.rs.satfinite.e2m1x4.f32 b03, {v3, v2, v1, v0}, %4; \n\t"
"cvt.rs.satfinite.e2m1x4.f32 b47, {v7, v6, v5, v4}, %5; \n\t"
"mov.b32 %0, {b03, b47};\n"
"}"
: "=r"(out_8x)
: "l"(in03), "l"(in47), "f"(scaling_coefficient), "r"(rbits03), "r"(rbits47));
} else {
NVTE_DEVICE_ERROR("Not supported scaling coefficient type.");
}
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return out_8x;
}
#endif // FP4_TYPE_SUPPORTED
// SIMD like "Fused" cast + multiplication (x2)
......@@ -1521,6 +1779,59 @@ __device__ __forceinline__ floatx4 up_cast(const bf16x4 &in) {
: "r"(in2[0]), "r"(in2[1]));
return out;
}
// Loads single BF16/FP16 element from shared memory state space
__device__ __forceinline__ bf16 ld_shared_b16(const bf16 *__restrict__ src_smem) {
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
bf16 dst;
asm volatile("ld.shared.b16 %0, [%1];"
: "=h"(reinterpret_cast<uint16_t &>(dst))
: "r"(src_smem_ptr));
return dst;
}
// Loads pair of BF16/FP16 values from shared memory state space
__device__ __forceinline__ bf16x2 ld_shared_b32(const bf16x2 *__restrict__ src_smem) {
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
bf16x2 dst;
asm volatile("ld.shared.b32 %0, [%1];"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(src_smem_ptr));
return dst;
}
// Loads 8x BF16 values from shared memory state space
__device__ __forceinline__ __uint128_t ld_shared_b128(const bf16 *__restrict__ src_smem) {
uint64_t elts03, elts47;
const uint32_t src_smem_ptr = __cvta_generic_to_shared(src_smem);
asm volatile(
"{\n\t"
".reg.b128 xy; \n\t"
"ld.shared.b128 xy, [%2]; \n\t"
"mov.b128 {%0, %1}, xy; \n"
"}\n"
: "=l"(elts03), "=l"(elts47)
: "r"(src_smem_ptr));
return (static_cast<__uint128_t>(elts47) << 64) | static_cast<__uint128_t>(elts03);
}
#if FP4_TYPE_SUPPORTED
// Vectorized store of x8 FP4 elements into shared memory state space
__device__ __forceinline__ void st_shared_b32(fp4e2m1x2 *__restrict__ dst_smem,
uint32_t fp4_pack_x8) {
const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem);
asm volatile("st.shared.b32 [%0], %1;" : : "r"(dst_smem_ptr), "r"(fp4_pack_x8));
}
#endif
// Vectorized store of x16 FP4 elements into shared memory state space
#if FP4_TYPE_SUPPORTED
__device__ __forceinline__ void st_shared_b64(fp4e2m1x2 *__restrict__ dst_smem,
uint64_t fp4_pack_x16) {
const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst_smem);
asm volatile("st.shared.b64 [%0], %1;" : : "r"(dst_smem_ptr), "l"(fp4_pack_x16));
}
#endif
#endif
} // namespace ptx
......
......@@ -2,17 +2,28 @@
#
# See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper):
class DisableFP8GEMM(DisableQuantizationGEMM):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Parameters
----------
......@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def __init__(self, *args, **kwargs):
warnings.warn(
"DisableFP8GEMM is deprecated. "
"Use DisableQuantizationGEMM instead, which works with all quantization "
"formats (FP8, NVFP4, etc.).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
......@@ -2,17 +2,27 @@
#
# See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer:
class DisableFP8Layer(DisableQuantizationLayer):
"""
Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example
-------
......@@ -23,33 +33,16 @@ class DisableFP8Layer:
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api.log_message("FP8 Disabled", layer_name)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
def __init__(self, *args, **kwargs):
warnings.warn(
"DisableFP8Layer is deprecated. "
"Use DisableQuantizationLayer instead, which works with all quantization "
"formats (FP8, NVFP4, etc.).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationGEMM(TEConfigAPIMapper):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Parameters
----------
gemms: List[str]
list of gemms to disable quantization for
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect"""
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationLayer:
"""
Disables all quantized GEMMs in the layer, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Example
-------
.. code-block:: yaml
example_disable_quantization_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationLayer:
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If quantized training, disable quantization for the selected layers if this feature is enabled.
debug_api.log_message("Quantization Disabled", layer_name)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API.
DisableQuantizationLayer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
......@@ -6,15 +6,17 @@
from typing import Dict, Optional, List, Tuple
from contextlib import contextmanager
import warnings
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine_torch as tex
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
......@@ -22,7 +24,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
try:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
_nvfp4_available = True
except ImportError:
_nvfp4_available = False
NVFP4Quantizer = None
ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
......@@ -39,6 +48,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
return "mxfp8"
if isinstance(quantizer, Float8BlockQuantizer):
return "fp8_block_scaling"
if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer):
return "nvfp4"
raise ValueError(f"Unsupported quantizer type: {type(quantizer)}")
......@@ -164,6 +175,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")
# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError(
f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for"
......@@ -189,6 +210,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string."""
columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat:
......@@ -213,7 +235,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype = None
fp8_dtype = tex.DType.kFloat8E4M3
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
......@@ -277,14 +299,26 @@ class LogFp8TensorStats(BaseLogTensorStats):
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
# Skip logging if quantizer is None (layer runs in high precision)
if quantizer is None:
warnings.warn(
f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': layer runs in high precision (no quantizer)."
)
return
quantized_tensor = rowwise_quantized_tensor
assert isinstance(
quantized_tensor, QuantizedTensor
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
# Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision)
if not isinstance(quantized_tensor, QuantizedTensor):
warnings.warn(
f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected QuantizedTensor, got {type(quantized_tensor).__name__})."
)
return
recipe_name = _get_recipe_name(quantizer)
for stat in config["stats"]:
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Optional
from contextlib import contextmanager
import warnings
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
@Registry.register_feature(namespace="transformer_engine")
class LogNvfp4TensorStats(BaseLogTensorStats):
"""Logs statistics of NVFP4 quantized tensors.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_nvfp4_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 1
- tensor: gradient
stats: [underflows%, mse]
freq: 5
start_step: 0
end_step: 80
"""
def check_if_stat_is_supported(self, stat: str):
"""Returns True if stat is supported, raises ValueError otherwise."""
supported_stats = [
"underflows%",
"mse",
]
if stat not in supported_stats:
raise ValueError(
f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}"
)
return True
def get_stat_with_prefix(self, stat: str) -> str:
"""Add nvfp4_ prefix to stat name for use in stats_computation."""
return f"nvfp4_{stat}"
@contextmanager
def update_aux_dict(
self,
aux_dict: Dict,
quantized_tensor: QuantizedTensor,
quantizer: Quantizer, # pylint: disable=unused-argument
original_tensor: torch.Tensor,
):
"""
Updates the aux_dict with the quantized tensor and additional NVFP4-specific data.
Yields the aux_dict.
"""
aux_dict = {
"nvfp4": quantized_tensor,
"original_tensor": original_tensor,
}
try:
yield aux_dict
finally:
pass
@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
tp_group,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
# Skip logging if quantizer is None (layer runs in high precision)
if quantizer is None:
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': layer runs in high precision (no quantizer)."
)
return
quantized_tensor = rowwise_quantized_tensor
# Skip logging if not NVFP4 quantizer (incompatible precision)
if not isinstance(quantizer, NVFP4Quantizer):
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected NVFP4Quantizer, got {type(quantizer).__name__})."
)
return
# Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision)
if not isinstance(quantized_tensor, NVFP4TensorStorage):
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected NVFP4TensorStorage, got {type(quantized_tensor).__name__})."
)
return
for stat in config["stats"]:
self.check_if_stat_is_supported(stat)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = (
start_step,
end_step,
start_end_list,
"nvfp4",
)
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
)
# Add nvfp4_ prefix to all stats for internal use
prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]]
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=prefixed_stats,
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
with self.update_aux_dict(
aux_dict={},
quantized_tensor=quantized_tensor,
quantizer=quantizer,
original_tensor=tensor,
) as aux_dict:
STATS_BUFFERS.feed(
layer_name,
tensor_name,
options,
tensor,
iteration,
skip_reduction,
aux_dict=aux_dict,
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
......@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise)
# NVFP4-specific statistics
def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor:
"""Count the number of non-zero elements in the FP4 data.
FP4 data is stored as 2 4-bit values per byte (uint8).
We need to unpack and count non-zeros.
"""
# Each byte contains two FP4 values
# Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0)
zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8)
# Extract first and second nibbles
first_nibble = fp4_data % 16
second_nibble = fp4_data // 16
# Count zeros
first_zeros = torch.isin(first_nibble, zero_vals).sum()
second_zeros = torch.isin(second_nibble, zero_vals).sum()
total_elements = fp4_data.numel() * 2
return total_elements - first_zeros - second_zeros
def add_nvfp4_underflows_stats():
"""Register underflow stats for NVFP4.
Computes underflows by counting zeros in packed FP4 data vs original tensor.
"""
stat_num = "nvfp4_underflows_num"
stat_pct = "nvfp4_underflows%"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
# Count non-zeros in original vs FP4 packed data
STATS[stat_num] = (
lambda x, aux_dict: x.count_nonzero()
- count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data)
)
/ aux_dict["nvfp4"].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
# Register NVFP4 stats
add_nvfp4_underflows_stats()
add_mse_stats("nvfp4") # Reuse existing MSE function
......@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize"
STANDARD_QUANTIZE = "Quantize"
HIGH_PRECISION = "High Precision"
......@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
......@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
......@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if quantize_enabled:
rowwise_plan = STANDARD_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
......@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if quantize_enabled:
columnwise_plan = STANDARD_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
......@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
del args["quantizer"]
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
......@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
......@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE:
if self.rowwise_tensor_plan == STANDARD_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE:
if self.columnwise_tensor_plan == STANDARD_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
......@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True."
assert self.parent_quantizer is None, "Quantized output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
......@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
if self.rowwise_tensor_plan != STANDARD_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
if self.columnwise_tensor_plan != STANDARD_QUANTIZE:
return True
return False
......@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
and self.rowwise_tensor_plan == STANDARD_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
......@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
and self.columnwise_tensor_plan == STANDARD_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
......@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE
)
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment