Unverified Commit 77a00635 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 3d76218e
......@@ -18,7 +18,11 @@ from transformer_engine.pytorch import (
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_computation import (
compute_max_blockwise_dynamic_range,
BlockwiseDynamicRangeStat,
)
import math
fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True)
......@@ -154,7 +158,7 @@ fp8_recipes = [
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
......@@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs):
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
LOG_HIGH_PRECISION_CONFIG = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
stats:
- dynamic_range
- max_blockwise_dynamic_range:
block_size: 4
dims: 1
- max_blockwise_dynamic_range:
block_size: 4
dims: 2
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"])
def test_log_stats_numerics(feature_dirs, tensor_name):
"""Check correctness of dynamic range and max blockwise dynamic range stats.
Tests different tensor types:
- activation/weight: use both orientations (rowwise + columnwise), takes max
- gradient/dgrad: use single orientation (rowwise only)
"""
log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
# There is 1024 x 1024 tensor with very small epsilon values in almost all elements,
# one row of large value A and three rows of large value B.
epsilon = 1e-10
A = 1000
B = 50
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name=tensor_name,
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=None,
rowwise_quantized_tensor=None,
columnwise_quantized_tensor=None,
)
debug_api.step()
output = read_log(log_dir)
max_over_orientations = tensor_name in ["activation", "weight"]
max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else ""
# Track which stats were found to ensure all are present
found_dims_1 = False
found_dims_2 = False
found_dynamic_range = False
for line in output.splitlines():
if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line:
max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1])
if max_over_orientations:
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
expected = math.log2(A) - math.log2(B)
else:
# Rowwise blocks have uniform values -> dynamic_range = 0
expected = 0
assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx(
expected, abs=1e-4
)
found_dims_1 = True
elif (
f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line
):
max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1])
# For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows
expected = math.log2(A) - math.log2(B)
assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx(
expected, abs=1e-4
)
found_dims_2 = True
elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line:
dynamic_range = float(line.split("value=")[1])
expected = math.log2(A) - math.log2(epsilon)
assert dynamic_range == pytest.approx(expected, abs=1e-4)
found_dynamic_range = True
# Ensure all expected stats were found in the output
assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output"
assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output"
assert found_dynamic_range, "dynamic_range not found in output"
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
if not fp8_available:
......@@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
def test_compute_max_blockwise_dynamic_range_direct():
"""Direct unit test for compute_max_blockwise_dynamic_range function.
Tests the function with various configurations to ensure correct behavior
for different block sizes, dimensions, and orientation settings.
"""
# Create test tensor with uniform rows but mixed columns
# Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01
epsilon = 0.01
A = 1000.0
B = 50.0
tensor = torch.zeros(1024, 1024).cuda() + epsilon
tensor[0, :] = A
tensor[1:4, :] = B
# Test 1: dims=1, max_over_orientations=False (rowwise only)
# Rowwise blocks have uniform values -> dynamic_range should be 0
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
assert result.item() == pytest.approx(
0.0, abs=1e-4
), "Rowwise 1D blocks with uniform values should have dynamic_range=0"
# Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise)
# Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B)
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(B)
assert result.item() == pytest.approx(expected, abs=1e-4), (
f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got"
f" {result.item()}"
)
# Test 3: dims=2, block_size=4 (4x4 tiles)
# 2D blocks span multiple rows -> always have mixed values
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(B)
assert result.item() == pytest.approx(expected, abs=1e-4), (
f"2D blocks should capture mixed values from different rows, expected {expected}, got"
f" {result.item()}"
)
# Test 4: Different block size
# With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon]
# So max=A, min=epsilon (not B anymore)
stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(tensor, stat_config)
expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B
assert result.item() == pytest.approx(
expected, abs=1e-4
), f"Block size 8 should work correctly, expected {expected}, got {result.item()}"
# Test 5: Tensor with all uniform values -> dynamic_range should be 0
uniform_tensor = torch.ones(64, 64).cuda() * 42.0
stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True)
result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config)
assert result.item() == pytest.approx(
0.0, abs=1e-4
), "Uniform tensor should have dynamic_range=0"
# Test 6: 3D tensor flattening validation using 2D/3D comparison
# Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2
# Then reshape to 3D and compute again - results should match if flattening is correct
tensor_2d = torch.tensor(
[
[1.0, 1.0, 10.0, 10.0],
[1.0, 1.0, 10.0, 10.0],
[100.0, 100.0, 1000.0, 1000.0],
[100.0, 100.0, 1000.0, 1000.0],
]
).cuda()
# Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100)
stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False)
result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config)
# Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct
tensor_3d = tensor_2d.reshape(2, 2, 4)
result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config)
assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), (
"3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got"
f" 2D={result_2d.item()}, 3D={result_3d.item()}"
)
print("All direct tests for compute_max_blockwise_dynamic_range passed!")
......@@ -4,7 +4,7 @@
"""LogTensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Optional
from typing import Dict, Optional, List
import torch
......@@ -19,6 +19,10 @@ from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Floa
from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
from transformer_engine.debug.features.utils.stats_computation import (
add_max_blockwise_dynamic_range_stats,
BlockwiseDynamicRangeStat,
)
@Registry.register_feature(namespace="transformer_engine")
......@@ -44,7 +48,14 @@ class LogTensorStats(BaseLogTensorStats):
- l1_norm
- l2_norm
- cur_amax – maximal absolute value of a tensor,
- dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)`
- dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)`
- max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor.
If tensor and its transpose is needed in training, this stat is computed for both orientations and the maximum is returned.
For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile.
- block_size: int, default = 32
- dims: int, default = 1, allowed values are 1 and 2
tensors/tensors_struct: List[str]
list of tensors to log
......@@ -88,6 +99,60 @@ class LogTensorStats(BaseLogTensorStats):
stats: [dynamic_range]
"""
def _is_supported_stat(self, stat: str | Dict):
"""Returns True if the stat is supported by this feature, False otherwise."""
if isinstance(stat, dict):
stat_name = list(stat.keys())[0]
if stat_name == "max_blockwise_dynamic_range":
stat_dict = stat[stat_name]
if not isinstance(stat_dict, dict):
return False
# Ensure only supported keys are present
allowed_keys = {"block_size", "dims"}
if any(k not in allowed_keys for k in stat_dict.keys()):
return False
block_size = stat_dict.get("block_size", 32)
dims = stat_dict.get("dims", 1)
# Type and value validation
if not isinstance(block_size, int) or not isinstance(dims, int):
return False
if block_size > 0 and dims in [1, 2]:
return True
return False
return stat in BaseLogTensorStats._get_supported_stats_list(None) | {
"cur_amax",
"dynamic_range",
}
def _parse_max_blockwise_dynamic_range_stats(
self, stats: List[str | Dict], tensor_name: str
) -> List[str | BlockwiseDynamicRangeStat]:
"""
Adds all max_blockwise_dynamic_range stats to the stat computation logic.
Changes the types of the stats from Dict to BlockwiseDynamicRangeStat named tuple,
for other stats nothing is changed.
For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}],
it will be changed to [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=True)]
or [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=False)] depending on tensor_name.
"""
max_over_orientations = tensor_name in ["activation", "weight"]
parsed_stats = []
for stat in stats:
if isinstance(stat, dict):
block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32)
dims = stat["max_blockwise_dynamic_range"].get("dims", 1)
# Register stat and return the named tuple
parsed_stat = add_max_blockwise_dynamic_range_stats(
block_size, dims, max_over_orientations
)
parsed_stats.append(parsed_stat)
else:
parsed_stats.append(stat)
return parsed_stats
def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"}
......@@ -147,14 +212,16 @@ class LogTensorStats(BaseLogTensorStats):
)
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
assert self._is_supported_stat(
stat
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"], tensor_name)
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=config["stats"],
stats=stats,
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
......
......@@ -130,8 +130,12 @@ class _Buffer:
for stat_name in self.stats_to_log:
combiner = STATS[stat_name][1]
stat_value = combiner(gathered_helper_stats)
# Convert stat key to string for logging (uses __str__ for named tuples)
stat_name_str = str(stat_name)
MetricLogger.log_scalar(
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
f"{self.layer_name}_{self.tensor_name}_{stat_name_str}", stat_value, self.iteration
)
output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = (
stat_value # for debugging purposes
......
......@@ -7,12 +7,25 @@ Mathematical functions used to tensor statistics computation.
"""
import math
from collections import namedtuple
import torch
import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Format
class BlockwiseDynamicRangeStat(
namedtuple("BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"])
):
"""Named tuple representing a blockwise dynamic range statistic configuration."""
def __str__(self) -> str:
"""Convert to string representation for stat name. Used for logging."""
suffix = "_max_over_orientations" if self.max_over_orientations else ""
return f"max_blockwise_dynamic_range_block_size_{self.block_size}_dims_{self.dims}{suffix}"
@torch.compile
def _compute_dynamic_range_top(tensor):
"""Computes the log2 of the amax of the tensor"""
......@@ -26,6 +39,7 @@ def _compute_dynamic_range_top(tensor):
return torch.log2(amax)
@torch.compile
def _compute_dynamic_range_bottom(tensor):
"""Computes the log2 of the amin of the tensor"""
tensor_abs = tensor.abs()
......@@ -37,6 +51,76 @@ def _compute_dynamic_range_bottom(tensor):
return torch.log2(amin)
def compute_max_blockwise_dynamic_range(tensor, stat_config):
"""
Computes maximum blockwise dynamic range (log2 max/min_nonzero) within blocks.
Flattens tensor to 2D and computes maximum dynamic range within blocks. If max_over_orientations
is True, computes for both rowwise and columnwise orientations and returns the maximum,
capturing the worst-case scenario regardless of how the tensor is used in GEMM operations.
If False, computes only for rowwise orientation.
Returns 0 if all blocks are zeros, otherwise computes dynamic range over non-zero blocks.
Args:
tensor: Input tensor (will be flattened to 2D)
stat_config: BlockwiseDynamicRangeStat named tuple with:
- block_size: Size of blocks (int)
- dims: 1 for 1D blocks (consecutive elements), 2 for 2D blocks (tiles)
- max_over_orientations: If True, compute max over rowwise and columnwise orientations
"""
# Extract parameters from stat_config
block_size = stat_config.block_size
dims = stat_config.dims
max_over_orientations = stat_config.max_over_orientations
def _compute_for_one_orientation(tensor):
total_numel = tensor.numel()
assert dims in [1, 2], f"dims must be 1 or 2, got {dims}"
# torch.compile friendly code - standard ** power does not work with jit
total_block_size = block_size * block_size if dims == 2 else block_size
assert (
total_numel % total_block_size == 0
), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})."
tensor = tensor.abs().float()
if dims == 1:
tensor = tensor.reshape(-1, block_size)
per_block_amax = tensor.amax(dim=1)
per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1)
else:
# We want to have tensor of shape [nr_blocks, block_size, block_size],
# where each block is a block_size x block_size tile of the original tensor.
dim_y = tensor.shape[-1] // block_size
tensor = (
tensor.reshape(-1, block_size, dim_y, block_size)
.permute(0, 2, 1, 3)
.reshape(-1, block_size, block_size)
)
per_block_amax = tensor.amax(dim=(1, 2))
per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2))
# Identify blocks that contain any non-zero element
nonzero_blocks = per_block_amax != 0
dynamic_range_per_block = torch.where(
nonzero_blocks,
torch.log2(per_block_amax) - torch.log2(per_block_amin),
torch.zeros_like(per_block_amax, dtype=torch.float32),
)
return dynamic_range_per_block.max()
# Flatten to 2D
tensor_2d = tensor.reshape(-1, tensor.shape[-1])
if max_over_orientations:
return max(
_compute_for_one_orientation(tensor_2d), # Rowwise orientation
_compute_for_one_orientation(tensor_2d.transpose(-2, -1)), # Columnwise orientation
)
return _compute_for_one_orientation(tensor_2d)
@torch.compile
def compute_variance(variances, numels, sums):
"""Welford algorithm is used for numerically stable distributed variance computation."""
mean = torch.sum(sums) / torch.sum(numels)
......@@ -45,6 +129,7 @@ def compute_variance(variances, numels, sums):
return var
@torch.compile
def compute_std(variances, numels, sums):
"""Computates standard deviation."""
return torch.sqrt(compute_variance(variances, numels, sums))
......@@ -316,6 +401,37 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False):
DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"}
def add_max_blockwise_dynamic_range_stats(
block_size: int, dims: int, max_over_orientations: bool = False
):
"""Register max_blockwise_X_dynamic_range stats for the recipe.
Args:
block_size: Size of blocks for computing blockwise dynamic range
dims: 1 for 1D blocks, 2 for 2D blocks
max_over_orientations: Whether to compute max over rowwise and columnwise orientations
Returns:
BlockwiseDynamicRangeStat named tuple representing this stat (used as the stat key)
"""
# Use named tuple directly as the stat key - this is cleaner than string keys
stat_key = BlockwiseDynamicRangeStat(block_size, dims, max_over_orientations)
if stat_key in stats_to_num:
return stat_key # already registered
assert dims in [1, 2], f"dims must be 1 or 2, got {dims}"
stats_to_num[stat_key] = len(stats_to_num)
DEPENDENCIES[stat_key] = {stat_key}
STATS[stat_key] = (
lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range(x, _stat_key),
lambda buffers, _stat_key=stat_key: max(_get(buffers, _stat_key)),
)
return stat_key
for _columnwise in [True, False]:
for _recipe_name in [
"", # default recipe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment