test_config.py 4.79 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pathlib, os

from nvdlfw_inspect.config_manager import ConfigManager

import nvdlfw_inspect.api as debug_api

try:
    import transformer_engine
    from transformer_engine.debug.features.api import TEConfigAPIMapper
except (ImportError, ModuleNotFoundError):
    print("Could not find TransformerEngine debug module.")
    exit(1)


def test_transformer_engine_config_parsing(feature_dirs):
    debug_api.initialize(
        config_file=pathlib.Path(__file__).resolve().parent
        / "test_configs/tensor_manipulation_transformer_engine.yaml",
        feature_dirs=feature_dirs,
        log_dir="./log",
    )

    cfg_fc1 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc1")["transformer_engine"]
    cfg_fc2 = ConfigManager.get_config_for_layer("decoder.1.mlp.fc2")["transformer_engine"]
    assert cfg_fc1 and cfg_fc2

    gemm_parsing = True
    tensor_parsing = True

    # Per tensor scaling set for dgrad, filter based on gemm
    ret, _ = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc1["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="wgrad",
        tensor_name="activation",
    )
    assert not ret

    # per tensor scaling set for gradient, filter based on tensor name
    ret, _ = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc1["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="dgrad",
        tensor_name="activation",
    )
    assert not ret

    ret, parsed_cfg_fc1 = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc1["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="dgrad",
        tensor_name="gradient",
    )
    assert ret
    assert parsed_cfg_fc1 == {"gemm": "dgrad", "tensor": "gradient"}

    # Test tensor struct
    ret, parsed_cfg_fc1_act = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc1["FakeQuant"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="fprop",
        tensor_name="activation",
    )
    ret, parsed_cfg_fc1_wei = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc1["FakeQuant"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="fprop",
        tensor_name="weight",
    )
    assert ret
    assert parsed_cfg_fc1_act == {
        "gemm": "fprop",
        "tensor": "activation",
        "quant_format": "FP8E4M3",
    }
    assert parsed_cfg_fc1_wei == {
        "gemm": "fprop",
        "tensor": "weight",
        "quant_format": "FP8E4M3",
    }

    # Test gemms struct
    ret, parsed_cfg_fc2_grad = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc2["FakeQuant"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="dgrad",
        tensor_name="gradient",
    )
    assert ret
    assert parsed_cfg_fc2_grad == {"gemm": "dgrad", "tensor": "gradient", "quant_format": "FP8E5M2"}
    ret, parsed_cfg_fc2_wei = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc2["FakeQuant"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="dgrad",
        tensor_name="weight",
    )
    assert ret
    assert parsed_cfg_fc2_wei == {"gemm": "dgrad", "tensor": "weight", "quant_format": "FP8E5M2"}

    # Test gemm + tensor struct
    ret, parsed_cfg_fc2_fprop_act = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc2["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="fprop",
        tensor_name="activation",
    )
    assert ret
    assert parsed_cfg_fc2_fprop_act == {"gemm": "fprop", "tensor": "activation"}

    ret, parsed_cfg_fc2_fprop_wei = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc2["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="fprop",
        tensor_name="weight",
    )
    assert ret
    assert parsed_cfg_fc2_fprop_wei == {"gemm": "fprop", "tensor": "weight"}

    ret, parsed_cfg_fc2_wgrad_act = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc2["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="wgrad",
        tensor_name="activation",
    )
    assert ret
    assert parsed_cfg_fc2_wgrad_act == {"gemm": "wgrad", "tensor": "activation"}

    ret, parsed_cfg_fc2_wgrad_grad = TEConfigAPIMapper().parse_config_and_api(
        cfg_fc2["PerTensorScaling"],
        gemm_parsing=gemm_parsing,
        tensor_parsing=tensor_parsing,
        gemm="wgrad",
        tensor_name="gradient",
    )
    assert ret
    assert parsed_cfg_fc2_wgrad_grad == {"gemm": "wgrad", "tensor": "gradient"}

    ConfigManager.reset()