# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. import pytest import torch import nvdlfw_inspect.api as debug_api import transformer_engine.pytorch as te from test_numerics import create_config_file fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) B, S, H, D = 64, 64, 64, 64 model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"] configs = { "": "", "log": """log: layers: layer_types: [linear] enabled: True transformer_engine: LogTensorStats: enabled: True tensors: [activation, gradient, weight, output, wgrad, dgrad] stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range] start_step : 0 end_step: 1 """, "log_fp8": """log_fp8: layers: layer_types: [linear] enabled: True transformer_engine: LogFp8TensorStats: enabled: True tensors: [activation, gradient, weight] stats: [underflows%] start_step : 0 end_step: 1 """, "fake_quant": """ fake_quant_config: enabled: True layers: layer_types: [linear] transformer_engine: FakeQuant: enabled: True gemms: [fprop, dgrad, wgrad] tensors: [activation, weight, gradient] quant_format: FP8E5M2 """, } # Configs that require FP8 to be enabled fp8_required_configs = {"log_fp8"} def _get_model(model_key): if model_key == "linear": return te.Linear(D, D, name="layer") if model_key == "layernorm_linear": return te.LayerNormLinear(D, D, name="layer") if model_key == "layernorm_mlp": return te.LayerNormMLP(D, D, D, name="layer") if model_key == "mha_attention": return te.MultiheadAttention(D, H, name="layer") if model_key == "transformer_layer": return te.TransformerLayer(D, D, H, name="layer") def _run_forward_backward(model, fp8): for _ in range(3): inp = torch.randn((S, B, H)).cuda() with te.autocast(enabled=fp8): out = model(inp) out.sum().backward() debug_api.step() @create_config_file def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir): try: if config != "": config_file.write(config) config_file.flush() config_file_name = config_file.name if config != "" else "" debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name) model = _get_model(model_key) _run_forward_backward(model, fp8) except Exception as error: raise error finally: debug_api.end_debug() @pytest.mark.parametrize("model_key", model_keys) @pytest.mark.parametrize("fp8", [False, True]) @pytest.mark.parametrize("config_key", configs.keys()) def test_sanity_debug(model_key, fp8, config_key, feature_dirs): if fp8 and not fp8_available: pytest.skip(reason_for_no_fp8) if not fp8 and config_key in fp8_required_configs: pytest.skip(f"Config '{config_key}' requires FP8") _run_test(model_key, fp8, configs[config_key], feature_dirs)