test_api_features.py 13.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer

import nvdlfw_inspect.api as debug_api

try:
    import transformer_engine
    import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
    print("Could not find TransformerEngine package.")
    exit(1)


def test_transformer_engine_no_config(feature_dirs):
    debug_api.initialize("", feature_dirs=feature_dirs)
    try:

        tensor = torch.rand(24, 2046).cuda()

        # FP8 enabled - true by the default
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
        )

        # modify_tensor_enabled - False by default
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
        )

        # inspect_tensor_enabled - False by default
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.attn.qkv", tensor_name="activation", iteration=0
        )

        # inspect_tensor_postquantize - False by default
        assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
            "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
        )

    finally:
        debug_api.end_debug()


def test_disable_fp8_gemm(configs_dir, feature_dirs):
    try:
        debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs)

        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="dgrad", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="wgrad", iteration=0
        )

        # caching
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="dgrad", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="wgrad", iteration=0
        )

    finally:
        debug_api.end_debug()


def test_disable_fp8_layer(configs_dir, feature_dirs):
    try:
        debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs)

        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", iteration=0
        )
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.mlp.fc1", gemm="wgrad", iteration=0
        )
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="wgrad", iteration=0
        )
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="dgrad", iteration=0
        )

    finally:
        debug_api.end_debug()


def test_per_tensor_scaling(configs_dir, feature_dirs):
    try:

        debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs)

        tensor = torch.rand(24, 2046).cuda()

        # check modify_tensor_enabled
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
        )
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
        )
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
        )
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
        )
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
        )
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
        )

        # check modify_tensor

        default_quantizer1 = Float8Quantizer(
            scale=torch.tensor([1]).cuda(),
            amax=torch.tensor([0]).cuda(),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        default_quantizer2 = Float8Quantizer(
            scale=torch.tensor([1]).cuda(),
            amax=torch.tensor([0]).cuda(),
            fp8_dtype=tex.DType.kFloat8E5M2,
        )

        output1 = debug_api.transformer_engine.modify_tensor(
            layer_name="decoder.1.mlp.fc1",
            gemm="fprop",
            tensor_name="activation",
            default_quantizer=default_quantizer1,
            iteration=0,
            tensor=tensor,
        )
        assert type(output1) == Float8Tensor
        assert output1._fp8_dtype == tex.DType.kFloat8E4M3

        output2 = debug_api.transformer_engine.modify_tensor(
            "decoder.1.mlp.fc1",
            gemm="dgrad",
            tensor=tensor,
            tensor_name="gradient",
            default_quantizer=default_quantizer2,
            iteration=0,
        )
        assert type(output2) == Float8Tensor
        assert output2._fp8_dtype == tex.DType.kFloat8E5M2

        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1",
            gemm="wgrad",
            tensor_name="gradient",
            iteration=0,
        )

        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc4",
            gemm="fprop",
            tensor_name="activation",
            iteration=0,
        )
    finally:
        debug_api.end_debug()


def test_fake_quant(configs_dir, feature_dirs):
    try:
        debug_api.initialize(
            configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs
        )

        tensor = torch.rand(24, 2046).cuda()

        # modify_tensor_enabled
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
        )

        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
        )

        # modify_tensor
        debug_api.transformer_engine.modify_tensor(
            "decoder.1.mlp.fc1",
            gemm="fprop",
            tensor=tensor,
            tensor_name="activation",
            iteration=0,
            default_quantizer=None,
        )

        debug_api.transformer_engine.modify_tensor(
            "decoder.1.mlp.fc1",
            gemm="dgrad",
            tensor=tensor,
            tensor_name="gradient",
            iteration=0,
            default_quantizer=None,
        )

        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.fc2", gemm="wgrad", iteration=0
        )
        # caching
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.fc2", gemm="wgrad", iteration=0
        )
    finally:
        debug_api.end_debug()


def test_statistics_collection(configs_dir, feature_dirs):
    try:
        debug_api.initialize(
            config_file=configs_dir + "stats_collection_test_config.yaml",
            feature_dirs=feature_dirs,
            default_logging_enabled=False,
        )

        tensor = torch.randn((100, 100, 5)).cuda()
        tensor_fp8 = Float8Tensor(
            data=tensor.to(torch.uint8).cuda(),
            fp8_scale_inv=torch.full([1], 1.0).cuda(),
            fp8_dtype=tex.DType.kFloat8E4M3,
            shape=tensor.shape,
            dtype=torch.float32,
        )

        def log():
            from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS

            return STATS_BUFFERS.log_stats()

        def assert_empty():
            stats = log()
            assert len(stats) == 0

        # TE tensor stats --
        debug_api.transformer_engine.inspect_tensor(
            "decoder.1.mlp.fc1",
            tensor=tensor,
            tensor_name="activation",
            iteration=200,
            tp_group=None,
        )
        stats = log()
        assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
        )
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.2.mlp.fc1", tensor_name="activation", iteration=200
        )
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
        )

        expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
        expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5)

        # TE FP8 tensor stats --
        assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
            "decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
        )
        debug_api.transformer_engine.inspect_tensor_postquantize(
            "decoder.1.mlp.fc1",
            tensor=tensor_fp8,
            tensor_name="gradient",
            iteration=200,
            rowwise=True,
            tp_group=None,
        )
        stats = log()
        torch.testing.assert_close(
            stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
        )

        assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
            "decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
        )
        assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
            "decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
        )

        # Second config in same yaml
        tensor = torch.rand((100, 100, 5))
        debug_api.transformer_engine.inspect_tensor(
            "decoder.6.mlp.fc1",
            tensor=tensor,
            tensor_name="activation",
            iteration=200,
            tp_group=None,
        )
        stats = log()
        stats_names = [x[3] for x in stats.keys()]
        all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
        assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()

        debug_api.transformer_engine.inspect_tensor(
            "decoder.7.mlp.fc1",
            tensor=tensor,
            tensor_name="weight",
            iteration=200,
            tp_group=None,
        )
        stats = log()
        stats_names = [x[3] for x in stats.keys()]
        all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
        assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()

        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.7.mlp.fc1", tensor_name="weight", iteration=201
        )
        assert_empty()

    finally:
        debug_api.end_debug()


def test_statistics_multi_run(configs_dir, feature_dirs):
    try:
        debug_api.initialize(
            config_file=configs_dir + "stats_collection_test_config.yaml",
            feature_dirs=feature_dirs,
            default_logging_enabled=False,
        )

        def feed(tensor, tensor_fp8):
            debug_api.transformer_engine.inspect_tensor(
                "decoder.5.mlp.fc1",
                tensor=tensor,
                tensor_name="activation",
                iteration=1,
                tp_group=None,
            )
            debug_api.transformer_engine.inspect_tensor_postquantize(
                "decoder.5.mlp.fc1",
                tensor=tensor_fp8,
                tensor_name="activation",
                iteration=1,
                rowwise=True,
                tp_group=None,
            )

        def log_stats():
            from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS

            return STATS_BUFFERS.log_stats()

        def fp8_tensor(t):
            return Float8Tensor(
                data=t.to(torch.uint8).cuda(),
                fp8_scale_inv=torch.ones([1]).cuda(),
                fp8_dtype=tex.DType.kFloat8E4M3,
                shape=t.shape,
                dtype=torch.float32,
            )

        shape = [1024, 1024]
        tensors = [torch.randn(shape) for _ in range(2)]
        tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]

        feed(tensors[0], tensors_fp8[0])
        feed(tensors[1], tensors_fp8[1])
        stats1 = log_stats()

        tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
        fp8tensor2 = fp8_tensor(tensor2)
        feed(tensor2, fp8tensor2)
        stats2 = log_stats()

        assert len(stats1.keys()) > 0
        for k in stats1.keys():
            torch.testing.assert_close(stats1[k], stats2[k])
    finally:
        debug_api.end_debug()


if __name__ == "__main__":
    pass