"tests/cpp/vscode:/vscode.git/clone" did not exist on "fbee89906a990fa640593b83739345e1f81099dc"
test_api_features.py 13.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer

import nvdlfw_inspect.api as debug_api

try:
    import transformer_engine
    import transformer_engine_torch as tex
except (ImportError, ModuleNotFoundError):
    print("Could not find TransformerEngine package.")
    exit(1)


def test_transformer_engine_no_config(feature_dirs):
    debug_api.initialize("", feature_dirs=feature_dirs)
    try:

        tensor = torch.rand(24, 2046).cuda()

        # FP8 enabled - true by the default
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
27
        )[0]
28

29
        # modify_tensor_enabled - (False, None) by default
30
31
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
32
        )[0]
33

34
        # inspect_tensor_enabled - (False, None) by default
35
36
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.attn.qkv", tensor_name="activation", iteration=0
37
        )[0]
38
39
40
41
42
43
44
45
46
47
48

    finally:
        debug_api.end_debug()


def test_disable_fp8_gemm(configs_dir, feature_dirs):
    try:
        debug_api.initialize(configs_dir + "disable_fp8_gemms.yaml", feature_dirs=feature_dirs)

        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
49
        )[0]
50
51
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="dgrad", iteration=0
52
        )[0]
53
54
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="wgrad", iteration=0
55
        )[0]
56
57
58
59

        # caching
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
60
        )[0]
61
62
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="dgrad", iteration=0
63
        )[0]
64
65
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="wgrad", iteration=0
66
        )[0]
67
68
69
70
71
72
73
74
75
76
77

    finally:
        debug_api.end_debug()


def test_disable_fp8_layer(configs_dir, feature_dirs):
    try:
        debug_api.initialize(configs_dir + "disable_fp8_layer.yaml", feature_dirs=feature_dirs)

        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", iteration=0
78
        )[0]
79
80
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.mlp.fc1", gemm="wgrad", iteration=0
81
        )[0]
82
83
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", iteration=0
84
        )[0]
85
86
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="fprop", iteration=0
87
        )[0]
88
89
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="wgrad", iteration=0
90
        )[0]
91
92
        assert not debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.attn.qkv", gemm="dgrad", iteration=0
93
        )[0]
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

    finally:
        debug_api.end_debug()


def test_per_tensor_scaling(configs_dir, feature_dirs):
    try:

        debug_api.initialize(configs_dir + "per_tensor_scaling.yaml", feature_dirs=feature_dirs)

        tensor = torch.rand(24, 2046).cuda()

        # check modify_tensor_enabled
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
109
        )[0]
110
111
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0
112
        )[0]
113
114
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
115
        )[0]
116
117
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0
118
        )[0]
119
120
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0
121
        )[0]
122
123
        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0
124
        )[0]
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

        # check modify_tensor

        default_quantizer1 = Float8Quantizer(
            scale=torch.tensor([1]).cuda(),
            amax=torch.tensor([0]).cuda(),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
        default_quantizer2 = Float8Quantizer(
            scale=torch.tensor([1]).cuda(),
            amax=torch.tensor([0]).cuda(),
            fp8_dtype=tex.DType.kFloat8E5M2,
        )

        output1 = debug_api.transformer_engine.modify_tensor(
            layer_name="decoder.1.mlp.fc1",
            gemm="fprop",
            tensor_name="activation",
            default_quantizer=default_quantizer1,
            iteration=0,
            tensor=tensor,
        )
        assert type(output1) == Float8Tensor
        assert output1._fp8_dtype == tex.DType.kFloat8E4M3

        output2 = debug_api.transformer_engine.modify_tensor(
            "decoder.1.mlp.fc1",
            gemm="dgrad",
            tensor=tensor,
            tensor_name="gradient",
            default_quantizer=default_quantizer2,
            iteration=0,
        )
        assert type(output2) == Float8Tensor
        assert output2._fp8_dtype == tex.DType.kFloat8E5M2

        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1",
            gemm="wgrad",
            tensor_name="gradient",
            iteration=0,
166
        )[0]
167
168
169
170
171
172

        assert not debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc4",
            gemm="fprop",
            tensor_name="activation",
            iteration=0,
173
        )[0]
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
    finally:
        debug_api.end_debug()


def test_fake_quant(configs_dir, feature_dirs):
    try:
        debug_api.initialize(
            configs_dir + "fake_quantization_config.yaml", feature_dirs=feature_dirs
        )

        tensor = torch.rand(24, 2046).cuda()

        # modify_tensor_enabled
        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0
189
        )[0]
190
191
192

        assert debug_api.transformer_engine.modify_tensor_enabled(
            "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0
193
        )[0]
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

        # modify_tensor
        debug_api.transformer_engine.modify_tensor(
            "decoder.1.mlp.fc1",
            gemm="fprop",
            tensor=tensor,
            tensor_name="activation",
            iteration=0,
            default_quantizer=None,
        )

        debug_api.transformer_engine.modify_tensor(
            "decoder.1.mlp.fc1",
            gemm="dgrad",
            tensor=tensor,
            tensor_name="gradient",
            iteration=0,
            default_quantizer=None,
        )

        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.fc2", gemm="wgrad", iteration=0
216
        )[0]
217
218
219
        # caching
        assert debug_api.transformer_engine.fp8_gemm_enabled(
            "decoder.1.fc2", gemm="wgrad", iteration=0
220
        )[0]
221
222
223
224
225
226
227
228
229
230
231
232
233
    finally:
        debug_api.end_debug()


def test_statistics_collection(configs_dir, feature_dirs):
    try:
        debug_api.initialize(
            config_file=configs_dir + "stats_collection_test_config.yaml",
            feature_dirs=feature_dirs,
            default_logging_enabled=False,
        )

        tensor = torch.randn((100, 100, 5)).cuda()
234
235
236
        quantizer = Float8Quantizer(
            scale=torch.full([1], 1.0).cuda(),
            amax=torch.full([1], 1.0).cuda(),
237
238
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
239
        tensor_fp8 = quantizer(tensor)
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

        def log():
            from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS

            return STATS_BUFFERS.log_stats()

        def assert_empty():
            stats = log()
            assert len(stats) == 0

        # TE tensor stats --
        debug_api.transformer_engine.inspect_tensor(
            "decoder.1.mlp.fc1",
            tensor=tensor,
            tensor_name="activation",
            iteration=200,
            tp_group=None,
257
258
259
            quantizer=quantizer,
            rowwise_quantized_tensor=tensor_fp8,
            columnwise_quantized_tensor=tensor_fp8,
260
261
262
263
264
        )
        stats = log()
        assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
265
        )[0]
266
267
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.2.mlp.fc1", tensor_name="activation", iteration=200
268
        )[0]
269
270
271
272
273
274

        expected_underflows = (
            ((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
        )

        assert debug_api.transformer_engine.inspect_tensor_enabled(
275
            "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
276
        )[0]
277
278

        # TE FP8 tensor stats --
279
280
        assert debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
281
        )[0]
282
        debug_api.transformer_engine.inspect_tensor(
283
284
285
286
            "decoder.1.mlp.fc1",
            tensor_name="gradient",
            iteration=200,
            tp_group=None,
287
288
289
290
            tensor=tensor,
            quantizer=quantizer,
            rowwise_quantized_tensor=tensor_fp8,
            columnwise_quantized_tensor=tensor_fp8,
291
292
293
294
295
296
        )
        stats = log()
        torch.testing.assert_close(
            stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
        )

297
298
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.1.mlp.fc1", tensor_name="activation", iteration=201
299
        )[0]
300
301
        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.2.mlp.fc1", tensor_name="gradient", iteration=200
302
        )[0]
303
304
305
306
307
308
309
310

        # Second config in same yaml
        tensor = torch.rand((100, 100, 5))
        debug_api.transformer_engine.inspect_tensor(
            "decoder.6.mlp.fc1",
            tensor_name="activation",
            iteration=200,
            tp_group=None,
311
312
313
314
            tensor=tensor,
            quantizer=quantizer,
            rowwise_quantized_tensor=tensor_fp8,
            columnwise_quantized_tensor=tensor_fp8,
315
316
317
318
319
320
321
322
323
324
325
        )
        stats = log()
        stats_names = [x[3] for x in stats.keys()]
        all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
        assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()

        debug_api.transformer_engine.inspect_tensor(
            "decoder.7.mlp.fc1",
            tensor_name="weight",
            iteration=200,
            tp_group=None,
326
327
328
329
            tensor=tensor,
            quantizer=quantizer,
            rowwise_quantized_tensor=tensor_fp8,
            columnwise_quantized_tensor=tensor_fp8,
330
331
332
333
334
335
336
337
        )
        stats = log()
        stats_names = [x[3] for x in stats.keys()]
        all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
        assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()

        assert not debug_api.transformer_engine.inspect_tensor_enabled(
            "decoder.7.mlp.fc1", tensor_name="weight", iteration=201
338
        )[0]
339
340
341
342
343
344
345
346
347
348
349
350
351
352
        assert_empty()

    finally:
        debug_api.end_debug()


def test_statistics_multi_run(configs_dir, feature_dirs):
    try:
        debug_api.initialize(
            config_file=configs_dir + "stats_collection_test_config.yaml",
            feature_dirs=feature_dirs,
            default_logging_enabled=False,
        )

353
        def feed(tensor, tensor_fp8, quantizer):
354
355
356
357
358
359
            debug_api.transformer_engine.inspect_tensor(
                "decoder.5.mlp.fc1",
                tensor=tensor,
                tensor_name="activation",
                iteration=1,
                tp_group=None,
360
361
362
                quantizer=quantizer,
                rowwise_quantized_tensor=tensor_fp8,
                columnwise_quantized_tensor=tensor_fp8,
363
364
365
366
367
368
369
            )

        def log_stats():
            from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS

            return STATS_BUFFERS.log_stats()

370
371
372
373
374
375
        quantizer = Float8Quantizer(
            scale=torch.full([1], 1.0).cuda(),
            amax=torch.full([1], 1.0).cuda(),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )

376
        def fp8_tensor(t):
377
            return quantizer(t.cuda())
378
379
380
381
382

        shape = [1024, 1024]
        tensors = [torch.randn(shape) for _ in range(2)]
        tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]

383
384
        feed(tensors[0], tensors_fp8[0], quantizer)
        feed(tensors[1], tensors_fp8[1], quantizer)
385
386
387
388
        stats1 = log_stats()

        tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
        fp8tensor2 = fp8_tensor(tensor2)
389
        feed(tensor2, fp8tensor2, quantizer)
390
391
392
393
394
395
396
397
398
399
400
        stats2 = log_stats()

        assert len(stats1.keys()) > 0
        for k in stats1.keys():
            torch.testing.assert_close(stats1[k], stats2[k])
    finally:
        debug_api.end_debug()


if __name__ == "__main__":
    pass