Unverified Commit 74faf7ec authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] NVFP4 debug stats support (#2296)



* init
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* year update in license
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 29b84c16
......@@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
- run selected GEMMs in higher precision,
- run current scaling - with one scaling factor per tensor - for particular GEMMs,
- test new precisions and integrate them with FP8 training,
- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
- ... and many more.
There are 4 things one needs to do to use Transformer Engine debug features:
......
......@@ -8,7 +8,10 @@ Debug features
.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
\ No newline at end of file
......@@ -15,6 +15,7 @@ from transformer_engine.pytorch import (
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_nvfp4_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
......@@ -29,6 +30,7 @@ mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)
LOG_QUANTIZED_CONFIG_BASE = """
log:
......@@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset()
# NVFP4 tests
LOG_NVFP4_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
def test_nvfp4_numeric(feature_dirs):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")
with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState
recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
num_quantizers=3,
)
# Create test tensor with known distribution
torch.manual_seed(42)
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Add some small values that should underflow to zero in FP4
tensor[0, :16] = 0.0001
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
# Validate both stats are present
assert "nvfp4_underflows%" in output, "underflows% stat missing"
assert "nvfp4_mse" in output, "mse stat missing"
# Extract values and validate numerics
underflows_value = None
mse_value = None
for line in output.splitlines():
if "nvfp4_underflows%" in line and "value=" in line:
underflows_value = float(line.split("value=")[1].split()[0])
if "nvfp4_mse" in line and "value=" in line:
mse_value = float(line.split("value=")[1].split()[0])
# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask = tensor != 0
dequant_zero_mask = dequantized_tensor == 0
expected_underflows = (
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
)
# Allow some tolerance
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)
# Compute expected MSE
expected_mse = torch.nn.functional.mse_loss(
dequantized_tensor.float(), tensor.float(), reduction="mean"
)
assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)
def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)
# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")
with debug_session(log_fp8_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for _ in range(2):
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert "mxfp8_mse" in output
def test_log_grouped_gemm(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
......
......@@ -2,17 +2,28 @@
#
# See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper):
class DisableFP8GEMM(DisableQuantizationGEMM):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Parameters
----------
......@@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def __init__(self, *args, **kwargs):
warnings.warn(
"DisableFP8GEMM is deprecated. "
"Use DisableQuantizationGEMM instead, which works with all quantization "
"formats (FP8, NVFP4, etc.).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
......@@ -2,17 +2,27 @@
#
# See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""
import warnings
from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer:
class DisableFP8Layer(DisableQuantizationLayer):
"""
Disables all FP8 GEMMs in the layer.
.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.
Example
-------
......@@ -20,36 +30,19 @@ class DisableFP8Layer:
example_disable_fp8_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api.log_message("FP8 Disabled", layer_name)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
def __init__(self, *args, **kwargs):
warnings.warn(
"DisableFP8Layer is deprecated. "
"Use DisableQuantizationLayer instead, which works with all quantization "
"formats (FP8, NVFP4, etc.).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationGEMM(TEConfigAPIMapper):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Parameters
----------
gemms: List[str]
list of gemms to disable quantization for
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableQuantizationLayer Feature support for nvidia-dlframework-inspect"""
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationLayer:
"""
Disables all quantized GEMMs in the layer, forcing high-precision execution.
Works with any quantization format (FP8, NVFP4, etc.).
Example
-------
.. code-block:: yaml
example_disable_quantization_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationLayer:
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and quantized GEMM execution.
Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If quantized training, disable quantization for the selected layers if this feature is enabled.
debug_api.log_message("Quantization Disabled", layer_name)
# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API.
DisableQuantizationLayer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarily responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
......@@ -9,12 +9,13 @@ from contextlib import contextmanager
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine_torch as tex
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
......@@ -22,7 +23,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import (
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
try:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
_nvfp4_available = True
except ImportError:
_nvfp4_available = False
NVFP4Quantizer = None
ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
......@@ -39,6 +47,8 @@ def _get_recipe_name(quantizer: Optional[Quantizer]):
return "mxfp8"
if isinstance(quantizer, Float8BlockQuantizer):
return "fp8_block_scaling"
if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer):
return "nvfp4"
raise ValueError(f"Unsupported quantizer type: {type(quantizer)}")
......@@ -164,6 +174,16 @@ class LogFp8TensorStats(BaseLogTensorStats):
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")
# Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work)
# But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer
if recipe_from_stat == "nvfp4":
raise ValueError(
f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats."
" FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for"
" NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.,"
" 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons."
)
if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError(
f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for"
......@@ -189,6 +209,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string."""
columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat:
......@@ -213,7 +234,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype = None
fp8_dtype = tex.DType.kFloat8E4M3
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
......@@ -282,6 +303,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
quantized_tensor = rowwise_quantized_tensor
assert isinstance(
quantized_tensor, QuantizedTensor
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
......
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogNvfp4TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Optional
from contextlib import contextmanager
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
@Registry.register_feature(namespace="transformer_engine")
class LogNvfp4TensorStats(BaseLogTensorStats):
"""Logs statistics of NVFP4 quantized tensors.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting NVFP4 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
List of statistics to collect. Available stats:
- underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data)
- mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_nvfp4_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogNvfp4TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%, mse]
freq: 1
- tensor: gradient
stats: [underflows%, mse]
freq: 5
start_step: 0
end_step: 80
"""
def check_if_stat_is_supported(self, stat: str):
"""Returns True if stat is supported, raises ValueError otherwise."""
supported_stats = [
"underflows%",
"mse",
]
if stat not in supported_stats:
raise ValueError(
f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}"
)
return True
def get_stat_with_prefix(self, stat: str) -> str:
"""Add nvfp4_ prefix to stat name for use in stats_computation."""
return f"nvfp4_{stat}"
@contextmanager
def update_aux_dict(
self,
aux_dict: Dict,
quantized_tensor: QuantizedTensor,
quantizer: Quantizer, # pylint: disable=unused-argument
original_tensor: torch.Tensor,
):
"""
Updates the aux_dict with the quantized tensor and additional NVFP4-specific data.
Yields the aux_dict.
"""
aux_dict = {
"nvfp4": quantized_tensor,
"original_tensor": original_tensor,
}
try:
yield aux_dict
finally:
pass
@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor() in the forward."""
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
tp_group,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer."
quantized_tensor = rowwise_quantized_tensor
# Ensure we're working with NVFP4 tensors
if not isinstance(quantizer, NVFP4Quantizer):
raise ValueError(
"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, "
f"but got {type(quantizer).__name__}"
)
assert isinstance(quantized_tensor, NVFP4TensorStorage), (
"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a"
" NVFP4TensorStorage."
)
for stat in config["stats"]:
self.check_if_stat_is_supported(stat)
start_step = config.get("start_step", None)
end_step = config.get("end_step", None)
start_end_list = config.get("start_end_list", None)
if start_end_list is not None:
start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list)
options = (
start_step,
end_step,
start_end_list,
"nvfp4",
)
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
)
# Add nvfp4_ prefix to all stats for internal use
prefixed_stats = [self.get_stat_with_prefix(stat) for stat in config["stats"]]
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=prefixed_stats,
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
with self.update_aux_dict(
aux_dict={},
quantized_tensor=quantized_tensor,
quantizer=quantizer,
original_tensor=tensor,
) as aux_dict:
STATS_BUFFERS.feed(
layer_name,
tensor_name,
options,
tensor,
iteration,
skip_reduction,
aux_dict=aux_dict,
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
......@@ -443,3 +443,65 @@ for _columnwise in [True, False]:
add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise)
# NVFP4-specific statistics
def count_nonzero_nvfp4(fp4_data: torch.Tensor) -> torch.Tensor:
"""Count the number of non-zero elements in the FP4 data.
FP4 data is stored as 2 4-bit values per byte (uint8).
We need to unpack and count non-zeros.
"""
# Each byte contains two FP4 values
# Value 0 in FP4 E2M1 format is represented as 0 (and also 8 for -0.0)
zero_vals = torch.tensor([0, 8], device=fp4_data.device, dtype=torch.uint8)
# Extract first and second nibbles
first_nibble = fp4_data % 16
second_nibble = fp4_data // 16
# Count zeros
first_zeros = torch.isin(first_nibble, zero_vals).sum()
second_zeros = torch.isin(second_nibble, zero_vals).sum()
total_elements = fp4_data.numel() * 2
return total_elements - first_zeros - second_zeros
def add_nvfp4_underflows_stats():
"""Register underflow stats for NVFP4.
Computes underflows by counting zeros in packed FP4 data vs original tensor.
"""
stat_num = "nvfp4_underflows_num"
stat_pct = "nvfp4_underflows%"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
# Count non-zeros in original vs FP4 packed data
STATS[stat_num] = (
lambda x, aux_dict: x.count_nonzero()
- count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
x.count_nonzero() - count_nonzero_nvfp4(aux_dict["nvfp4"]._rowwise_data)
)
/ aux_dict["nvfp4"].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
# Register NVFP4 stats
add_nvfp4_underflows_stats()
add_mse_stats("nvfp4") # Reuse existing MSE function
......@@ -36,7 +36,7 @@ _tensor_to_gemm_names_map = {
}
API_CALL_MODIFY = "modify_tensor()"
STANDARD_FP8_QUANTIZE = "FP8 Quantize"
STANDARD_QUANTIZE = "Quantize"
HIGH_PRECISION = "High Precision"
......@@ -88,7 +88,7 @@ class DebugQuantizer(Quantizer):
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self.output_tensor = tensor_name in ["output", "wgrad", "dgrad"]
if self.output_tensor:
......@@ -170,7 +170,7 @@ class DebugQuantizer(Quantizer):
def get_tensors_plan(self):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
API_CALL_MODIFY, STANDARD_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import nvdlfw_inspect.api as debug_api
......@@ -191,16 +191,16 @@ class DebugQuantizer(Quantizer):
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if quantize_enabled:
rowwise_plan = STANDARD_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
......@@ -218,16 +218,16 @@ class DebugQuantizer(Quantizer):
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
quantize_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled( # API name kept for compatibility
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if quantize_enabled:
columnwise_plan = STANDARD_QUANTIZE
if columnwise_plan is None:
columnwise_plan = HIGH_PRECISION
......@@ -278,7 +278,7 @@ class DebugQuantizer(Quantizer):
del args["quantizer"]
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
):
args["tensor"] = rowwise_gemm_tensor
......@@ -286,7 +286,7 @@ class DebugQuantizer(Quantizer):
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
):
args["tensor"] = columnwise_gemm_tensor
......@@ -317,14 +317,14 @@ class DebugQuantizer(Quantizer):
self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be quantized,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE:
if self.rowwise_tensor_plan == STANDARD_QUANTIZE:
rowwise_gemm_tensor = quantized_tensor
if self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE:
if self.columnwise_tensor_plan == STANDARD_QUANTIZE:
columnwise_gemm_tensor = quantized_tensor
# 2. modify_tensor() is called, if it is used.
......@@ -379,7 +379,7 @@ class DebugQuantizer(Quantizer):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import nvdlfw_inspect.api as debug_api
assert self.parent_quantizer is None, "FP8 output is not supported for debug=True."
assert self.parent_quantizer is None, "Quantized output is not supported for debug=True."
assert self.output_tensor
tensor_to_gemm = {"output": "fprop", "wgrad": "wgrad", "dgrad": "dgrad"}
if self.rowwise_tensor_plan == API_CALL_MODIFY:
......@@ -420,9 +420,9 @@ class DebugQuantizer(Quantizer):
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
if self.rowwise_tensor_plan != STANDARD_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
if self.columnwise_tensor_plan != STANDARD_QUANTIZE:
return True
return False
......@@ -446,7 +446,7 @@ class DebugQuantizer(Quantizer):
if self.parent_quantizer is not None:
if (
dst.rowwise_gemm_tensor is not None
and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
and self.rowwise_tensor_plan == STANDARD_QUANTIZE
):
if hasattr(dst.rowwise_gemm_tensor, "quantize_"):
dst.rowwise_gemm_tensor.quantize_(src, noop_flag=None)
......@@ -455,7 +455,7 @@ class DebugQuantizer(Quantizer):
updated_rowwise_gemm = True
if (
dst.columnwise_gemm_tensor is not None
and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
and self.columnwise_tensor_plan == STANDARD_QUANTIZE
and not updated_rowwise_gemm
):
if hasattr(dst.columnwise_gemm_tensor, "quantize_"):
......@@ -540,14 +540,12 @@ class DebugQuantizer(Quantizer):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
rowwise_gemm_quantize = self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_QUANTIZE
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_QUANTIZE
)
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
if STANDARD_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment