test_sanity.py 2.68 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
#
# See LICENSE for license information.

import pytest
import torch

import nvdlfw_inspect.api as debug_api
import transformer_engine.pytorch as te

from test_numerics import create_config_file

13
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
B, S, H, D = 64, 64, 64, 64

model_keys = ["linear", "layernorm_linear", "layernorm_mlp", "mha_attention", "transformer_layer"]

configs = {
    "": "",
    "log": """log:
  layers:
    layer_types: [linear]
  enabled:
    True
  transformer_engine:
    LogTensorStats:
      enabled: True
      tensors: [activation, gradient, weight, output, wgrad, dgrad]
      stats: [min, max, mean, std, l1_norm, l2_norm, cur_amax, dynamic_range]
      start_step : 0
      end_step: 1
    LogFp8TensorStats:
      enabled: True
      tensors: [activation, gradient, weight]
      stats: [underflows, overflows]
      start_step : 0
      end_step: 1
""",
    "fake_quant": """
fake_quant_config:
  enabled: True
  layers:
    layer_types: [linear]
  transformer_engine:
    FakeQuant:
      enabled: True
      gemms: [fprop, dgrad, wgrad]
      quant_format: FP8E5M2
""",
}


def _get_model(model_key):
    if model_key == "linear":
        return te.Linear(D, D)
    if model_key == "layernorm_linear":
        return te.LayerNormLinear(D, D)
    if model_key == "layernorm_mlp":
        return te.LayerNormMLP(D, D, D)
    if model_key == "mha_attention":
        return te.MultiheadAttention(D, H)
    if model_key == "transformer_layer":
        return te.TransformerLayer(D, D, H)


def _run_forward_backward(model, fp8):
    for _ in range(3):
        inp = torch.randn((S, B, H)).cuda()
70
        with te.autocast(enabled=fp8):
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            out = model(inp)
        out.sum().backward()
        debug_api.step()


@create_config_file
def _run_test(model_key, fp8, config, feature_dirs, config_file, log_dir):
    try:
        if config != "":
            config_file.write(config)
            config_file.flush()
        config_file_name = config_file.name if config != "" else ""
        debug_api.initialize(feature_dirs=feature_dirs, config_file=config_file_name)
        model = _get_model(model_key)
        _run_forward_backward(model, fp8)
    except Exception as error:
        raise error
    finally:
        debug_api.end_debug()


@pytest.mark.parametrize("model_key", model_keys)
@pytest.mark.parametrize("fp8", [False, True])
@pytest.mark.parametrize("config_key", configs.keys())
def test_sanity_debug(model_key, fp8, config_key, feature_dirs):
96
97
    if fp8 and not fp8_available:
        pytest.skip(reason_for_no_fp8)
98
    _run_test(model_key, fp8, configs[config_key], feature_dirs)