Unverified Commit b33dd083 authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] NVIDIA-DL-Framework-Inspect support – part 2 – features (#1613)



* features drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* Update transformer_engine/debug/features/utils/stats_computation.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/disable_fp8_layer.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/log_fp8_tensor_stats.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/utils/stats_buffer.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/per_tensor_scaling.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/per_tensor_scaling.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/disable_fp8_gemm.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/per_tensor_scaling.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* Update transformer_engine/debug/features/per_tensor_scaling.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* changes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* temporarily removed saturations
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/debug/features/_test_dummy_feature.py
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* docs fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent b9e7b0b8
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Debug features."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Feature doing nothing, used for testing purposes."""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class TestDummyFeature(TEConfigAPIMapper):
"""
This is feature used only in tests. It invokes look_at_tensor_before_process
and does nothing.
If no features are used, then TE layer automatically switches to the non-debug mode.
This feature is invoked for each GEMM to prevent this behavior.
"""
@api_method
def inspect_tensor_enabled(self, *_args, **_kwargs):
"""API call used to determine whether to run look_at_tensor_before_process
in the forward pass."""
return True
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""API definition for nvidia-dlframework-inspect."""
import copy
from typing import Dict, Union
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry
import torch
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import all_tensor_types
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
class TEConfigAPIMapper(BaseConfigAPIMapper):
"""Class responsible for determining which NV DLFW Inspect API should be run for each tensor and gemm."""
def parse_config_and_api(self, config, **kwargs):
"""Process the config and returns True if the config and api args match, along with processed config."""
processed_config = None
config_copy = copy.deepcopy(config)
gemm_parsing = kwargs.get("gemm_parsing", False)
tensor_parsing = kwargs.get("tensor_parsing", False)
if gemm_parsing:
# parse with GEMM and/or tensor
processed_config = self._process_transformer_engine_config(config_copy, **kwargs)
elif tensor_parsing:
# parse with only tensor
processed_config = self._process_tensor_config(config_copy, kwargs["tensor_name"])
if not processed_config:
return False, None
if "enabled" in processed_config:
processed_config.pop("enabled")
return True, processed_config
def _validate_gemm(self, gemm):
assert gemm in ["fprop", "wgrad", "dgrad"], (
f"[NVTORCH INSPECT ERROR] Invalid gemm: {gemm}. It must be one of the ['fprop',"
" 'wgrad', 'dgrad']."
)
def _process_transformer_engine_config(self, config, **kwargs):
"""
Return config specific to a particular tensor name and gemm that matches the api args.
"""
if "gemms_struct" in config:
for cfg in config["gemms_struct"]:
self._validate_gemm(cfg["gemm"])
if cfg["gemm"] == kwargs["gemm"]:
if kwargs["tensor_parsing"]:
cfg = self._process_tensor_config(cfg, kwargs["tensor_name"])
if not cfg:
return None
cfg_copy = copy.deepcopy(cfg)
config.pop("gemms_struct")
assert (
"enabled" not in cfg_copy
), "[NVTORCH INSPECT ERROR] Enabled field should not be part of gemms_struct"
config.update(cfg_copy)
return config
return None
if "gemms" in config:
for gemm in config["gemms"]:
self._validate_gemm(gemm)
if kwargs["gemm"] in config["gemms"]:
if kwargs["tensor_parsing"]:
cfg = self._process_tensor_config(config, kwargs["tensor_name"])
if not cfg:
return None
config["gemm"] = kwargs["gemm"]
config.pop("gemms")
return config
return None
raise ValueError(
"[NVTORCH INSPECT ERROR] Provide 'gemms_struct: List[Dict]' or 'gemms: List[str]'"
" in the config yaml"
)
required_kwargs = {
"fp8_gemm_enabled": ["gemm"],
"modify_tensor_enabled": ["tensor_name", "gemm"],
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"],
"inspect_tensor_postquantize_enabled": ["tensor_name"],
"default": ["tensor_name", "gemm"],
}
# pylint: disable=unused-argument
class TEDefaultFeatures:
"""Transformer Engine API calls default behavior."""
def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: int) -> bool:
"""
If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is True
"""
return True # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
def modify_tensor_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
iteration: int,
) -> bool:
"""
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
return False
def modify_tensor(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
default_quantizer: Quantizer,
iteration: int,
out: Union[torch.Tensor, QuantizedTensor],
) -> Union[torch.Tensor, QuantizedTensor, None]:
"""
It allows tensor modification.
For example, feature `FakeQuant` uses it to emulate casting to FP8.
It can be invoked at most once for each tensor within a given GEMM operation.
This call is invoked if `modify_tensor_enabled` returns `True` and the feature is enabled for the *tensor_name* and *gemm*.
Then it is called **instead of** the default quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor: torch.Tensor
tensor in high precision,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
default_quantizer : Quantizer
quantizer which is used to cast the tensor to lower precision
if *modify_tensor* is not invoked. For example,
feature per tensor scale uses it to obtain FP8 dtype of the tensor.
If the recipe indicates that the tensor is not cast - for example,
if running without FP8 autocast, then `default_quantizer=None`,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
out: Union[torch.Tensor, QuantizedTensor]
output tensor, used in the weight caching mechanism.
Returns
-------
Union[torch.Tensor, transformer_engine.pytorch.QuantizerTensor, None]
can be `torch.Tensor` or one of the Transformer Engine's `QuantizedTensor` -
the rule is that both tensors returned for each GEMM should have the same type.
If both are `Float8Tensor`, then GEMM is run in FP8.
If both are `torch.Tensor`, GEMM is run in high precision.
Please take that into account especially if only one tensor of the GEMM
is processed by the `modify_tensor()`. For example, `FakeQuant`
disabled FP8 GEMM to ensure that the second tensor is also in high precision.
If the tensor is not the input for any GEMM - namely `output`,
`wgrad` and `dgrad` - the return type would match the input type.
Should return `None` if `out` is not `None`.
"""
raise NotImplementedError(
"modify_tensor_enabled() returned True, modify_tensor() was invoked, but it is not"
" handled by any API."
)
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
) -> None:
"""
The feature is invoked if *inspect_tensor_enabled* returns `True`. It can be used to obtain information on the high precision tensor. For example, it is run by the `LogTensorStats` feature.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in high precision,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
def inspect_tensor_postquantize(
self,
config: Dict,
layer_name: str,
tensor_name: str,
gemm: str,
tensor: torch.Tensor,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
) -> None:
"""
Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in fp8 or processed tensor after the modify_tensor call,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
process group for the tensor parallel group. This is used for weight statistics reduction.
This is not reduction group from debug_api.
Returns
-------
Should return nothing.
"""
def inspect_tensor_enabled(
self,
config: Dict,
layer_name: str,
tensor_name: str,
iteration: int,
) -> bool:
"""
It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`].
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
return False
def inspect_tensor_postquantize_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
tensor_name: str,
iteration: int,
) -> bool:
"""
It is a routing call, which is run at the initialization of the layer.
If it returns true, then *inspect_tensor_postquantize* for
a given GEMM and tensor will be invoked.
Parameters
----------
config: Dict
dictionary containing information from `config.yaml` corresponding to the feature, tensor_name and gemm.
layer_name: str
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
tensor_name: str
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
Returns
-------
bool - default is False
"""
return False
@Registry.register_namespace_api(namespace="transformer_engine")
class TransformerEngineAPI(BaseNamespaceAPI):
"""
Transformer Engine API class that contains default APIs that are invoked when a config is not provided
or a layer is not selected in the config.
TransformerEngine specific features must override these APIs wherever required.
The overridden APIs will be invoked whenever the corresponding feature is enabled in the config.
"""
def __init__(self):
BaseNamespaceAPI.__init__(self)
self._default_api_impl = TEDefaultFeatures()
self._cacheable_api_kwargs_map = {
"fp8_gemm": ["gemm"],
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"],
"inspect_tensor_postquantize_enabled": ["tensor_name"],
"modify_tensor_enabled": ["tensor_name"],
}
def is_multiple_feature_invocation_allowed(self, api_name):
"""
Check if API allows executing multiple features for a single call
"""
return api_name in {
"fp8_gemm_enabled",
"inspect_tensor",
"inspect_tensor_postquantize",
"inspect_tensor_enabled",
"inspect_tensor_postquantize_enabled",
}
def input_assertions_hook(self, api_name, **kwargs):
"""
These args must be passed as kwargs in the API call for all TransformerEngine specific APIs.
"""
if api_name in required_kwargs:
for kwarg in required_kwargs[api_name]:
assert kwarg in kwargs, (
f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
f" {api_name}."
)
else:
for kwarg in required_kwargs["default"]:
assert kwarg in kwargs, (
f"[NVTORCH INSPECT ERROR] Cannot route API, too ambiguous. Provide {kwarg} in"
f" {api_name}."
)
def routing_condition(self, api_name, config, _, feature_obj, **kwargs):
"""
Overridden APIs are selected based on the GEMM name in the config and kwargs.
"""
tensor_parsing = "tensor_name" in required_kwargs[api_name]
gemm_parsing = "gemm" in required_kwargs[api_name]
status, modified_config = feature_obj.parse_config_and_api(
config, gemm_parsing=gemm_parsing, tensor_parsing=tensor_parsing, **kwargs
)
return status, modified_config
def output_assertions_hook(self, api_name, ret, **kwargs):
"""Output hooks used to check correctness of the outputs of the API calls."""
if "enabled" in api_name or api_name == "fp8_gemm":
assert isinstance(ret, bool)
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None
if api_name == "modify_tensor":
assert type(ret) in all_tensor_types
if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
and "dtype" in kwargs
):
if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"]
def step(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.step()"""
STATS_BUFFERS.log_stats()
def end_debug(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState.reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.
Parameters
----------
gemms: List[str]
list of gemms to disable
- fprop
- dgrad
- wgrad
Example
-------
.. code-block:: yaml
example_disable_fp8_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [dgrad, wgrad]
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""
for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
@Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer:
"""
Disables all FP8 GEMMs in the layer.
Example
-------
.. code-block:: yaml
example_disable_fp8_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
enabled: True
"""
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api.log_message("FP8 Disabled", layer_name)
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.
Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FakeQuant Feature support for nvidia-dlframework-inspect"""
from typing import Optional
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
from nvdlfw_inspect.utils import append_parent_docstring
import transformer_engine_torch as tex
from transformer_engine.debug.features.api import TEConfigAPIMapper
from transformer_engine.common.recipe import Format
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.fp8 import _default_sf_compute
def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None):
"""Input tensor is quantized to fp8 and then dequantized."""
assert tensor.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
), "[NVTORCH INSPECT ERROR] Unsupported tensor type."
assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor."
assert fp8_format in {
"FP8E4M3",
"FP8E5M2",
"MXFP8E4M3",
"MXFP8E5M2",
}, (
"[NVTORCH INSPECT ERROR] Only 4 FP8 types: FP8E4M3, FP8E5M2, MXFP8E4M3, MXFP8E5M2 are"
" supported in TE."
)
if fp8_format in ["FP8E4M3", "FP8E5M2"]:
if fp8_format == "FP8E4M3":
fp8_max = Format.E4M3.value.max_fwd
fp8_dtype = tex.DType.kFloat8E4M3
else:
fp8_max = Format.E5M2.value.max_fwd
fp8_dtype = tex.DType.kFloat8E5M2
amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device)
scale = _default_sf_compute(amax, one, fp8_max)
quantizer = Float8Quantizer(scale, amax, fp8_dtype)
else:
quantizer = MXFP8Quantizer(fp8_dtype=fp8_format)
if out is not None:
out.copy_(quantizer(tensor).dequantize())
return None
return quantizer(tensor).dequantize()
@Registry.register_feature(namespace="transformer_engine")
@append_parent_docstring(parent=TEConfigAPIMapper)
class FakeQuant(TEConfigAPIMapper):
"""
Disables FP8 GEMM. Fake quantizes chosen tensors to FP8 - using per-tensor scaling factor, not delayed scaling - and runs high-precision GEMM.
.. figure:: ./img/fake_quant.svg
:align: center
Fig 1: Comparison of FP8 FPROP GEMM with the same GEMM in BF16 with fake quantization of activation tensor. Green tensors have the same values, but different dtypes.
Parameters
----------
gemms/gemms_struct: List[str]
list of gemms to fake quantize
- fprop
- dgrad
- wgrad
tensors/tensors_struct: List[str]
list of tensors to fake quantize
- activation
- gradient
- weight
- output
- wgrad
- dgrad
quant_format: str
specifies the FP8 format to use:
- FP8E5M2
- FP8E4M3
Example
-------
.. code-block:: yaml
example_fake_quant_fp8:
enabled: True
layers:
layer_types: [transformer_layer.layernorm_mlp.fc1]
transformer_engine:
FakeQuant:
enabled: True
quant_format: FP8E5M2
gemms_struct:
- gemm: fprop
tensors: [activation, weight]
- gemm: dgrad
tensors: [gradient]
"""
def _supported_formats(self):
"""Returns formats that one can fake quantize tensor to."""
return ["FP8E4M3", "FP8E5M2", "MXFP8E4M3", "MXFP8E5M2"]
@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False
@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward."""
return True
@api_method
def modify_tensor(
self,
config,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
default_quantizer: Quantizer,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
): # pylint: disable=unused-argument
"""API call used to process the tensor."""
for key in config.keys():
if key not in ["gemm", "tensor", "quant_format"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
if "quant_format" not in config:
raise ValueError(
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor:"
f" quant_format missing for Tensor: {tensor_name} in the config yaml for"
" FakeQuant feature which is a required field"
)
if config["quant_format"] not in self._supported_formats():
raise ValueError(
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor:"
f" quant_format: {config['quant_format']} for Tensor: {tensor_name} in the config"
" yaml for FakeQuant feature is not supported"
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=process_tensor: {gemm}, {tensor_name}",
layer_name,
extra_cachable_args=(gemm, tensor_name),
)
quant_format = config["quant_format"]
q_tensor = fake_quantize(tensor, quant_format, out=out)
if dtype is not None:
q_tensor = q_tensor.to(dtype)
return q_tensor
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogFp8TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
@Registry.register_feature(namespace="transformer_engine")
class LogFp8TensorStats(BaseLogTensorStats):
"""
This feature handles logging of FP8 tensor stats.
In a distributed setting, the auxiliary stats are computed on each rank and gathered after
the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log
stats!
`LogFp8TensorStats` supports micro-batching. If multiple forward/backward passes are invoked
per `debug_api.step()`, then stats for all tensors except weights will be accumulated.
`LogFp8TensorStats` can induce significant overhead. To mitigate this issue, logging stats
with `freq > 1` is recommended. If `LogFp8TensorStats` is not used in a given step, the
overhead is smaller. If no other feature is used for the layer, the TE layer will
run as fast as it would without `debug_api` initialized.
Parameters
----------
stats: List[str]
list of statistics to log
- underflows% - percentage of elements of the tensor equal to 0,
tensors/tensors_struct: List[str]
list of tensors to log
- activation
- gradient
- weight
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_fp8_tensor_stat_collection:
enabled: True
layers:
layer_types: [layernorm_linear]
transformer_engine:
LogFp8TensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%]
freq: 1
- tensor: gradient
stats: [underflows%]
freq: 5
start_step: 0
end_step: 80
"""
def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return {"underflows%"}
@api_method
def inspect_tensor_postquantize_enabled(
self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor_postquantize() in the forward."""
# check whether logging should happen in this iteration
return self._check_params(config, layer_name, iteration=iteration)
@api_method
def inspect_tensor_postquantize(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
rowwise: bool,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using"
" log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors."
)
# This API can be invoked twice - with the tensor and with the transpose.
# We want to collect the stats once.
if not rowwise:
return # tensor was already seen rowwise in the other gemm
tensor = tensor._data
options = (
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
"fp8",
)
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=config["stats"],
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor_postquantize: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""LogTensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union
import torch
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
@Registry.register_feature(namespace="transformer_engine")
class LogTensorStats(BaseLogTensorStats):
"""
This feature handles the logging of basic tensor statistics.
For a distributed setting, the auxiliary stats are computed for each node and gathered after the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log stats!
`LogTensorStats` supports micro-batching. If multiple forward/backward passes are invoked per `debug_api.step()`, then stats for all tensors except weights will be accumulated.
`LogTensorStats` can induce significant overhead. To mitigate this issue, logging stats with `freq > 1` is recommended. If `LogTensorStats` is not used in a given step, the overhead is smaller. Moreover, if no other feature is used for the layer, the TE layer will run as fast as it would without `debug_api` initialized.
Parameters
----------
stats: List[str]
list of statistics to log
- min
- max
- mean
- std
- l1_norm
- l2_norm
- cur_amax – maximal absolute value of a tensor,
- dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)`
tensors/tensors_struct: List[str]
list of tensors to log
- activation
- gradient
- weight
- output
- wgrad
- dgrad
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
start step of logging stats
end_step: Optional[int], default = None
end step of logging stats
start_end_list: Optional[list([int, int])], default = None
non-overlapping list of (start, end) pairs in incremental order. If not None, will ignore start_step and end_step
Example
-------
.. code-block:: yaml
example_tensor_stat_collection:
enabled: True
layers:
layer_name_regex_pattern: .*(fc1|self_attention).*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [mean]
freq: 10
start_step: 5
end_step: 100
- tensor: gradient
stats: [mean, max, min]
freq: 2
start_end_list: [[0, 20], [80, 100]]
- tensor: weight
stats: [dynamic_range]
"""
def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"}
@api_method
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run look_at_tensor_before_process() in the forward."""
return self._check_params(config, layer_name, iteration=iteration)
@api_method
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
iteration: int,
tp_group: torch.distributed.ProcessGroup,
):
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert (
type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase]
and tensor.dtype != torch.uint8
), (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using"
" log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors."
)
options = (
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
)
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
tensor_name=tensor_name,
stats=config["stats"],
options=options,
reduction_group=reduction_group,
reduce_within_microbatch=reduce_within_microbatch,
)
STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=look_at_tensor_before_process: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name),
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PerTensorScaling Feature support for nvidia-dlframework-inspect"""
from typing import Optional
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
import transformer_engine_torch as tex
from transformer_engine.pytorch.tensor import Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Tensor,
Float8CurrentScalingQuantizer,
)
from transformer_engine.debug.features.api import TEConfigAPIMapper
def per_tensor_cast(
tensor: torch.Tensor, fp8_dtype: tex.DType, out: Float8Tensor = None
) -> Float8Tensor:
"""
This function computes the scaling factors based on the tensor amax and then casts it to the fp8
"""
assert tensor.dtype in (
torch.float,
torch.float16,
torch.bfloat16,
), "[NVTORCH INSPECT ERROR] Unsupported tensor type for per tensor current scaling"
assert tensor.is_cuda, "[NVTORCH INSPECT ERROR] Must be a GPU tensor."
assert fp8_dtype in {
tex.DType.kFloat8E4M3,
tex.DType.kFloat8E5M2,
}, "[NVTORCH INSPECT ERROR] Only 2 FP8 types: E4M3 and E5M2 are supported in TE."
tensor = tensor.contiguous()
quantizer = Float8CurrentScalingQuantizer(fp8_dtype)
if out is not None:
quantizer.update_quantized(tensor, out)
return None
return quantizer(tensor)
@Registry.register_feature(namespace="transformer_engine")
class PerTensorScaling(TEConfigAPIMapper):
"""
Allows using per-tensor current scaling for the specific tensors.
Can be used only within `DelayedScaling` recipe autocast.
Parameters
----------
gemms/gemms_struct: List[str]
list of gemms to enable per-tensor current scaling for
- fprop
- dgrad
- wgrad
tensors/tensors_struct: List[str]
list of tensors to enable per-tensor current scaling for
- activation
- gradient
- weight
Example
-------
.. code-block:: yaml
example_per_tensor_scaling:
enabled: True
layers:
layer_types: [transformer_layer.self_attn.layernorm_q]
transformer_engine:
PerTensorScaling:
enabled: True
margin: 1
gemms: [dgrad]
tensors: [weight, activation]
"""
@api_method
def fp8_gemm(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False
@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward."""
return True
@api_method
def modify_tensor(
self,
config,
layer_name: str,
gemm: str,
tensor_name: str,
tensor: torch.Tensor,
iteration: int,
default_quantizer: Quantizer,
out: Optional[Float8Tensor] = None,
dtype: Optional[torch.dtype] = None,
): # pylint: disable=unused-argument
"""API call used to process the tensor."""
for key in config.keys():
if key not in ["gemm", "tensor"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
assert isinstance(default_quantizer, Float8CurrentScalingQuantizer), (
f"[NVTORCH INSPECT ERROR] Feature={self.__class__.__name__}, API=process_tensor: "
"Per-tensor current scaling can be used only within `DelayedScaling` recipe autocast."
f" {layer_name}"
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=process_tensor: {gemm}, {tensor_name}",
layer_name,
extra_cachable_args=(gemm, tensor_name),
)
fp8_tensor = per_tensor_cast(tensor, default_quantizer.dtype, out=out)
return fp8_tensor
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utils for the debug features.
"""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Buffer used for LogTensorStats and LogFp8TensorStats features.
Buffers are fed with tensors, they compute necessary stats and save them.
When log() is called, they gather stats from all nodes, compute combined final stats and log them.
"""
from collections import defaultdict
import torch
from nvdlfw_inspect.utils import gather_along_first_dim
from nvdlfw_inspect.logging import MetricLogger
from transformer_engine.debug.features.utils.stats_computation import (
STATS,
DEPENDENCIES,
stats_to_num,
)
class _Buffer:
"""
Buffer stores temporary statistics for one tensor for one layer.
It also can synchronize between nodes and log final stats.
"""
def __init__(self, layer_name, tensor_name, stats, reduction_group, reduce_within_microbatch):
self.layer_name = layer_name
self.tensor_name = tensor_name
self.reduction_group = reduction_group
self.reduce_within_microbatch = reduce_within_microbatch
self.stats_to_log = stats
self.stats_to_compute = set()
for stat in stats:
self.stats_to_compute = self.stats_to_compute | DEPENDENCIES[stat]
self._buffer = torch.zeros(len(STATS), dtype=torch.float32).cuda()
self._new_buffer = self._buffer.clone()
self._tmp_buffer = self._buffer.clone()
# in case of data parallelism it is possible that layer will not be run on one node
# modified is set to True if node is run
# we do not take not run nodes into account
self.modified = torch.tensor([False], dtype=torch.bool).cuda()
self.iteration = None
self.skip_reduction = False
def _reset_before_next_step(self):
"""Resets the state after the logging."""
self.modified[0] = False
def _gather_helper_stats(self) -> torch.Tensor:
"""
If tensor stats should be accumulated among many nodes,
this method gathers all stats from the nodes where the stat was modified.
"""
if self.skip_reduction:
return self._buffer.unsqueeze(0)
mask = gather_along_first_dim(self.modified, process_group=self.reduction_group)[0]
gathered_buffer, _ = gather_along_first_dim(
self._buffer.unsqueeze(0), process_group=self.reduction_group
)
return gathered_buffer[mask.to(bool)]
def feed(self, tensor, iteration):
"""
feed() is used to add tensor for computing the statistics.
Because of the microbatching, feed() can be used multiple
times for one log().
The main reason of this design: need to combine results for already processed
tensors with the result of the new tensor.
"""
self.iteration = iteration
# If the stats are not reduced within microbatch and
# buffer was fed, we do not change the stats for the tensor.
# It is used for weights and microbatching.
if self.modified[0] and not self.reduce_within_microbatch:
return
# save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute:
fn, _ = STATS[stat_name]
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor)
# [num_buffers, num_stats]
buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0)
for stat_name in self.stats_to_compute:
fn, combinator = STATS[stat_name]
if self.modified[0]:
self._new_buffer[stats_to_num[stat_name]] = combinator(buffers)
else:
fn = STATS[stat_name][0]
self._new_buffer[stats_to_num[stat_name]] = fn(tensor)
self._buffer.copy_(self._new_buffer)
self.modified[0] = True
def log(self):
"""
Log the tensor stats and resets buffer.
"""
# [num_active_nodes, num_stats]
gathered_helper_stats = self._gather_helper_stats()
if not self.modified[0]:
return {}
output = {}
for stat_name in self.stats_to_log:
combiner = STATS[stat_name][1]
stat_value = combiner(gathered_helper_stats)
MetricLogger.log_scalar(
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
)
output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = (
stat_value # for debugging purposes
)
self._reset_before_next_step()
return output
class StatsBuffers:
"""
StatsBuffers class represents all buffers of the statistics for all tensors.
It is used to feed the tensors to the correct buffers.
"""
def __init__(self):
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
def reset(self):
"""Resets all buffers."""
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
def try_add_buffer(
self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch
):
"""If buffer for such combination of stats/tensor_name/... is not present, this method creates it."""
if (layer_name, tensor_name, options) in self.buffers:
return
buffer = _Buffer(layer_name, tensor_name, stats, reduction_group, reduce_within_microbatch)
self.buffers[(layer_name, tensor_name, options)] = buffer
self.reduction_group_to_buffer[reduction_group].append(buffer)
def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction):
"""Feeds the tensor into the respective buffer."""
buffer = self.buffers[(layer_name, tensor_name, options)]
buffer.feed(tensor, iteration)
buffer.skip_reduction = skip_reduction
def log_stats(self):
"""Logs the stats from all the buffers."""
output = {}
for reduction_group, buffers in self.reduction_group_to_buffer.items():
changed_buffers = [
(i, buffer)
for i, buffer in enumerate(buffers)
if gather_along_first_dim(
buffer.modified.unsqueeze(0), process_group=reduction_group
)[0].any()
]
for _, buffer in changed_buffers:
stats = buffer.log()
output.update(stats)
return output
STATS_BUFFERS = StatsBuffers()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Mathematical functions used to tensor statistics computation.
"""
import math
import torch
MAX_FP8_VALUE_INT8 = 126
@torch.compile
def _compute_dynamic_range_top(tensor):
"""Computes the log2 of the amax of the tensor"""
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
amax = tensor_abs.max().float()
if not amax.all():
amax = torch.tensor(1, device=tensor.device).to(torch.float)
return torch.log2(amax)
def _compute_dynamic_range_bottom(tensor):
"""Computes the log2 of the amin of the tensor"""
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
if tensor_abs.any():
amin = tensor_abs.min().float()
else:
amin = torch.tensor(1, device=tensor.device).to(torch.float)
return torch.log2(amin)
def compute_variance(variances, numels, sums):
"""Welford algorithm is used for numerically stable distributed variance computation."""
mean = torch.sum(sums) / torch.sum(numels)
means = sums / numels
var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels)
return var
def compute_std(variances, numels, sums):
"""Computates standard deviation."""
return torch.sqrt(compute_variance(variances, numels, sums))
# buffers is tensor of shape [nr_buffers, nr_stats]
def _get(buffers, stat_name):
stat_nr = stats_to_num[stat_name]
return buffers[:, stat_nr]
stats_to_num = {
"min": 0,
"max": 1,
"sum": 2,
"mean": 3,
"numel": 4,
"l1_norm": 5,
"l2_norm_square": 6,
"l2_norm": 7,
"variance": 8,
"cur_amax": 9,
"dynamic_range_top": 10,
"dynamic_range_bottom": 11,
"underflows_num": 12,
"std": 13,
"dynamic_range": 14,
"underflows%": 15,
}
DEPENDENCIES = {
"min": {"min"},
"max": {"max"},
"sum": {"sum"},
"mean": {"sum", "numel"},
"numel": {"numel"},
"l1_norm": {"l1_norm"},
"l2_norm_square": {"l2_norm_square", "numel"},
"l2_norm": {"l2_norm_square"},
"variance": {"variance", "numel", "sum"},
"cur_amax": {"cur_amax"},
"dynamic_range_top": {"dynamic_range_top"},
"dynamic_range_bottom": {"dynamic_range_bottom"},
"underflows_num": {"underflows_num"},
"std": {"variance", "numel", "sum"},
"dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"},
"underflows%": {"underflows_num", "numel"},
}
STATS = {
"min": (torch.min, lambda buffers: min(_get(buffers, "min"))),
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"numel": (lambda x: x.numel(), lambda buffers: sum(_get(buffers, "numel"))),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l2_norm_square": (
lambda x: torch.sum(x**2),
lambda buffers: sum(_get(buffers, "l2_norm_square")),
),
"l2_norm": (
lambda x: torch.norm(x, p=2),
lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))),
),
"variance": (
torch.var,
lambda buffers: compute_variance(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"dynamic_range_top": (
_compute_dynamic_range_top,
lambda buffers: max(_get(buffers, "dynamic_range_top")),
),
"dynamic_range_bottom": (
_compute_dynamic_range_bottom,
lambda buffers: min(_get(buffers, "dynamic_range_bottom")),
),
"underflows_num": (
lambda x: (x._data == 0).sum(),
lambda buffers: sum(_get(buffers, "underflows_num")),
),
"std": (
torch.std,
lambda buffers: compute_std(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"dynamic_range": (
lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda buffers: max(_get(buffers, "dynamic_range_top"))
- min(_get(buffers, "dynamic_range_bottom")),
),
"underflows%": (
lambda x: (x == 0).sum() / x.numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment