Unverified Commit 2645eaec authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[Pytorch] NVIDIA-DL-Framework-Inspect support – part 3 – tests (#1612)



* tests drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* move dir
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* tests fox
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 1d903f5e
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
: ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/}
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
FAIL=0
pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL
...@@ -20,6 +20,7 @@ FAILED_CASES="" ...@@ -20,6 +20,7 @@ FAILED_CASES=""
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
...@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use ...@@ -30,6 +31,19 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_use
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
# debug tests
# Config with the dummy feature which prevents nvinspect from being disabled.
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py"
# standard numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
exit 1 exit 1
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
def pytest_addoption(parser):
parser.addoption(
"--feature_dirs", nargs="+", action="store", default="", help="List of feature directories"
)
parser.addoption(
"--configs_dir",
action="store",
default="",
type=str,
help="Path to the directory with configs.",
)
@pytest.fixture
def feature_dirs(request):
return request.config.getoption("--feature_dirs")
@pytest.fixture
def configs_dir(request):
return request.config.getoption("--configs_dir")
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import tempfile
import functools
import os
import itertools
import random
import argparse
import re
import torch
import torch.distributed as dist
import transformer_engine
import transformer_engine_torch as tex
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug import set_weight_tensor_tp_group_reduce
from test_numerics import (
_emulate_linear,
_init_debug,
disable_fp8_gemms_create_config,
DISABLE_FP8_LAYER_CONFIG,
_cmp,
IN_SIZE,
OUT_SIZE,
_init_model,
SEED,
SEQ_LEN,
BATCH_SIZE,
FP8_RECIPE,
fake_quant_fp8_create_config,
_get_current_scale,
_prepare_per_tensor_scaling_config,
AMAX_HISTORY_LEN,
set_scaling_factors,
set_current_scaling_factors,
)
WORLD_RANK, WORLD_SIZE = None, None
NCCL_WORLD = None
FEATURE_DIRS = None
all_boolean = [True, False]
TEST_NR = 0
def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None):
if tp_size is None:
tp_size = WORLD_SIZE
tp_rank = WORLD_RANK
torch.manual_seed(weight_seed)
weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda()
torch.manual_seed(data_seed)
in_split_size = IN_SIZE // tp_size
out_split_size = OUT_SIZE // tp_size
x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda()
if parallel_mode == "row":
x = x[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size]
x.retain_grad()
with torch.no_grad():
if parallel_mode == "column":
weight = weight[tp_rank * out_split_size : (tp_rank + 1) * out_split_size, :]
else:
weight = weight[:, tp_rank * in_split_size : (tp_rank + 1) * in_split_size]
return x, weight.contiguous()
def _init_model(weight, parallel_mode=None, tp_group=None, name="linear"):
model = transformer_engine.pytorch.Linear(
IN_SIZE,
OUT_SIZE,
name=name,
parallel_mode=parallel_mode,
tp_group=(tp_group or NCCL_WORLD if parallel_mode else None),
)
with torch.no_grad():
model.weight.copy_(weight)
return model
class AllGather(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, dim, group=None):
if group is None:
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()
else:
world_size = torch.distributed.get_world_size(group=group)
rank = torch.distributed.get_rank(group=group)
dist.barrier()
# Create a list to gather tensors from all processes
y_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(y_list, tensor, group=group)
# Save the world size and rank for backward computation
ctx.world_size = world_size
ctx.rank = rank
ctx.dim = dim
# Concatenate the gathered tensors along the feature dimension
y_full = torch.cat(y_list, dim=dim)
return y_full
@staticmethod
def backward(ctx, grad_output):
# Split the gradient output and return the portion corresponding to this rank
grad_input = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)[ctx.rank]
return grad_input, None, None
def _run_forward_backward(x, model, parallel_mode=None, group=None):
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
y.requires_grad_(True)
y.retain_grad()
if parallel_mode == "column":
y = AllGather.apply(y, -1, group)
y.requires_grad_(True)
y.retain_grad()
l = y.sum()
l.backward()
elif parallel_mode == "row":
l = y.sum()
l.backward()
debug_api.step()
return y
def _emulate_linear_distributed(*args, parallel_mode=None, **kwargs):
assert parallel_mode in ["column", "row"]
def split(gradient):
split_size = OUT_SIZE // WORLD_SIZE
gradient = gradient[:, WORLD_RANK * split_size : (WORLD_RANK + 1) * split_size]
return gradient
activation_sync = None
gradient_sync = None
if parallel_mode == "column":
activation_sync = lambda x: AllGather.apply(x, -1)
gradient_sync = split
else:
activation_sync = (
lambda activation: dist.all_reduce(activation, op=dist.ReduceOp.SUM) or activation
)
output = _emulate_linear(
*args, activation_sync=activation_sync, gradient_sync=gradient_sync, **kwargs
)
if parallel_mode == "column":
dist.all_reduce(output["dgrad"], op=dist.ReduceOp.SUM)
return output
def check_debug_log(msg):
with open(f"log/debug_logs/debug_log_globalrank-{WORLD_RANK}.log", "r") as f:
for line in f.readlines():
if msg in line:
return True
return False
def run_debug_test(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
rank = dist.get_rank()
temp_file_name = None
temp_logdir_name = None
if rank == 0:
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
temp_file_name = temp_file.name
temp_dir_obj = tempfile.TemporaryDirectory()
temp_logdir_name = temp_dir_obj.name
# Store the TemporaryDirectory object to prevent it from being deleted
wrapper.temp_dir_obj = temp_dir_obj
temp_file_name_list = [temp_file_name]
temp_logdir_name_list = [temp_logdir_name]
# Broadcast the temporary file and directory names to all processes
dist.broadcast_object_list(temp_file_name_list, src=0)
dist.broadcast_object_list(temp_logdir_name_list, src=0)
temp_file_name = temp_file_name_list[0]
temp_logdir_name = temp_logdir_name_list[0]
dist.barrier()
config_file = open(temp_file_name, mode="r+", buffering=1)
try:
kwargs["config_file"] = config_file
kwargs["log_dir"] = temp_logdir_name
if rank == 0:
global TEST_NR
print(f"Running test {TEST_NR} {func.__name__} with args = {args}.")
TEST_NR += 1
func(*args, **kwargs)
finally:
if rank == 0 and temp_file_name is not None:
os.unlink(temp_file_name)
debug_api.end_debug()
if rank == 0 and hasattr(wrapper, "temp_dir_obj"):
wrapper.temp_dir_obj.cleanup()
return wrapper
CONFIG_LOG_TEST_DISTRIBUTED = """log_distributed:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows%]
start_step : 0
end_step: 1
"""
def _prepare_config_test_log_distributed(config_file):
if WORLD_RANK != 0:
return
config_file.write(CONFIG_LOG_TEST_DISTRIBUTED)
config_file.flush()
def _compute_dynamic_range(tensor):
tensor_abs = tensor.abs()
tensor_abs = tensor_abs[tensor_abs != 0]
if tensor_abs.any():
amin = tensor_abs.min().float()
else:
amin = torch.tensor(1, device=tensor.device).to(torch.float)
amax = tensor_abs.max().float()
if not amax.all():
amax = torch.tensor(1, device=tensor.device).to(torch.float)
dynamic_range = torch.log2(amax) - torch.log2(amin)
return dynamic_range
@run_debug_test
def test_log_distributed(parallel_mode, gather_weight, **kwargs):
_prepare_config_test_log_distributed(kwargs["config_file"])
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
dp_group_ranks = [i for i in range(TP_RANK, WORLD_SIZE, TP_SIZE)]
dp_group = dist.new_group(ranks=dp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
output = _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
gathered_activation = AllGather.apply(x.contiguous(), 0)
gathered_weight = AllGather.apply(weight.contiguous(), 0, tp_group)
gathered_gradient = AllGather.apply(output.grad.contiguous(), 0, dp_group)
if parallel_mode == "row":
gathered_gradient = AllGather.apply(gathered_gradient, 0, tp_group)
log_file = kwargs["log_dir"] + "/nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log"
dist.barrier()
if WORLD_RANK != 0:
return # stats are gathered on node 0
with open(log_file) as f:
content = f.read()
def get_stat(tensor, stat):
regex = r".*_{tensor}_{stat}\s+.*iteration=(\d+)\s+.*value=([-+]?\d*\.?\d+)".format(
tensor=tensor, stat=stat
)
for line in content.splitlines():
match = re.search(regex, line)
if match:
value = float(match.group(2))
return value
rf = lambda x: round(float(x), 4)
stats = []
tensors = {
"activation": gathered_activation,
"weight": gathered_weight if gather_weight else weight,
"gradient": gathered_gradient,
}
stats = {
"min": torch.min,
"max": torch.max,
"mean": torch.mean,
"std": torch.std,
"l1_norm": lambda x: torch.norm(x, p=1),
"l2_norm": lambda x: torch.norm(x, p=2),
"cur_amax": lambda x: x.abs().max(),
"dynamic_range": _compute_dynamic_range,
}
for stat_key in stats.keys():
for tensor_key in tensors.keys():
torch.testing.assert_close(
get_stat(tensor_key, stat_key),
rf(stats[stat_key](tensors[tensor_key])),
atol=0.0001,
rtol=0.0001,
)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
"""
This test tests the scenario, when one of the node of data parallel does not invoke the debug layer.
It naturally occurs in the expert parallelism, when one expert doesn't get input on one node,
but gets it on other nodes. If there were all_gather inside forward(), this would result in deadlock.
"""
_prepare_config_test_log_distributed(kwargs["config_file"])
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
"row", weight_seed=WORLD_RANK * 1234, data_seed=WORLD_RANK * 1234, tp_size=1, tp_rank=0
) # data parallel
model = _init_model(weight, parallel_mode=None, name="linear1")
model1 = _init_model(weight, parallel_mode=None, name="linear2")
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y1 = model(x)
y2 = model1(x)
y = y1 + y2
y.sum().backward()
debug_api.step()
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
if WORLD_RANK != 0:
y = y + model1(x)
y.sum().backward()
@run_debug_test
def test_disable_fp8_gemms(fprop_fp8, dgrad_fp8, wgrad_fp8, parallel_mode, **kwargs):
disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"])
fp8_kwargs = {
"fprop_fp8": fprop_fp8,
"dgrad_fp8": dgrad_fp8,
"wgrad_fp8": wgrad_fp8,
}
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
x, weight = _get_tensors(parallel_mode)
model = _init_model(weight, parallel_mode=parallel_mode)
y = _run_forward_backward(x, model, parallel_mode=parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
x.grad.zero_()
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
@run_debug_test
def test_disable_fp8_layer(parallel_mode, **kwargs):
if WORLD_RANK == 0:
kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG)
kwargs["config_file"].flush()
dist.barrier()
x, weight = _get_tensors(parallel_mode)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode)
x.grad.zero_()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
model = _init_model(weight, parallel_mode)
y = _run_forward_backward(x, model, parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
_cmp(ground_truth, output)
@run_debug_test
def test_per_tensor_scaling(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
parallel_mode,
**kwargs,
):
input_kwargs = {
"fprop_inp": fprop_inp,
"fprop_weight": fprop_weight,
"dgrad_weight": dgrad_weight,
"dgrad_grad": dgrad_grad,
"wgrad_input": wgrad_input,
"wgrad_grad": wgrad_grad,
}
fp8_kwargs = {
"fprop_fp8": True,
"dgrad_fp8": True,
"wgrad_fp8": True,
}
"""
Runs a test to validate per-tensor (current) scaling in FP8 computations.
The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling.
Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling;
similarly, the loss is multiplied by a large factor to alter the gradient's magnitude,
creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors.
Finally, a linear pass is emulated, and the results are compared.”
"""
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
warmup_input, warmup_weight = _get_tensors(parallel_mode=parallel_mode)
model = _init_model(warmup_weight, parallel_mode=parallel_mode)
# Warmup run to setup amax and scaling factors.
for _ in range(AMAX_HISTORY_LEN):
_run_forward_backward(warmup_input, model, parallel_mode=parallel_mode)
x, weight = _get_tensors(
parallel_mode=parallel_mode, weight_seed=WORLD_RANK * 2137, data_seed=WORLD_RANK * 2137
)
model.weight.data = weight.data
x.retain_grad()
# delayed scaling factor
# need to be collected before forward pass with test data,
# because this forward pass changes scaling factors
set_scaling_factors(model, input_kwargs, fp8_kwargs)
LOSS_MULTIPLIER = 100
with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x)
model.zero_grad()
if parallel_mode == "column":
y = AllGather.apply(y, -1)
y.retain_grad()
(
LOSS_MULTIPLIER * y.sum()
).backward() # Loss multiplication to change gradient's order of magintude
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
# per tensor - current - scaling factors
# need to be collected after forward pass with test data,
# because gradient(y.grad) cannot be accessed before forward,
# but it needs to be collected.
set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs)
ground_truth = _emulate_linear_distributed(
x, weight, parallel_mode=parallel_mode, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs
)
_cmp(ground_truth, output)
@run_debug_test
def test_fake_quant_fp8(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
parallel_mode,
**kwargs,
):
fp8_kwargs = {
"fprop_input_fake_quant": fprop_inp,
"fprop_weight_fake_quant": fprop_weight,
"dgrad_gradient_fake_quant": dgrad_grad,
"dgrad_weight_fake_quant": dgrad_weight,
"wgrad_gradient_fake_quant": wgrad_grad,
"wgrad_input_fake_quant": wgrad_input,
"fprop_fp8": not (fprop_inp or fprop_weight),
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
if WORLD_RANK == 0:
fake_quant_fp8_create_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
dist.barrier()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
x, weight = _get_tensors(parallel_mode)
model = _init_model(weight, parallel_mode)
y = _run_forward_backward(x, model, parallel_mode)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
fp8_kwargs["fprop_input_scale"] = (
_get_current_scale(x, fprop_inp) if not fp8_kwargs["fprop_fp8"] else None
)
fp8_kwargs["fprop_weight_scale"] = (
_get_current_scale(weight, fprop_weight) if not fp8_kwargs["fprop_fp8"] else None
)
fp8_kwargs["dgrad_gradient_scale"] = (
_get_current_scale(y.grad, dgrad_grad) if not fp8_kwargs["dgrad_fp8"] else None
)
fp8_kwargs["dgrad_weight_scale"] = (
_get_current_scale(weight, dgrad_weight) if not fp8_kwargs["dgrad_fp8"] else None
)
fp8_kwargs["wgrad_gradient_scale"] = (
_get_current_scale(y.grad, wgrad_grad) if not fp8_kwargs["wgrad_fp8"] else None
)
fp8_kwargs["wgrad_input_scale"] = (
_get_current_scale(x, wgrad_input) if not fp8_kwargs["wgrad_fp8"] else None
)
ground_truth = _emulate_linear_distributed(x, weight, parallel_mode=parallel_mode, **fp8_kwargs)
_cmp(ground_truth, output)
def _init_distributed():
global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, FP8
WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node
assert LOCAL_SIZE <= torch.cuda.device_count()
dist_init_kwargs = {
"backend": "nccl",
"rank": WORLD_RANK,
"world_size": WORLD_SIZE,
}
dist_init_kwargs["init_method"] = "env://"
dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
assert dist.is_nccl_available()
torch.cuda.set_device(LOCAL_RANK)
dist.init_process_group(**dist_init_kwargs)
NCCL_WORLD = dist.new_group(backend="nccl")
WORLD_SIZE = dist.get_world_size()
def _run_test_with_combinations(
test_function, values_list, num_repeat, extra_args, sample_size=None
):
combinations = itertools.product(values_list, repeat=num_repeat)
total_combinations = itertools.product(combinations, extra_args)
if sample_size is not None:
total_combinations = random.sample(list(total_combinations), sample_size)
for comb, arg in total_combinations:
test_function(*comb, arg)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--feature_dirs", type=str)
args = parser.parse_args()
FEATURE_DIRS = args.feature_dirs
random.seed(SEED)
_init_distributed()
test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)
for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)
# test_disable_fp8_gemms
_run_test_with_combinations(
test_disable_fp8_gemms, all_boolean, num_repeat=3, extra_args=["column", "row"]
)
# test_fake_quant_fp8
dtype_options = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None]
_run_test_with_combinations(
test_fake_quant_fp8,
dtype_options,
num_repeat=6,
extra_args=["column", "row"],
sample_size=20,
)
_run_test_with_combinations(
test_per_tensor_scaling,
all_boolean,
num_repeat=6,
extra_args=["column"],
sample_size=20,
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine package.")
exit(1)
def test_transformer_engine_no_config(feature_dirs):
debug_api.initialize("", feature_dirs=feature_dirs)
try:
tensor = torch.rand(24, 2046).cuda()
# FP8 enabled - true by the default
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
# modify_tensor_enabled - False by default
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
# inspect_tensor_enabled - False by default
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)
# inspect_tensor_postquantize - False by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_gemm(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_disable_fp8_layer(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="fprop", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", iteration=0
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="fprop", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="wgrad", iteration=0
)
assert not debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.attn.qkv", gemm="dgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_per_tensor_scaling(configs_dir, feature_dirs):
try:
debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs)
tensor = torch.rand(24, 2046).cuda()
# check modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
)
# check modify_tensor
default_quantizer1 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
default_quantizer2 = Float8Quantizer(
scale=torch.tensor([1]).cuda(),
amax=torch.tensor([0]).cuda(),
fp8_dtype=tex.DType.kFloat8E5M2,
)
output1 = debug_api.transformer_engine.modify_tensor(
layer_name="decoder.1.mlp.fc1",
gemm="fprop",
tensor_name="activation",
default_quantizer=default_quantizer1,
iteration=0,
tensor=tensor,
)
assert type(output1) == Float8Tensor
assert output1._fp8_dtype == tex.DType.kFloat8E4M3
output2 = debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
default_quantizer=default_quantizer2,
iteration=0,
)
assert type(output2) == Float8Tensor
assert output2._fp8_dtype == tex.DType.kFloat8E5M2
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1",
gemm="wgrad",
tensor_name="gradient",
iteration=0,
)
assert not debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc4",
gemm="fprop",
tensor_name="activation",
iteration=0,
)
finally:
debug_api.end_debug()
def test_fake_quant(configs_dir, feature_dirs):
try:
debug_api.initialize(
configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs
)
tensor = torch.rand(24, 2046).cuda()
# modify_tensor_enabled
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
)
assert debug_api.transformer_engine.modify_tensor_enabled(
"decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
)
# modify_tensor
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="fprop",
tensor=tensor,
tensor_name="activation",
iteration=0,
default_quantizer=None,
)
debug_api.transformer_engine.modify_tensor(
"decoder.1.mlp.fc1",
gemm="dgrad",
tensor=tensor,
tensor_name="gradient",
iteration=0,
default_quantizer=None,
)
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
# caching
assert debug_api.transformer_engine.fp8_gemm_enabled(
"decoder.1.fc2", gemm="wgrad", iteration=0
)
finally:
debug_api.end_debug()
def test_statistics_collection(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def assert_empty():
stats = log()
assert len(stats) == 0
# TE tensor stats --
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
)
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
)
assert_empty()
finally:
debug_api.end_debug()
def test_statistics_multi_run(configs_dir, feature_dirs):
try:
debug_api.initialize(
config_file=configs_dir + "stats_collection_test_config.yaml",
feature_dirs=feature_dirs,
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
)
def log_stats():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
return STATS_BUFFERS.log_stats()
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
stats2 = log_stats()
assert len(stats1.keys()) > 0
for k in stats1.keys():
torch.testing.assert_close(stats1[k], stats2[k])
finally:
debug_api.end_debug()
if __name__ == "__main__":
pass
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os
from nvdlfw_inspect.config_manager import ConfigManager
import nvdlfw_inspect.api as debug_api
try:
import transformer_engine
from transformer_engine.debug.features.api import TEConfigAPIMapper
except (ImportError, ModuleNotFoundError):
print("Could not find TransformerEngine debug module.")
exit(1)
def test_transformer_engine_config_parsing(feature_dirs):
debug_api.initialize(
config_file=pathlib.Path(__file__).resolve().parent
/ "test_configs/tensor_manipulation_transformer_engine.yaml",
feature_dirs=feature_dirs,
log_dir="./log",
)
cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"]
cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"]
assert cfg_fc1 and cfg_fc2
gemm_parsing = True
tensor_parsing = True
# Per tensor scaling set for dgrad, filter based on gemm
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert not ret
# per tensor scaling set for gradient, filter based on tensor name
ret, _ = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="activation",
)
assert not ret
ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"}
# Test tensor struct
ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc1["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc1_act == {
"gemm": "fprop",
"tensor": "activation",
"quant_format": "FP8E4M3",
}
assert parsed_cfg_fc1_wei == {
"gemm": "fprop",
"tensor": "weight",
"quant_format": "FP8E4M3",
}
# Test gemms struct
ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"}
ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["FakeQuant"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="dgrad",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"}
# Test gemm + tensor struct
ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"}
ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="fprop",
tensor_name="weight",
)
assert ret
assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"}
ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="activation",
)
assert ret
assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"}
ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api(
cfg_fc2["PerTensorScaling"],
gemm_parsing=gemm_parsing,
tensor_parsing=tensor_parsing,
gemm="wgrad",
tensor_name="gradient",
)
assert ret
assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"}
ConfigManager.reset()
test_disable_fp8_gemm_1:
enabled: True
layers:
layer_types: [qkv, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [dgrad, wgrad]
\ No newline at end of file
test_disable_fp8_layer:
enabled: True
layers:
layer_types: [qkv]
transformer_engine:
DisableFP8Layer:
enabled: True
\ No newline at end of file
deummy_feature_everywhere:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
TestDummyFeature:
enabled: True
tensors: [weight, activation, gradient, output, wgrad, dgrad]
gemms: [wgrad, dgrad, fprop]
\ No newline at end of file
test_fake_quant_fp8:
enabled: True
layers:
layer_numbers: [1]
layer_types: [fc1, fc2]
transformer_engine:
FakeQuant:
enabled: True
gemms: [fprop, dgrad]
tensors_struct:
- tensor: activation
quant_format: FP8E4M3
- tensor: gradient
quant_format: FP8E5M2
\ No newline at end of file
test_per_tensor_scaling:
enabled: True
layers:
layer_numbers: [1]
layer_types: [fc1, fc2]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [wgrad]
PerTensorScaling:
enabled: True
gemms_struct:
- gemm: fprop
tensors_struct:
- tensor: activation
- tensor: weight
- gemm: dgrad
tensors_struct:
- tensor: gradient
\ No newline at end of file
stat_collection_test_1:
enabled: True
layers:
layer_numbers: [1, 3]
LogTensorStats:
enabled: True
stats: [mean, std, l1_norm, l2_norm]
tensors: [activation]
freq: 1
start_step: 100
end_step: 500
transformer_engine:
LogTensorStats:
enabled: True
stats: [cur_amax, dynamic_range]
tensors: [activation]
freq: 2
start_step: 100
end_step: 500
LogFp8TensorStats:
enabled: True
stats: [underflows%]
tensors: [gradient]
freq: 5
start_step: 100
end_step: 500
stat_collection_test_2:
enabled: True
layers:
layer_numbers: [6, 7]
transformer_engine:
LogTensorStats:
enabled: True
tensors_struct:
- tensor: activation
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
freq: 2
start_step: 100
end_step: 500
- tensor: weight
stats: [mean, std, l1_norm, min, max]
freq: 5
start_step: 100
end_step: 500
stat_collection_test_4:
enabled: True
layers:
layer_numbers: [5]
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation]
stats: [cur_amax, dynamic_range, mean, std, l1_norm]
LogFp8TensorStats:
enabled: True
stats: [underflows%]
tensors: [activation]
\ No newline at end of file
# This config is used when FP8 training is ON
transformer_engine_fc1_manipulation:
enabled: True
layers:
layer_name_regex_pattern: .*(fc1) # Select layers if they end in fc1
transformer_engine: # namespace
DisableFP8GEMM: # Disable FP8 GEMM. FProp run in high precision
enabled: True
gemms: [fprop]
PerTensorScaling: # Scale DGrad gradients using per tensor current scaling and run FP8 GEMM
enabled: True
gemms: [dgrad]
tensors: [gradient]
FakeQuant: # Disable FP8 GEMM for Wgrad. Fake quantize activations to Wgrad and run high precision GEMM
enabled: True
gemms: [fprop]
tensors_struct:
- tensor: activation
quant_format: FP8E4M3
- tensor: weight
quant_format: FP8E4M3
transformer_engine_fc2_manipulation:
enabled: True
layers:
layer_name_regex_pattern: .*(fc2) # Select layers if they end in fc2
transformer_engine: # namespace
PerTensorScaling: # Scale WGrad and Fprop inputs using per tensor current scaling and run FP8 GEMM
enabled: True
gemms_struct:
- gemm: fprop
tensors_struct:
- tensor: activation
- tensor: weight
- gemm: wgrad
tensors_struct:
- tensor: activation
- tensor: gradient
FakeQuant: # Disable FP8 GEMM for DGrad. Fake quantize weights and gradients to DGrad and run high precision GEMM
enabled: True
gemms_struct:
- gemm: dgrad
tensors: [weight, gradient]
quant_format: FP8E5M2
\ No newline at end of file
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import subprocess
from pathlib import Path
import pytest
import torch
"""
Distributed numerics tests
These tests test the numerical corectness of the TransformerEngine layers.
Tests are parametrized by the layer and fp8 precision.
One test consists of running multiple configurations from file run_numerics.py
Such design is due to the fact the initialization of one test is long
- 2 processes need to start and load torch and TE. Multiple configurations
are run in one test - this reduces the initialization overhead.
"""
if torch.cuda.device_count() < 2:
pytest.skip("Distributed training needs at least 2 GPUs.")
TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
def test_debug_distributed(feature_dirs):
test_path = TEST_ROOT / "run_distributed.py"
test_cmd = LAUNCH_CMD + [str(test_path), f"--feature_dirs={feature_dirs[0]}"]
result = subprocess.run(test_cmd, env=os.environ, capture_output=True, check=False)
if result.returncode != 0:
raise AssertionError(result.stderr.decode())
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as tepytorch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.module.base import (
_2X_ACC_DGRAD,
_2X_ACC_FPROP,
_2X_ACC_WGRAD,
)
all_boolean = [True, False]
FP8_FORMAT = Format.HYBRID
AMAX_HISTORY_LEN = 16
FP8_RECIPE = DelayedScaling(
fp8_format=FP8_FORMAT, amax_history_len=AMAX_HISTORY_LEN, amax_compute_algo="max"
)
SEED = 1234
IN_SIZE = 128
OUT_SIZE = 64
BATCH_SIZE = 16
SEQ_LEN = 128
LOSS_FN = torch.nn.functional.cross_entropy
def _cast_to_fp8(tensor, scale, dtype):
tensor = tensor.contiguous()
if type(scale) == torch.Tensor:
amax = scale.abs().max().float()
quantizer = Float8Quantizer(scale, amax, dtype)
else:
quantizer = Float8CurrentScalingQuantizer(scale, device=tensor.device)
return quantizer(tensor)
def _get_current_scale(tensor, fp8_dtype):
if fp8_dtype == tex.DType.kFloat8E4M3:
fp8_max = Format.E4M3.value.max_fwd
else:
fp8_max = Format.E5M2.value.max_fwd
amax = tensor.abs().max().float()
one = torch.ones(1, device=tensor.device)
return _default_sf_compute(amax, one, fp8_max, 0).detach()
def _fake_cast(tensor, fp8_dtype, scale):
scale = scale or _get_current_scale(tensor, fp8_dtype)
fp8_tensor = _cast_to_fp8(tensor, scale, fp8_dtype)
return fp8_tensor.dequantize()
def _fp8_gemm_kernel(tensor1, scale1, dtype1, tensor2, scale2, dtype2, use_split_accumulator):
fp8_tensor1 = _cast_to_fp8(tensor1, scale1, dtype1)
fp8_tensor2 = _cast_to_fp8(tensor2, scale2, dtype2)
out, *_ = tepytorch.cpp_extensions.general_gemm(
fp8_tensor1,
fp8_tensor2,
tepytorch.module.base.get_workspace(),
torch.float32,
use_split_accumulator=use_split_accumulator,
)
out.requires_grad = True
return out.T
def _emulate_linear(
input: torch.Tensor,
weight: torch.Tensor,
fprop_fp8: bool = False,
fprop_input_fake_quant: tex.DType = None,
fprop_input_scale: torch.Tensor = None,
fprop_weight_fake_quant: tex.DType = None,
fprop_weight_scale: torch.Tensor = None,
dgrad_fp8: bool = False,
dgrad_gradient_fake_quant: tex.DType = None,
dgrad_gradient_scale: torch.Tensor = None,
dgrad_weight_fake_quant: tex.DType = None,
dgrad_weight_scale: torch.Tensor = None,
wgrad_fp8: bool = False,
wgrad_gradient_fake_quant: tex.DType = None,
wgrad_gradient_scale: torch.Tensor = None,
wgrad_input_fake_quant: tex.DType = None,
wgrad_input_scale: torch.Tensor = None,
loss_multiplier: float = 1.0,
activation_sync=None,
gradient_sync=None,
):
_scalar = lambda x: torch.Tensor([x]).cuda() if type(x) in [float, torch.Tensor] else x
if fprop_fp8:
activation = _fp8_gemm_kernel(
input,
_scalar(fprop_input_scale or 1.0),
tex.DType.kFloat8E4M3,
weight,
_scalar(fprop_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
_2X_ACC_FPROP,
)
activation = activation.clone().detach().contiguous().requires_grad_(True)
else:
fprop_input = (
_fake_cast(input, fprop_input_fake_quant, _scalar(fprop_input_scale))
if fprop_input_fake_quant is not None
else input
)
fprop_weight = (
_fake_cast(weight, fprop_weight_fake_quant, _scalar(fprop_weight_scale))
if fprop_weight_fake_quant is not None
else weight
)
activation = (fprop_input @ fprop_weight.T).contiguous()
if activation_sync:
activation = activation_sync(activation)
activation.retain_grad()
(loss_multiplier * activation.sum()).backward(retain_graph=True)
gradient = activation.grad.clone()
if gradient_sync:
gradient = gradient_sync(gradient)
if dgrad_fp8:
dgrad = _fp8_gemm_kernel(
weight.T,
_scalar(dgrad_weight_scale or 1.0),
tex.DType.kFloat8E4M3,
gradient,
_scalar(dgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
_2X_ACC_DGRAD,
).T
else:
dgrad_gradient = (
_fake_cast(gradient, dgrad_gradient_fake_quant, _scalar(dgrad_gradient_scale))
if dgrad_gradient_fake_quant is not None
else gradient
)
dgrad_weight = (
_fake_cast(weight, dgrad_weight_fake_quant, _scalar(dgrad_weight_scale))
if dgrad_weight_fake_quant is not None
else weight
)
dgrad = dgrad_gradient @ dgrad_weight
if wgrad_fp8:
wgrad = _fp8_gemm_kernel(
input.T,
_scalar(wgrad_input_scale or 1.0),
tex.DType.kFloat8E4M3,
gradient.T,
_scalar(wgrad_gradient_scale or 1.0),
tex.DType.kFloat8E5M2,
_2X_ACC_WGRAD,
).T
else:
wgrad_gradient = (
_fake_cast(gradient, wgrad_gradient_fake_quant, _scalar(wgrad_gradient_scale))
if wgrad_gradient_fake_quant is not None
else gradient
)
wgrad_input = (
_fake_cast(input, wgrad_input_fake_quant, _scalar(wgrad_input_scale))
if wgrad_input_fake_quant is not None
else input
)
wgrad_input = wgrad_input.contiguous()
wgrad_gradient = wgrad_gradient.contiguous()
wgrad, *_ = tepytorch.cpp_extensions.general_gemm(
wgrad_input,
wgrad_gradient,
tepytorch.module.base.get_workspace(),
torch.float32,
layout="NT",
grad=True,
use_split_accumulator=_2X_ACC_WGRAD,
)
return {"activation": activation, "wgrad": wgrad, "dgrad": dgrad}
def _init_debug(config_name, log_dir, feature_dirs):
debug_api.initialize(
config_file=config_name,
feature_dirs=feature_dirs,
log_dir=log_dir,
default_logging_enabled=True,
)
def create_config_file(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
with tempfile.TemporaryDirectory() as temp_dir:
try:
kwargs["config_file"] = temp_file
kwargs["log_dir"] = temp_dir
result = func(*args, **kwargs)
finally:
temp_file_name = temp_file.name
debug_api.end_debug()
os.unlink(temp_file_name)
return result
return wrapper
def _cmp(ground_truth, output):
torch.testing.assert_close(ground_truth["activation"], output["activation"])
torch.testing.assert_close(ground_truth["wgrad"], output["wgrad"])
torch.testing.assert_close(ground_truth["dgrad"], output["dgrad"])
def _init_model(weight):
model = transformer_engine.pytorch.Linear(IN_SIZE, OUT_SIZE, name="linear")
with torch.no_grad():
model.weight.copy_(weight.contiguous())
return model
def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None):
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=is_first_microbatch)
(y.sum() * loss_scale).backward()
debug_api.step()
return y
def _get_tensors():
torch.manual_seed(SEED)
x = torch.randn((SEQ_LEN * BATCH_SIZE, IN_SIZE), requires_grad=True).cuda()
x.retain_grad()
weight = torch.randn((OUT_SIZE, IN_SIZE)).cuda()
return x, weight
DISABLE_FP8_CONFIG = Template(
"""disable_fp8_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
DisableFP8GEMM:
enabled: True
gemms: [$gemms]
"""
)
@pytest.mark.parametrize("fprop_fp8", all_boolean)
@pytest.mark.parametrize("dgrad_fp8", all_boolean)
@pytest.mark.parametrize("wgrad_fp8", all_boolean)
def test_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8):
run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8)
def disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, config_file):
gemms = ""
if not fprop_fp8:
gemms += "fprop,"
if not dgrad_fp8:
gemms += "dgrad,"
if not wgrad_fp8:
gemms += "wgrad,"
if len(gemms) > 0:
gemms = gemms[:-1] # remove last ','
config_file.write(DISABLE_FP8_CONFIG.safe_substitute(gemms=gemms))
config_file.flush()
@create_config_file
def run_disable_fp8_gemms(feature_dirs, fprop_fp8, dgrad_fp8, wgrad_fp8, **kwargs):
disable_fp8_gemms_create_config(fprop_fp8, dgrad_fp8, wgrad_fp8, kwargs["config_file"])
fp8_kwargs = {
"fprop_fp8": fprop_fp8,
"dgrad_fp8": dgrad_fp8,
"wgrad_fp8": wgrad_fp8,
}
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
x, weight = _get_tensors()
model = _init_model(weight)
y = _run_forward_backward(x, model)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
x.grad.zero_()
ground_truth = _emulate_linear(x, weight, **fp8_kwargs)
_cmp(ground_truth, output)
def test_disable_fp8_layer(feature_dirs):
run_disable_fp8_layer(feature_dirs)
DISABLE_FP8_LAYER_CONFIG = """disable_fp8_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
DisableFP8Layer:
enabled: True
"""
@create_config_file
def run_disable_fp8_layer(feature_dirs, **kwargs):
kwargs["config_file"].write(DISABLE_FP8_LAYER_CONFIG)
kwargs["config_file"].flush()
x, weight = _get_tensors()
ground_truth = _emulate_linear(x, weight)
x.grad.zero_()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
model = _init_model(weight)
y = _run_forward_backward(x, model)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
_cmp(ground_truth, output)
random.seed(1234)
all_combinations = list(itertools.product(all_boolean, repeat=6))
subset_combinations = random.sample(all_combinations, 20)
@pytest.mark.parametrize(
"fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad",
subset_combinations,
)
def test_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False")
run_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
PER_TENSOR_SCALING_CONFIG = Template(
"""per_tensor_scaling_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
PerTensorScaling:
enabled: True
gemms_struct:
$gemms
"""
)
def _prepare_per_tensor_scaling_config(
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
):
gemms = ""
title = lambda x: f" - gemm: {x}\n tensors: ["
def add_tensor(if_add, gemm_name):
nonlocal gemms
if if_add:
gemms += gemm_name + ","
if fprop_inp or fprop_weight:
gemms += title("fprop")
add_tensor(fprop_inp, "activation")
add_tensor(fprop_weight, "weight")
gemms = gemms[:-1] + "]\n"
if dgrad_weight or dgrad_grad:
gemms += title("dgrad")
add_tensor(dgrad_weight, "weight")
add_tensor(dgrad_grad, "gradient")
gemms = gemms[:-1] + "]\n"
if wgrad_input or wgrad_grad:
gemms += title("wgrad")
add_tensor(wgrad_input, "activation")
add_tensor(wgrad_grad, "gradient")
gemms = gemms[:-1] + "]\n"
config_file.write(PER_TENSOR_SCALING_CONFIG.safe_substitute(gemms=gemms))
config_file.flush()
def set_scaling_factors(model, input_kwargs, fp8_kwargs):
# Copy fp8 scaling factors into fp8_kwargs dict if respective flag in input_kwargs is set.
if not input_kwargs["fprop_inp"]:
fp8_kwargs["fprop_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone()
if not input_kwargs["fprop_weight"]:
fp8_kwargs["fprop_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone()
if not input_kwargs["dgrad_grad"]:
fp8_kwargs["dgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone()
if not input_kwargs["dgrad_weight"]:
fp8_kwargs["dgrad_weight_scale"] = model.fp8_meta["scaling_fwd"].scale[1].clone()
if not input_kwargs["wgrad_grad"]:
fp8_kwargs["wgrad_gradient_scale"] = model.fp8_meta["scaling_bwd"].scale[0].clone()
if not input_kwargs["wgrad_input"]:
fp8_kwargs["wgrad_input_scale"] = model.fp8_meta["scaling_fwd"].scale[0].clone()
def set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs):
# Compute per tensor scaling factor if respective flag in input_kwargs is set.
if input_kwargs["fprop_inp"]:
fp8_kwargs["fprop_input_scale"] = tex.DType.kFloat8E4M3
if input_kwargs["fprop_weight"]:
fp8_kwargs["fprop_weight_scale"] = tex.DType.kFloat8E4M3
if input_kwargs["dgrad_grad"]:
fp8_kwargs["dgrad_gradient_scale"] = tex.DType.kFloat8E5M2
if input_kwargs["dgrad_weight"]:
fp8_kwargs["dgrad_weight_scale"] = tex.DType.kFloat8E4M3
if input_kwargs["wgrad_grad"]:
fp8_kwargs["wgrad_gradient_scale"] = tex.DType.kFloat8E5M2
if input_kwargs["wgrad_input"]:
fp8_kwargs["wgrad_input_scale"] = tex.DType.kFloat8E4M3
@create_config_file
def run_per_tensor_scaling(
feature_dirs,
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
**kwargs,
):
input_kwargs = {
"fprop_inp": fprop_inp,
"fprop_weight": fprop_weight,
"dgrad_weight": dgrad_weight,
"dgrad_grad": dgrad_grad,
"wgrad_input": wgrad_input,
"wgrad_grad": wgrad_grad,
}
fp8_kwargs = {
"fprop_fp8": True,
"dgrad_fp8": True,
"wgrad_fp8": True,
}
"""
Runs a test to validate per-tensor (current) scaling in FP8 computations.
The function performs warm-up iterations to populate the amax buffer of the model and compute scaling factors based on delayed scaling.
Subsequently, weights and inputs are switched to ensure their current scaling factors differ from those based on delayed scaling;
similarly, the loss is multiplied by a large factor to alter the gradient's magnitude,
creating a discrepancy between the original (delayed) and per-tensor (current) scaling factors.
Finally, a linear pass is emulated, and the results are compared.”
"""
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
warmup_input, warmup_weight = _get_tensors()
model = _init_model(warmup_weight)
# Warmup run to setup amax and scaling factors.
for _ in range(AMAX_HISTORY_LEN):
_run_forward_backward(warmup_input, model)
x = torch.randn_like(warmup_input, requires_grad=True).cuda()
weight = torch.randn_like(warmup_weight, requires_grad=True).cuda()
model.weight.data = weight.data
x.retain_grad()
# delayed scaling factor
# need to be collected before forward pass with test data,
# because this forward pass changes scaling factors
set_scaling_factors(model, input_kwargs, fp8_kwargs)
LOSS_MULTIPLIER = 100
with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE):
y = model(x, is_first_microbatch=True)
model.zero_grad()
y.retain_grad()
(
LOSS_MULTIPLIER * y.sum()
).backward() # Loss multiplication to change gradient's order of magintude
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
# per tensor - current - scaling factors
# need to be collected after forward pass with test data,
# because gradient(y.grad) cannot be accessed before forward,
# but it needs to be collected.
set_current_scaling_factors(x, weight, y, input_kwargs, fp8_kwargs)
ground_truth = _emulate_linear(x, weight, loss_multiplier=LOSS_MULTIPLIER, **fp8_kwargs)
_cmp(ground_truth, output)
@pytest.mark.parametrize(
"fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad",
subset_combinations,
)
def test_microbatching_per_tensor_scaling(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
if not any([fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad]):
pytest.skip("Skipping test because all parameters are False")
@create_config_file
def run_microbatching_test(
feature_dirs,
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
**kwargs,
):
# Prepare the configuration file
_prepare_per_tensor_scaling_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
# Initialize debug
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
# Get data
x_full, weight = _get_tensors()
microbatch_size = x_full.size(0) // 2
x_mb1 = x_full[:microbatch_size, ...].clone().detach().requires_grad_(True)
x_mb2 = x_full[microbatch_size:, ...].clone().detach().requires_grad_(True)
def init_and_warmup():
model = _init_model(weight)
_run_forward_backward(x_mb1, model, loss_scale=0.5)
_run_forward_backward(x_mb2, model, loss_scale=0.5)
return model
# Run without is_first_microbatch
model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors
y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5)
y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5)
# Collect outputs
output1 = {
"activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0),
"wgrad": model.weight.grad.clone(),
"dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0),
}
# Run with is_first_microbatch
model = init_and_warmup() # running next 2 iters does not change amaxes and scaling factors
y_mb1 = _run_forward_backward(x_mb1, model, loss_scale=0.5, is_first_microbatch=True)
y_mb2 = _run_forward_backward(x_mb2, model, loss_scale=0.5, is_first_microbatch=False)
# Collect outputs
output2 = {
"activation": torch.cat([y_mb1.clone(), y_mb2.clone()], dim=0),
"wgrad": model.weight.grad.clone(),
"dgrad": torch.cat([x_mb1.grad.clone(), x_mb2.grad.clone()], dim=0),
}
# Compare outputs
torch.testing.assert_close(output1["activation"], output2["activation"], atol=1.0, rtol=0.5)
torch.testing.assert_close(output1["dgrad"], output2["dgrad"], atol=1.0, rtol=0.5)
torch.testing.assert_close(output1["wgrad"], output2["wgrad"], atol=1.0, rtol=0.5)
# Run the test
run_microbatching_test(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
all_combinations = list(
itertools.product([tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, None], repeat=6)
)
subset_combinations = random.sample(all_combinations, 10)
@pytest.mark.parametrize(
"fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad",
subset_combinations,
)
def test_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
):
run_fake_quant_fp8(
feature_dirs, fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad
)
FAKE_QUANT_CONFIG = Template(
"""fake_quant_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
FakeQuant:
enabled: True
gemms_struct:
$gemms
"""
)
def fake_quant_fp8_create_config(
fprop_inp, fprop_weight, dgrad_weight, dgrad_grad, wgrad_input, wgrad_grad, config_file
):
format_to_str = {tex.DType.kFloat8E4M3: "FP8E4M3", tex.DType.kFloat8E5M2: "FP8E5M2"}
gemms = ""
def _add_tensor(quant_format, tensor):
nonlocal gemms
if quant_format:
gemms += " " * 8 + "- tensor: " + tensor + "\n"
gemms += " " * 8 + " quant_format: " + format_to_str[quant_format] + "\n"
title = lambda x: f" - gemm: {x}\n tensors_struct:\n"
if fprop_inp or fprop_weight:
gemms += title("fprop")
_add_tensor(fprop_inp, "activation")
_add_tensor(fprop_weight, "weight")
gemms = gemms[:-1] + "\n"
if dgrad_weight or dgrad_grad:
gemms += title("dgrad")
_add_tensor(dgrad_weight, "weight")
_add_tensor(dgrad_grad, "gradient")
gemms = gemms[:-1] + "\n"
if wgrad_input or wgrad_grad:
gemms += title("wgrad")
_add_tensor(wgrad_input, "activation")
_add_tensor(wgrad_grad, "gradient")
gemms = gemms[:-1] + "\n"
config = FAKE_QUANT_CONFIG.safe_substitute(gemms=gemms)
config_file.write(config)
config_file.flush()
@create_config_file
def run_fake_quant_fp8(
feature_dirs,
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
**kwargs,
):
fp8_kwargs = {
"fprop_input_fake_quant": fprop_inp,
"fprop_weight_fake_quant": fprop_weight,
"dgrad_gradient_fake_quant": dgrad_grad,
"dgrad_weight_fake_quant": dgrad_weight,
"wgrad_gradient_fake_quant": wgrad_grad,
"wgrad_input_fake_quant": wgrad_input,
"fprop_fp8": not (fprop_inp or fprop_weight),
"dgrad_fp8": not (dgrad_weight or dgrad_grad),
"wgrad_fp8": not (wgrad_grad or wgrad_input),
}
fake_quant_fp8_create_config(
fprop_inp,
fprop_weight,
dgrad_weight,
dgrad_grad,
wgrad_input,
wgrad_grad,
kwargs["config_file"],
)
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], feature_dirs)
x, weight = _get_tensors()
model = _init_model(weight)
y = _run_forward_backward(x, model)
output = {"activation": y.clone(), "wgrad": model.weight.grad.clone(), "dgrad": x.grad.clone()}
ground_truth = _emulate_linear(x, weight, **fp8_kwargs)
_cmp(ground_truth, output)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import functools
import itertools
import os
import random
import tempfile
from string import Template
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import _default_sf_compute
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from test_numerics import create_config_file
B, S, H, D = 64, 64, 64, 64
model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"]
configs = {
"": "",
"log": """log:
layers:
layer_types: [linear]
enabled:
True
transformer_engine:
LogTensorStats:
enabled: True
tensors: [activation, gradient, weight, output, wgrad, dgrad]
stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
start_step : 0
end_step: 1
LogFp8TensorStats:
enabled: True
tensors: [activation, gradient, weight]
stats: [underflows, overflows]
start_step : 0
end_step: 1
""",
"fake_quant": """
fake_quant_config:
enabled: True
layers:
layer_types: [linear]
transformer_engine:
FakeQuant:
enabled: True
gemms: [fprop, dgrad, wgrad]
quant_format: FP8E5M2
""",
}
def _get_model(model_key):
if model_key == "linear":
return te.Linear(D, D)
if model_key == "layernorm_linear":
return te.LayerNormLinear(D, D)
if model_key == "layernorm_mlp":
return te.LayerNormMLP(D, D, D)
if model_key == "mha_attention":
return te.MultiheadAttention(D, H)
if model_key == "transformer_layer":
return te.TransformerLayer(D, D, H)
def _run_forward_backward(model, fp8):
for _ in range(3):
inp = torch.randn((S, B, H)).cuda()
with te.fp8_autocast(enabled=fp8):
out = model(inp)
out.sum().backward()
debug_api.step()
@create_config_file
def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
try:
if config != "":
config_file.write(config)
config_file.flush()
config_file_name = config_file.name if config != "" else ""
debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name)
model = _get_model(model_key)
_run_forward_backward(model, fp8)
except Exception as error:
raise error
finally:
debug_api.end_debug()
@pytest.mark.parametrize("model_key", model_keys)
@pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("config_key", configs.keys())
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
_run_test(model_key, fp8, configs[config_key], feature_dirs)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
LOG_FILE = os.path.join("nvdlfw_inspect_logs", "nvdlfw_inspect_globalrank-0.log")
def reset_debug_log():
if os.path.isfile(LOG_FILE):
# delete all content
with open(LOG_FILE, "w") as f:
pass
def check_debug_log(msg):
with open(LOG_FILE, "r") as f:
for line in f.readlines():
if msg in line:
return True
return False
...@@ -34,6 +34,18 @@ NCCL_WORLD = None ...@@ -34,6 +34,18 @@ NCCL_WORLD = None
LOSS_FN = nn.MSELoss() LOSS_FN = nn.MSELoss()
QUANTIZATION = None QUANTIZATION = None
if os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False):
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
# Disable TF32 # Disable TF32
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
......
...@@ -102,6 +102,20 @@ all_normalizations = ["LayerNorm", "RMSNorm"] ...@@ -102,6 +102,20 @@ all_normalizations = ["LayerNorm", "RMSNorm"]
mask_types = ["causal", "no_mask"] mask_types = ["causal", "no_mask"]
NVTE_TEST_NVINSPECT_ENABLED = os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", False)
if NVTE_TEST_NVINSPECT_ENABLED:
# The numerics of all the layers should work the same,
# when debug=True. I fed them with dummy feature
# to prevent switching off debug, which can happen if
# no feature is active.
import nvdlfw_inspect.api as debug_api
debug_api.initialize(
os.environ["NVTE_TEST_NVINSPECT_CONFIG_FILE"],
feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"],
)
fp8_recipes = [ fp8_recipes = [
recipe.MXFP8BlockScaling(), recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(), recipe.DelayedScaling(),
...@@ -568,6 +582,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m ...@@ -568,6 +582,8 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_m
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
...@@ -682,6 +698,8 @@ def test_gpt_full_activation_recompute( ...@@ -682,6 +698,8 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
...@@ -1726,6 +1744,8 @@ def test_grouped_linear_accuracy( ...@@ -1726,6 +1744,8 @@ def test_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available: if fp8 and recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
...@@ -1924,6 +1944,8 @@ def test_padding_grouped_linear_accuracy( ...@@ -1924,6 +1944,8 @@ def test_padding_grouped_linear_accuracy(
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
...@@ -2039,6 +2061,8 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): ...@@ -2039,6 +2061,8 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
@pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("model", ["126m"])
def test_gpt_cuda_graph(dtype, bs, model): def test_gpt_cuda_graph(dtype, bs, model):
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("Cuda Graphs are not supported in debug mode.")
config = model_configs[model] config = model_configs[model]
sigma = 0.023 sigma = 0.023
...@@ -2136,6 +2160,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): ...@@ -2136,6 +2160,8 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe):
pytest.skip(reason_for_no_fp8) pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available: if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8) pytest.skip(reason_for_no_mxfp8)
if NVTE_TEST_NVINSPECT_ENABLED:
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available: if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling) pytest.skip(reason_for_no_fp8_block_scaling)
......
...@@ -12,7 +12,7 @@ from nvdlfw_inspect.registry import Registry ...@@ -12,7 +12,7 @@ from nvdlfw_inspect.registry import Registry
import torch import torch
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import all_tensor_types from transformer_engine.pytorch.tensor import get_all_tensor_types
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
...@@ -424,7 +424,7 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -424,7 +424,7 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None assert ret is None
if api_name == "modify_tensor": if api_name == "modify_tensor":
assert type(ret) in all_tensor_types assert type(ret) in get_all_tensor_types()
if ( if (
type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck type(ret) == torch.Tensor # pylint: disable=unidiomatic-typecheck
and "dtype" in kwargs and "dtype" in kwargs
...@@ -438,4 +438,4 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -438,4 +438,4 @@ class TransformerEngineAPI(BaseNamespaceAPI):
def end_debug(self): def end_debug(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()""" """This function is called by the nvidia-dlframework-inspect after every debug_api.end_debug()"""
TEDebugState.reset() TEDebugState._reset()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment