"src/lib/vscode:/vscode.git/clone" did not exist on "f74f2ea76517041bab1e245650a4690793fe22f9"
Unverified Commit 74faf7ec authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] NVFP4 debug stats support (#2296)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* year update in license
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 29b84c16
...@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea ...@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation, - log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
- run selected GEMMs in higher precision, - run selected GEMMs in higher precision,
- run current scaling - with one scaling factor per tensor - for particular GEMMs, - run current scaling - with one scaling factor per tensor - for particular GEMMs,
- test new precisions and integrate them with FP8 training, - test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
- ... and many more. - ... and many more.
There are 4 things one needs to do to use Transformer Engine debug features: There are 4 things one needs to do to use Transformer Engine debug features:
......
...@@ -8,7 +8,10 @@ Debug features ...@@ -8,7 +8,10 @@ Debug features
.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats .. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats .. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM .. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer .. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling .. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant .. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
\ No newline at end of file
...@@ -15,6 +15,7 @@ from transformer_engine.pytorch import ( ...@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
is_fp8_available, is_fp8_available,
is_mxfp8_available, is_mxfp8_available,
is_fp8_block_scaling_available, is_fp8_block_scaling_available,
is_nvfp4_available,
) )
from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
...@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) ...@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True return_reason=True
) )
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
LOG_QUANTIZED_CONFIG_BASE = """ LOG_QUANTIZED_CONFIG_BASE = """
log: log:
...@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): ...@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset() TEDebugState._reset()
# NVFP4 tests
LOG_NVFP4_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
def test_nvfp4_numeric(feature_dirs):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")
with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState
recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
num_quantizers=3,
)
# Create test tensor with known distribution
torch.manual_seed(42)
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Add some small values that should underflow to zero in FP4
tensor[0, :16] = 0.0001
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
# Validate both stats are present
assert "nvfp4_underflows%" in output, "underflows% stat missing"
assert "nvfp4_mse" in output, "mse stat missing"
# Extract values and validate numerics
underflows_value = None
mse_value = None
for line in output.splitlines():
if "nvfp4_underflows%" in line and "value=" in line:
underflows_value = float(line.split("value=")[1].split()[0])
if "nvfp4_mse" in line and "value=" in line:
mse_value = float(line.split("value=")[1].split()[0])
# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask = tensor != 0
dequant_zero_mask = dequantized_tensor == 0
expected_underflows = (
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
)
# Allow some tolerance
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)
# Compute expected MSE
expected_mse = torch.nn.functional.mse_loss(
dequantized_tensor.float(), tensor.float(), reduction="mean"
)
assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)
def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")
with debug_session(log_fp8_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for _ in range(2):
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert "mxfp8_mse" in output
def test_log_grouped_gemm(feature_dirs): def test_log_grouped_gemm(feature_dirs):
if not fp8_available: if not fp8_available:
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
......
...@@ -2,17 +2,28 @@ ...@@ -2,17 +2,28 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect""" """DisableFP8GEMM Feature support for nvidia-dlframework-inspect
from nvdlfw_inspect.registry import Registry, api_method DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
from transformer_engine.debug.features.api import TEConfigAPIMapper New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM
@Registry.register_feature(namespace="transformer_engine") @Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper): class DisableFP8GEMM(DisableQuantizationGEMM):
""" """
GEMM operations are executed in higher precision, even when FP8 autocast is enabled. GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Parameters Parameters
---------- ----------
...@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper): ...@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers: layers:
layer_types: [fc1] layer_types: [fc1]
transformer_engine: transformer_engine:
DisableFP8GEMM: DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
enabled: True enabled: True
gemms: [dgrad, wgrad] gemms: [dgrad, wgrad]
""" """
@api_method def __init__(self, *args, **kwargs):
def fp8_gemm_enabled( warnings.warn(
self, config, layer_name: str, gemm: str, iteration: int "DisableFP8GEMM is deprecated. "
): # pylint: disable=unused-argument "Use DisableQuantizationGEMM instead, which works with all quantization "
"""API call responsible for choice between high-precision and FP8 GEMM execution.""" "formats (FP8, NVFP4, etc.).",
DeprecationWarning,
for key in config: stacklevel=2,
if key != "gemm": )
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') super().__init__(*args, **kwargs)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
...@@ -2,17 +2,27 @@ ...@@ -2,17 +2,27 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect""" """DisableFP8Layer Feature support for nvidia-dlframework-inspect
import nvdlfw_inspect.api as debug_api DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
from nvdlfw_inspect.registry import Registry, api_method New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer
@Registry.register_feature(namespace="transformer_engine") @Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer: class DisableFP8Layer(DisableQuantizationLayer):
""" """
Disables all FP8 GEMMs in the layer. Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example Example
------- -------
...@@ -20,36 +30,19 @@ class DisableFP8Layer: ...@@ -20,36 +30,19 @@ class DisableFP8Layer:
example_disable_fp8_layer: example_disable_fp8_layer:
enabled: True enabled: True
layers: layers:
layer_types: [fc1] layer_types: [fc1]
transformer_engine: transformer_engine:
DisableFP8Layer: DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
enabled: True enabled: True
""" """
@api_method def __init__(self, *args, **kwargs):
def fp8_gemm_enabled( warnings.warn(
self, config, layer_name: str, gemm: str, iteration: int "DisableFP8Layer is deprecated. "
): # pylint: disable=unused-argument "Use DisableQuantizationLayer instead, which works with all quantization "
"""API call responsible for selecting between high-precision and FP8 GEMM execution.""" "formats (FP8, NVFP4, etc.).",
for key in config: DeprecationWarning,
if key not in ["enabled", "gemm"]: stacklevel=2,
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".') )
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config. super().__init__(*args, **kwargs)
debug_api.log_message("FP8 Disabled", layer_name)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationGEMM(TEConfigAPIMapper):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Parameters
----------
gemms: List[str]
list of gemms to disable quantization for
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect"""
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationLayer:
"""
Disables all quantized GEMMs in the layer, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Example
-------
.. code-block:: yaml
example_disable_quantization_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationLayer:
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If quantized training, disable quantization for the selected layers if this feature is enabled.
debug_api.log_message("Quantization Disabled", layer_name)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API.
DisableQuantizationLayer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
...@@ -9,12 +9,13 @@ from contextlib import contextmanager ...@@ -9,12 +9,13 @@ from contextlib import contextmanager
import torch import torch
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
import transformer_engine_torch as tex
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import ( from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer, Float8Quantizer,
...@@ -22,7 +23,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import ( ...@@ -22,7 +23,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
) )
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
try:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
_nvfp4_available = True
except ImportError:
_nvfp4_available = False
NVFP4Quantizer = None
ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
...@@ -39,6 +47,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]): ...@@ -39,6 +47,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
return "mxfp8" return "mxfp8"
if isinstance(quantizer, Float8BlockQuantizer): if isinstance(quantizer, Float8BlockQuantizer):
return "fp8_block_scaling" return "fp8_block_scaling"
if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer):
return "nvfp4"
raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") raise ValueError(f"Unsupported quantizer type: {type(quantizer)}")
...@@ -164,6 +174,16 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -164,6 +174,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")
# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError( raise ValueError(
f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for"
...@@ -189,6 +209,7 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -189,6 +209,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string.""" """Returns the recipe name from the stat string."""
columnwise_stat = stat.endswith("_columnwise") columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES: for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat: if recipe_name in stat:
...@@ -213,7 +234,7 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -213,7 +234,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
Yields the aux_dict. Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor. Needs to clean after usage, because it possibly change the usage of the quantized tensor.
""" """
fp8_dtype = None fp8_dtype = tex.DType.kFloat8E4M3
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance( assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
...@@ -282,6 +303,7 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -282,6 +303,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
quantized_tensor = rowwise_quantized_tensor quantized_tensor = rowwise_quantized_tensor
assert isinstance( assert isinstance(
quantized_tensor, QuantizedTensor quantized_tensor, QuantizedTensor
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Optional
from contextlib import contextmanager
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
@Registry.register_feature(namespace="transformer_engine")
class LogNvfp4TensorStats(BaseLogTensorStats):
"""Logs statistics of NVFP4 quantized tensors.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_nvfp4_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 1
- tensor: gradient
stats: [underflows%, mse]
freq: 5
start_step: 0
end_step: 80
"""
def check_if_stat_is_supported(self, stat: str):
"""Returns True if stat is supported, raises ValueError otherwise."""
supported_stats = [
"underflows%",
"mse",
]
if stat not in supported_stats:
raise ValueError(
f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}"
)
return True
def get_stat_with_prefix(self, stat: str) -> str:
"""Add nvfp4_ prefix to stat name for use in stats_computation."""
return f"nvfp4_{stat}"
@contextmanager
def update_aux_dict(
self,
aux_dict: Dict,
quantized_tensor: QuantizedTensor,
quantizer: Quantizer, # pylint: disable=unused-argument
original_tensor: torch.Tensor,
):
"""
Updates the aux_dict with the quantized tensor and additional NVFP4-specific data.
Yields the aux_dict.
"""
aux_dict = {
"nvfp4": quantized_tensor,
"original_tensor": original_tensor,
}
try:
yield aux_dict
finally:
pass
@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
tp_group,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer."
quantized_tensor = rowwise_quantized_tensor
# Ensure we're working with NVFP4 tensors
if not isinstance(quantizer, NVFP4Quantizer):
raise ValueError(
"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, "
f"but got {type(quantizer).__name__}"
)
assert isinstance(quantized_tensor, NVFP4TensorStorage), (
"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a"
" NVFP4TensorStorage."
)
for stat in config["stats"]:
self.check_if_stat_is_supported(stat)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = (
start_step,
end_step,
start_end_list,
"nvfp4",
)
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
)
# Add nvfp4_ prefix to all stats for internal use
prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]]
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=prefixed_stats,
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
with self.update_aux_dict(
aux_dict={},
quantized_tensor=quantized_tensor,
quantizer=quantizer,
original_tensor=tensor,
) as aux_dict:
STATS_BUFFERS.feed(
layer_name,
tensor_name,
options,
tensor,
iteration,
skip_reduction,
aux_dict=aux_dict,
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
...@@ -443,3 +443,65 @@ for _columnwise in [True, False]: ...@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
add_underflows_stats(_recipe_name, _columnwise) add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise) add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise) add_mse_stats(_recipe_name, _columnwise)
# NVFP4-specific statistics
def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor:
"""Count the number of non-zero elements in the FP4 data.
FP4 data is stored as 2 4-bit values per byte (uint8).
We need to unpack and count non-zeros.
"""
# Each byte contains two FP4 values
# Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0)
zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8)
# Extract first and second nibbles
first_nibble = fp4_data % 16
second_nibble = fp4_data // 16
# Count zeros
first_zeros = torch.isin(first_nibble, zero_vals).sum()
second_zeros = torch.isin(second_nibble, zero_vals).sum()
total_elements = fp4_data.numel() * 2
return total_elements - first_zeros - second_zeros
def add_nvfp4_underflows_stats():
"""Register underflow stats for NVFP4.
Computes underflows by counting zeros in packed FP4 data vs original tensor.
"""
stat_num = "nvfp4_underflows_num"
stat_pct = "nvfp4_underflows%"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
# Count non-zeros in original vs FP4 packed data
STATS[stat_num] = (
lambda x, aux_dict: x.count_nonzero()
- count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data)
)
/ aux_dict["nvfp4"].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
# Register NVFP4 stats
add_nvfp4_underflows_stats()
add_mse_stats("nvfp4") # Reuse existing MSE function
...@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = { ...@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
} }
API_CALL_MODIFY = "modify_tensor()" API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize" STANDARD_QUANTIZE = "Quantize"
HIGH_PRECISION = "High Precision" HIGH_PRECISION = "High Precision"
...@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer): ...@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
# inspect_tensor*_enabled are bool fields, # inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls. # indicating whether some feature will need to run inspect_tensor_* calls.
# #
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION] # *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor. # determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"] self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor: if self.output_tensor:
...@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer): ...@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
def get_tensors_plan(self): def get_tensors_plan(self):
""" """
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors. of this quantizer with respect to these tensors.
""" """
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
...@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer): ...@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
rowwise_plan = API_CALL_MODIFY rowwise_plan = API_CALL_MODIFY
else: else:
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call( quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name, layer_name=self.layer_name,
gemm=self.rowwise_gemm_name, gemm=self.rowwise_gemm_name,
iteration=self.iteration, iteration=self.iteration,
) )
) )
if fp8_quantize: if quantize_enabled:
rowwise_plan = STANDARD_FP8_QUANTIZE rowwise_plan = STANDARD_QUANTIZE
if rowwise_plan is None: if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION rowwise_plan = HIGH_PRECISION
...@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer): ...@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
columnwise_plan = API_CALL_MODIFY columnwise_plan = API_CALL_MODIFY
else: else:
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call( quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name, layer_name=self.layer_name,
gemm=self.columnwise_gemm_name, gemm=self.columnwise_gemm_name,
iteration=self.iteration, iteration=self.iteration,
) )
) )
if fp8_quantize: if quantize_enabled:
columnwise_plan = STANDARD_FP8_QUANTIZE columnwise_plan = STANDARD_QUANTIZE
if columnwise_plan is None: if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION columnwise_plan = HIGH_PRECISION
...@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer): ...@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
del args["quantizer"] del args["quantizer"]
if ( if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise and self.inspect_tensor_postquantize_enabled_rowwise
): ):
args["tensor"] = rowwise_gemm_tensor args["tensor"] = rowwise_gemm_tensor
...@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer): ...@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
debug_api.transformer_engine.inspect_tensor_postquantize(**args) debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if ( if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise and self.inspect_tensor_postquantize_enabled_columnwise
): ):
args["tensor"] = columnwise_gemm_tensor args["tensor"] = columnwise_gemm_tensor
...@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer): ...@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
self.parent_quantizer.set_usage(rowwise=True) self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
quantized_tensor = self.parent_quantizer(tensor) quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, # if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized,
# one tensor with columnwise=True and rowwise=True is computed # one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it. # and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE: if self.rowwise_tensor_plan == STANDARD_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE: if self.columnwise_tensor_plan == STANDARD_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used. # 2. modify_tensor() is called, if it is used.
...@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer): ...@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
"""This call is invoked after the gemm to inspect and modify the output tensor.""" """This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True." assert self.parent_quantizer is None, "Quantized output is not supported for debug=True."
assert self.output_tensor assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"} tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY: if self.rowwise_tensor_plan == API_CALL_MODIFY:
...@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer): ...@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
): ):
return True return True
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: if self.rowwise_tensor_plan != STANDARD_QUANTIZE:
return True return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: if self.columnwise_tensor_plan != STANDARD_QUANTIZE:
return True return True
return False return False
...@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer): ...@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
if self.parent_quantizer is not None: if self.parent_quantizer is not None:
if ( if (
dst.rowwise_gemm_tensor is not None dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE and self.rowwise_tensor_plan == STANDARD_QUANTIZE
): ):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"): if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None) dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
...@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer): ...@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
updated_rowwise_gemm = True updated_rowwise_gemm = True
if ( if (
dst.columnwise_gemm_tensor is not None dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE and self.columnwise_tensor_plan == STANDARD_QUANTIZE
and not updated_rowwise_gemm and not updated_rowwise_gemm
): ):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"): if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
...@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer): ...@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
""" """
Updates the usage of the parent quantizer. Updates the usage of the parent quantizer.
""" """
rowwise_gemm_quantize = ( rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = ( columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE
) )
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage( self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize, rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize, columnwise=columnwise_gemm_quantize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment