Unverified Commit b6b3abce authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch debug] Improve precision debug tools performance (#1909)



* turn on userbuffers for layers without debug
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* working change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests and fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* update nvinspect version
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix ci
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9f9b4816
......@@ -14,11 +14,19 @@
FAIL=0
# It is not installed as a requirement,
# because it is not available on PyPI.
pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
......
......@@ -21,6 +21,11 @@ FAILED_CASES=""
mkdir -p "$XML_LOG_DIR"
# It is not installed as a requirement,
# because it is not available on PyPI.
pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
......
......@@ -24,22 +24,22 @@ def test_transformer_engine_no_config(feature_dirs):
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
# modify_tensor_enabled - False by default
# modify_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
# inspect_tensor_enabled - False by default
# inspect_tensor_enabled - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
)[0]
# inspect_tensor_postquantize - False by default
# inspect_tensor_postquantize - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -51,24 +51,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -80,22 +80,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
)[0]
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
)[0]
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
)[0]
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -111,22 +111,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
)[0]
# check modify_tensor
......@@ -168,14 +168,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs):
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
)[0]
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
)[0]
finally:
debug_api.end_debug()
......@@ -191,11 +191,11 @@ def test_fake_quant(configs_dir, feature_dirs):
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
)[0]
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
)[0]
# modify_tensor
debug_api.transformer_engine.modify_tensor(
......@@ -218,11 +218,11 @@ def test_fake_quant(configs_dir, feature_dirs):
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
)[0]
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
)[0]
finally:
debug_api.end_debug()
......@@ -265,21 +265,20 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)
)[0]
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
)[0]
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
......@@ -295,10 +294,10 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
)[0]
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
)[0]
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
......@@ -328,7 +327,7 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
)[0]
assert_empty()
finally:
......
test:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 1
freq: 3
LogFp8TensorStats:
enabled: True
tensors: weight
stats: [underflows%]
start_step: 1
freq: 5
\ No newline at end of file
test:
enabled: True
layers:
layer_name_regex_pattern: .*1
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
start_step: 0
freq: 100000
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import tempfile
import os
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
# If layer does not invoke any feature in current iteration,
# then it changed into non-debug mode.
# This test checks whether this works correctly -
# non-quantized statistics should be logged every 3 iterations,
# and quantized statistics should be logged every 5 iterations.
with tempfile.TemporaryDirectory() as temp_dir:
debug_api.initialize(
config_file=configs_dir + "/log_config.yaml",
feature_dirs=feature_dirs,
log_dir=temp_dir,
)
if layer == "linear":
model = te.Linear(128, 128, name="linear1")
elif layer == "transformer":
model = te.TransformerLayer(128, 128, 4, name="transformer1")
else:
raise ValueError(f"Invalid layer: {layer}")
for i in range(11):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
y = model(x)
y.sum().backward()
debug_api.step()
with open(
os.path.join(
temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
),
"r",
) as f:
file_content = f.read()
for i in range(1, 11):
if i % 3 == 0 or i % 5 == 0:
assert f"iteration={i:06d}" in file_content
else:
assert f"iteration={i:06d}" not in file_content
debug_api.end_debug()
TEDebugState._reset()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import torch
import transformer_engine.pytorch as te
import time
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
if debug_tools_initialized:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)
try:
if layer == "linear":
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 18000
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 2000
x = torch.randn(1, 1, 1).cuda()
y = model(x)
y.sum().backward()
debug_api.step()
torch.cuda.synchronize()
time_start = time.time()
for i in range(NUM_ITERS):
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
time_end = time.time()
finally:
if debug_tools_initialized:
debug_api.end_debug()
return time_end - time_start
@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.
with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
......@@ -5,7 +5,7 @@
"""API definition for nvidia-dlframework-inspect."""
import copy
from typing import Dict, Union
from typing import Dict, Union, Tuple, Optional
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry
......@@ -101,13 +101,23 @@ required_kwargs = {
class TEDefaultFeatures:
"""Transformer Engine API calls default behavior."""
def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: int) -> bool:
def fp8_gemm_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
"""
If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer and gemm.
Returning the next enabled iteration can help optimize CPU usage.
Parameters
----------
......@@ -122,9 +132,9 @@ class TEDefaultFeatures:
Returns
-------
bool - default is True
Union[bool, Tuple[bool, Optional[int]]] - default is (True, None)
"""
return True # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
return True, None # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
def modify_tensor_enabled(
self,
......@@ -133,9 +143,16 @@ class TEDefaultFeatures:
gemm: str,
tensor_name: str,
iteration: int,
) -> bool:
) -> bool | Tuple[bool, Optional[int]]:
"""
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what.
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name.
It has **higher priority** than fp8_gemm; if *modify_tensor_enabled* returns True or (True, next_enabled_iter),
then modify_tensor call is invoked for the respective tensor no matter what.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor.
Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large.
Returning only a bool is deprecated.
Parameters
----------
......@@ -153,9 +170,9 @@ class TEDefaultFeatures:
Returns
-------
bool - default is False
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
"""
return False
return False, None
def modify_tensor(
self,
......@@ -167,7 +184,7 @@ class TEDefaultFeatures:
default_quantizer: Quantizer,
iteration: int,
out: Union[torch.Tensor, QuantizedTensor],
) -> Union[torch.Tensor, QuantizedTensor, None]:
) -> torch.Tensor | QuantizedTensor | None:
"""
It allows tensor modification.
For example, feature `FakeQuant` uses it to emulate casting to FP8.
......@@ -298,9 +315,15 @@ class TEDefaultFeatures:
layer_name: str,
tensor_name: str,
iteration: int,
) -> bool:
) -> bool | Tuple[bool, Optional[int]]:
"""
It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked.
It is a routing call, which is run at the initialization of the layer.
Determines if *inspect_tensor* for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer and tensor.
Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large.
Returning only a bool is deprecated.
Parameters
----------
......@@ -316,9 +339,9 @@ class TEDefaultFeatures:
Returns
-------
bool - default is False
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
"""
return False
return False, None
def inspect_tensor_postquantize_enabled(
self,
......@@ -327,11 +350,16 @@ class TEDefaultFeatures:
gemm: str,
tensor_name: str,
iteration: int,
) -> bool:
) -> bool | Tuple[bool, Optional[int]]:
"""
It is a routing call, which is run at the initialization of the layer.
If it returns true, then *inspect_tensor_postquantize* for
a given GEMM and tensor will be invoked.
Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name.
Returning the next enabled iteration can help optimize CPU usage,
especially when the interval between inspect_tensor_postquantize is large.
Returning only a bool is deprecated.
Parameters
----------
......@@ -349,9 +377,9 @@ class TEDefaultFeatures:
Returns
-------
bool - default is False
Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
"""
return False
return False, None
@Registry.register_namespace_api(namespace="transformer_engine")
......@@ -420,7 +448,7 @@ class TransformerEngineAPI(BaseNamespaceAPI):
def output_assertions_hook(self, api_name, ret, **kwargs):
"""Output hooks used to check correctness of the outputs of the API calls."""
if "enabled" in api_name or api_name == "fp8_gemm":
assert isinstance(ret, bool)
assert isinstance(ret, (bool, tuple))
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None
if api_name == "modify_tensor":
......@@ -432,6 +460,38 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"]
def handle_multi_feature_output(
self, api_name, multi_feature_outputs, features_to_invoke, **kwargs
):
"""
Handle multi-tensor output of the API calls.
"""
if "enabled" in api_name:
# *_enabled feature calls can return bool, or tuple (bool, Optional[int]).
# If any of them returns bool, then we return bool - this means that we cannot state anything
# about enablement in the next steps.
# If all of them return a tuple (bool, Optional[int]), we return the minimum value,
# representing the number of steps after the feature will be enabled next time.
# If the second value is None, that means that the feature will never be enabled.
all_ret_tuple = all(
isinstance(feature_output, tuple)
for feature_output in multi_feature_outputs.values()
)
if all_ret_tuple:
run_current = any(
feature_output[0] for feature_output in multi_feature_outputs.values()
)
next_iter = None
for feature_output in multi_feature_outputs.values():
if feature_output[1] is not None:
next_iter = min(next_iter, feature_output[1])
return run_current, next_iter
run_current = any(feature_output for feature_output in multi_feature_outputs.values())
return run_current, None
return super().handle_multi_feature_output(
api_name, multi_feature_outputs, features_to_invoke, **kwargs
)
def step(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.step()"""
STATS_BUFFERS.log_stats()
......
......@@ -50,4 +50,4 @@ class DisableFP8GEMM(TEConfigAPIMapper):
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False
return False, None
......@@ -41,7 +41,7 @@ class DisableFP8Layer:
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False
return False, None
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
......
......@@ -127,14 +127,14 @@ class FakeQuant(TEConfigAPIMapper):
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False
return False, None
@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward."""
return True
return True, iteration + 1
@api_method
def modify_tensor(
......
......@@ -13,6 +13,7 @@ from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as Bas
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
......@@ -92,8 +93,15 @@ class LogFp8TensorStats(BaseLogTensorStats):
self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor_postquantize() in the forward."""
# check whether logging should happen in this iteration
return self._check_params(config, layer_name, iteration=iteration)
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method
def inspect_tensor_postquantize(
......
......@@ -19,6 +19,7 @@ from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter
@Registry.register_feature(namespace="transformer_engine")
......@@ -97,7 +98,15 @@ class LogTensorStats(BaseLogTensorStats):
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run look_at_tensor_before_process() in the forward."""
return self._check_params(config, layer_name, iteration=iteration)
run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method
def inspect_tensor(
......
......@@ -91,14 +91,14 @@ class PerTensorScaling(TEConfigAPIMapper):
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False
return False, None
@api_method
def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward."""
return True
return True, iteration + 1
@api_method
def modify_tensor(
......
......@@ -5,3 +5,38 @@
"""
Utils for the debug features.
"""
def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration):
"""
Determines whether the feature should be enabled at the current iteration,
and computes the next iteration at which the feature will be enabled.
Returns
-------
run_current : bool
True if the feature should be enabled at the current iteration.
next_iter : int
The next iteration index at which the feature will be enabled.
"""
run_current = False
if start_end_list:
intervals = sorted(start_end_list)
else:
start_step = 0 if start_step is None else start_step
end = float("inf") if end_step is None else end_step
intervals = [(start_step, end)]
for s, e in intervals:
if iteration % freq == 0 and s <= iteration <= e:
run_current = True
first = max(iteration + 1, s)
offset = first % freq
candidate = first if offset == 0 else first + (freq - offset)
if candidate <= e:
return run_current, candidate
return run_current, None # No next iteration found
......@@ -10,6 +10,7 @@ When log() is called, they gather stats from all nodes, compute combined final s
from collections import defaultdict
from typing import Dict
import torch
from nvdlfw_inspect.utils import gather_along_first_dim
......@@ -20,6 +21,7 @@ from transformer_engine.debug.features.utils.stats_computation import (
DEPENDENCIES,
stats_to_num,
)
from transformer_engine.debug.pytorch.debug_state import TEDebugState
class _Buffer:
......@@ -146,10 +148,41 @@ class StatsBuffers:
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
# Logging stats involves synchronization between nodes
# and non-trivial cpu overhead.
# It should be only done if absolutely necessary.
# This variables helps to determine if we can reduce.
self.at_least_one_layer_fed = False
self.layers_to_next_iter: Dict[str, int] = {}
def _if_run_reduction(self) -> bool:
"""
Returns True if reduction should be run.
This is the case if at least one layer logged stats.
If not, it may be the case that some layer was not run on this node.
If we know that such layers on all other nodes do not log this time,
we can not reduce. If this in not the case, we should reduce.
To ensure corretness, we assume that every layer is invoked at first forward pass.
If this is not the case, hang might happen.
"""
if self.at_least_one_layer_fed:
return True
iteration = TEDebugState.get_iteration()
for _, next_iter in self.layers_to_next_iter.items():
# Note that layer can be not run for many iterations,
# in this case we will synchronize until every step until we get any information from it.
if iteration >= next_iter:
return True
return False
def reset(self):
"""Resets all buffers."""
self.buffers = {} # (layer_name, tensor_name) -> buffer
self.reduction_group_to_buffer = defaultdict(list)
self.at_least_one_layer_fed = False
self.layers_to_next_iter: Dict[str, int] = {}
def try_add_buffer(
self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch
......@@ -163,12 +196,16 @@ class StatsBuffers:
def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction):
"""Feeds the tensor into the respective buffer."""
self.at_least_one_layer_fed = True
buffer = self.buffers[(layer_name, tensor_name, options)]
buffer.feed(tensor, iteration)
buffer.skip_reduction = skip_reduction
def log_stats(self):
"""Logs the stats from all the buffers."""
if not self._if_run_reduction():
return {}
output = {}
for reduction_group, buffers in self.reduction_group_to_buffer.items():
changed_buffers = [
......@@ -181,7 +218,7 @@ class StatsBuffers:
for _, buffer in changed_buffers:
stats = buffer.log()
output.update(stats)
self.at_least_one_layer_fed = False
return output
......
......@@ -22,6 +22,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import (
prepare_for_saving,
restore_from_saved,
)
from transformer_engine.debug.pytorch.debug_state import TEDebugState
aten = torch.ops.aten
......@@ -53,14 +54,13 @@ class DebugQuantizer(Quantizer):
parent_quantizer: Optional[Quantizer],
tp_group: torch.distributed.ProcessGroup,
):
import nvdlfw_inspect.api as debug_api
super().__init__(rowwise=True, columnwise=True)
self.layer_name = layer_name
self.tensor_name = tensor_name
self.parent_quantizer = parent_quantizer
self.tp_group = tp_group # used in inspect_tensor calls
self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count
self.iteration = TEDebugState.get_iteration()
# .internal = True is slightly faster, but results
# in errors when caching the weights.
......@@ -70,6 +70,12 @@ class DebugQuantizer(Quantizer):
self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name]
# next iteration when this quantizer will call any API
# it is None at the init and it is computed after_enabled api calls.
# None at the beginning means that if nothing will be done,
# this quantizer will never call any API.
self.next_debug_iter = None
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
......@@ -102,15 +108,21 @@ class DebugQuantizer(Quantizer):
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
inspect_tensor_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
)
modify_enabled = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
modify_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION
return inspect_tensor_enabled, plan
......@@ -121,10 +133,13 @@ class DebugQuantizer(Quantizer):
"""
import nvdlfw_inspect.api as debug_api
inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
inspect_tensor_enabled = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_enabled(
layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration
)
)
inspect_tensor_postquantize_enabled_rowwise = (
inspect_tensor_postquantize_enabled_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
......@@ -132,7 +147,8 @@ class DebugQuantizer(Quantizer):
gemm=self.rowwise_gemm_name,
)
)
inspect_tensor_postquantize_enabled_columnwise = (
inspect_tensor_postquantize_enabled_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
layer_name=self.layer_name,
tensor_name=self.tensor_name,
......@@ -158,42 +174,54 @@ class DebugQuantizer(Quantizer):
rowwise_plan = None
columnwise_plan = None
modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
modify_rowwise = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
if modify_rowwise:
rowwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.rowwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
rowwise_plan = STANDARD_FP8_QUANTIZE
if rowwise_plan is None:
rowwise_plan = HIGH_PRECISION
if self.columnwise_gemm_name is not None:
modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
modify_columnwise = self.process_enabled_api_call(
debug_api.transformer_engine.modify_tensor_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
tensor_name=self.tensor_name,
iteration=self.iteration,
)
)
if modify_columnwise:
columnwise_plan = API_CALL_MODIFY
else:
if self.parent_quantizer is not None:
fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
fp8_quantize = self.process_enabled_api_call(
debug_api.transformer_engine.fp8_gemm_enabled(
layer_name=self.layer_name,
gemm=self.columnwise_gemm_name,
iteration=self.iteration,
)
)
if fp8_quantize:
columnwise_plan = STANDARD_FP8_QUANTIZE
if columnwise_plan is None:
......@@ -229,7 +257,7 @@ class DebugQuantizer(Quantizer):
"layer_name": self.layer_name,
"tensor": tensor,
"tensor_name": self.tensor_name,
"iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count,
"iteration": TEDebugState.get_iteration(),
"tp_group": self.tp_group,
}
if tensor is not None and self.inspect_tensor_enabled:
......@@ -270,22 +298,14 @@ class DebugQuantizer(Quantizer):
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if columnwise_gemm_quantize and not rowwise_gemm_quantize:
rowwise_gemm_quantize = True # only columnwise quantization not implemented
self._update_parent_quantizer_usage()
# Only columnwise quantization is not supported.
if self.parent_quantizer is not None:
if not self.parent_quantizer.rowwise_usage and self.parent_quantizer.columnwise_usage:
self.parent_quantizer.set_usage(rowwise=True)
rowwise_gemm_tensor, columnwise_gemm_tensor = None, None
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=True,
columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported
)
quantized_tensor = self.parent_quantizer(tensor)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
......@@ -341,7 +361,6 @@ class DebugQuantizer(Quantizer):
quantizer=self,
layer_name=self.layer_name,
tensor_name=self.tensor_name,
original_tensor=tensor,
)
def process_gemm_output(self, tensor: torch.Tensor):
......@@ -375,6 +394,25 @@ class DebugQuantizer(Quantizer):
return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device)
return torch.empty(shape, dtype=dtype, device=device)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
def calibrate(self, tensor: torch.Tensor):
"""Calibration override, should not be invoked."""
raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported")
......@@ -446,29 +484,70 @@ class DebugQuantizer(Quantizer):
self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor)
def any_feature_enabled(self) -> bool:
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
or self.inspect_tensor_postquantize_enabled_columnwise
or self.rowwise_tensor_plan == API_CALL_MODIFY
or self.columnwise_tensor_plan == API_CALL_MODIFY
):
return True
if self.parent_quantizer is not None:
if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE:
return True
return False
def get_next_debug_iter(self) -> Optional[int]:
"""
Returns the next iteration for which the debug is enabled for this tensor.
If the next iteration is None, then the debug is not enabled for this tensor.
"""
return self.next_debug_iter
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Probably not needed for debug quantizer"""
return None
def process_enabled_api_call(
self, enabled_call_output: bool | Tuple[bool, Optional[int]]
) -> bool:
"""
Process enabled API call output.
Updates self.next_debug_iter field accordingly.
Return the bool representing if the API call is enabled.
"""
if isinstance(enabled_call_output, tuple):
assert len(enabled_call_output) == 2, "Expected a tuple of length 2"
enabled_bool, next_iter = enabled_call_output
else:
enabled_bool = enabled_call_output
next_iter = self.iteration + 1
if self.next_debug_iter is None:
self.next_debug_iter = next_iter
elif next_iter is not None:
# If next iter is None, that means that call will never be enabled.
self.next_debug_iter = min(self.next_debug_iter, next_iter)
return enabled_bool
def supports_only_rowwise_all_gather(self) -> bool:
if self.parent_quantizer is not None:
return self.parent_quantizer.supports_only_rowwise_all_gather()
return False
def _update_parent_quantizer_usage(self):
"""
Updates the usage of the parent quantizer.
"""
rowwise_gemm_quantize = (
self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize = (
self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE
)
if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]:
self.parent_quantizer.set_usage(
rowwise=rowwise_gemm_quantize,
columnwise=columnwise_gemm_quantize,
)
def set_usage(self, rowwise: bool = None, columnwise: bool = None):
"""
Sets the usage of the quantizer.
"""
super().set_usage(rowwise=rowwise, columnwise=columnwise)
if not self.output_tensor:
self._update_parent_quantizer_usage()
class DebugQuantizedTensor(QuantizedTensorBase):
"""
......@@ -484,7 +563,6 @@ class DebugQuantizedTensor(QuantizedTensorBase):
quantizer,
layer_name=None,
tensor_name=None,
original_tensor=None,
):
self.rowwise_gemm_tensor = rowwise_gemm_tensor
......@@ -492,7 +570,6 @@ class DebugQuantizedTensor(QuantizedTensorBase):
self.quantizer = quantizer
self._layer_name = layer_name
self._tensor_name = tensor_name
self._original_tensor = original_tensor
def prepare_for_saving(self):
""" " Prepare for saving method override"""
......@@ -501,6 +578,7 @@ class DebugQuantizedTensor(QuantizedTensorBase):
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor
else [self.rowwise_gemm_tensor]
)
tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save)
self.tensors_to_save = tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
......@@ -519,6 +597,7 @@ class DebugQuantizedTensor(QuantizedTensorBase):
else:
self.rowwise_gemm_tensor = tensor_objects_list[0]
self.columnwise_gemm_tensor = self.rowwise_gemm_tensor
return saved_tensors
def quantize_(self, tensor, *, noop_flag=None):
......@@ -542,3 +621,27 @@ class DebugQuantizedTensor(QuantizedTensorBase):
def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None):
"""Update usage of the tensor."""
if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor:
# If the same object is used both for rowwise and columnwise gemms,
# there is no benefit in erasing the usage of one of them.
# And there are scenarios when not deleting the usage of one of them is needed.
# For example when we want to recreate columnwise from rowwise.
if rowwise_usage is False:
self.rowwise_gemm_tensor = None
if columnwise_usage is False:
self.columnwise_gemm_tensor = None
if isinstance(self.rowwise_gemm_tensor, QuantizedTensor):
self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if isinstance(self.columnwise_gemm_tensor, QuantizedTensor):
self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage)
if rowwise_usage and self.rowwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate rowwise tensor from columnwise tensor in debug mode."
)
if columnwise_usage and self.columnwise_gemm_tensor is None:
raise RuntimeError(
"Cannot recreate columnwise tensor from rowwise tensor is debug mode."
)
......@@ -62,6 +62,13 @@ class TEDebugState:
"""Sets weight tensor reduction mode."""
cls.weight_tensor_tp_group_reduce = enabled
@classmethod
def get_iteration(cls):
"""Returns the current iteration."""
import nvdlfw_inspect.api as debug_api
return debug_api.DEBUG_MANAGER._trainer_iteration_count
def set_weight_tensor_tp_group_reduce(enabled):
"""Sets weight tensor reduction mode."""
......
......@@ -4,6 +4,25 @@
"""Utils functions for the debug module."""
from typing import Optional
def next_iter_when_debug_should_be_run(quantizers) -> Optional[int]:
"""
Returns next iteration at which the debug should be run.
If debug will never be run for this layer, returns None.
"""
out = None
for q in quantizers:
if q.get_next_debug_iter() is not None:
if out is None:
out = q.get_next_debug_iter()
else:
out = min(out, q.get_next_debug_iter())
return out
def any_feature_enabled(quantizers):
"""Returns True if at least one API call is made from DebugQuantizer."""
......
......@@ -981,6 +981,15 @@ def _all_gather_fp8(
return out, handle
def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]:
"""Get quantizer format."""
if isinstance(quantizer, DebugQuantizer):
quantizer = quantizer.parent_quantizer
if isinstance(quantizer, Float8BlockQuantizer):
return quantizer.all_gather_usage
return None
def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None:
"""Make quantizer compact"""
_quantizer = quantizer
......@@ -1343,6 +1352,44 @@ def gather_along_first_dim(
inp = quantizer(inp)
return inp, None
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = DebugQuantizedTensor(
rowwise_gemm_tensor=inp.rowwise_gemm_tensor,
columnwise_gemm_tensor=inp.columnwise_gemm_tensor,
quantizer=inp.quantizer,
layer_name=inp._layer_name,
tensor_name=inp._tensor_name,
)
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
# shapes
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
rowwise_total = None
if rowwise is not None:
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[
0
]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
columnwise_total = None
if columnwise is not None:
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
# Sometimes the same object is used both for rowwise and columnwise gemms,
# and we want to avoid double all-gathers.
out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# Output tensor dims
out_shape = list(inp.size())
out_shape[0] *= world_size
......@@ -1380,34 +1427,6 @@ def gather_along_first_dim(
out_shape=out_shape,
)
# Debug case - call gather_along_first_dim on each tensor
if isinstance(inp, DebugQuantizedTensor):
out_obj = inp
rowwise = inp.get_tensor(False)
columnwise = inp.get_tensor(True)
final_quantizer = (
None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(rowwise, Float8BlockwiseQTensorBase):
rowwise = inp._original_tensor
rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0]
out_obj.rowwise_gemm_tensor = rowwise_total
if rowwise is not columnwise:
final_quantizer_columnwise = (
None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer
)
# Temporary fix for TP communication of Float8BlockwiseQTensorBase
if isinstance(columnwise, Float8BlockwiseQTensorBase):
columnwise = inp._original_tensor
columnwise_total, _ = gather_along_first_dim(
columnwise, process_group, False, final_quantizer_columnwise
)
out_obj.columnwise_gemm_tensor = columnwise_total
else:
out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor
return out_obj, None
# High-precision communication for quantized tensors
if quantizer is not None:
warnings.warn(
......@@ -1418,6 +1437,7 @@ def gather_along_first_dim(
inp = inp.dequantize()
# Falling back to high-precision all-gather for Float8BlockQuantizer
# means that it should directly output GEMM_READY format
compact = _get_quantizer_format(quantizer)
_set_quantizer_format(quantizer, compact=False)
out = torch.empty(
out_shape,
......@@ -1427,6 +1447,7 @@ def gather_along_first_dim(
)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group)
out = quantizer(out)
_set_quantizer_format(quantizer, compact=compact)
return out, None
# Dequantize quantized tensor if not supported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment