"vscode:/vscode.git/clone" did not exist on "d904350a2c9a9fb2e476b45a486cc72fa6c2bd8f"
numerics_debug.py 908 Bytes
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# See LICENSE for license information.

"""Utilities for debugging numerical issues with FP8"""
from typing import Tuple
import torch
from transformer_engine.common import recipe

_NUMERICS_DEBUG = False


def debug(enabled: bool = True) -> None:
    """Set FP8 debug mode"""
    global _NUMERICS_DEBUG
    _NUMERICS_DEBUG = enabled


19
def fp8_tensor_statistics(tensor: torch.Tensor, fp8_format: str = "E4M3") -> Tuple[int, ...]:
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
25
26
27
28
29
30
31
32
    """Print FP8 tensor stats"""
    fp8_format = fp8_format.upper()
    assert fp8_format in (
        "E4M3",
        "E5M2",
    ), "fp8_format must be 'E4M3' or 'E5M2' for amax"

    fmt = recipe.Format[fp8_format]
    FP8_MAX = fmt.value.max_fwd

    num_overflows = (tensor == FP8_MAX).sum().item()
    num_underflows = (tensor == 0).sum().item()
    return (num_underflows, num_overflows)