test_recipe.py 20.2 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, Optional
6
7
8

import pytest
import torch
9
import warnings
10
11
12

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
13
14
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
15
import transformer_engine_torch as tex
16
17
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
18
    _amax_and_scale_update,
19
    fp8_model_init,
20
)
21
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
22
import transformer_engine.pytorch.ops as te_ops
23
from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear
24
25
from transformer_engine.pytorch.distributed import fp8_autocast
from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling
26
import transformer_engine_torch as tex
27
28
29

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
30
31
32
33
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
    FP8GlobalStateManager.is_fp8_block_scaling_available()
)
34

35

36
# FP8 per tensor delayed scaling
37
38
39
40
41
42
43
44
45
46
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

47
    @pytest.mark.parametrize("amax_history_len", [31, 1024])
48
49
    @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
    @pytest.mark.parametrize("is_first_microbatch", [None, True, False])
50
    def test_fp8_scale_update_with_linear_module(
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
        self,
        amax_history_len: int,
        amax_compute_algo: str,
        is_first_microbatch: Optional[bool],
        margin: int = 2,
    ):

        # Construct linear module
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=amax_history_len,
            amax_compute_algo=amax_compute_algo,
        )
66
        with te.fp8_autocast(fp8_recipe=recipe):
67
            module = te.Linear(16, 16)
68
69
70
71
            y = module(
                torch.randn([16, 16], device="cuda"),
                is_first_microbatch=True,
            )
72
73
74
75
76
77
78
        y.backward(torch.zeros_like(y))

        # Get amax history and scaling factors
        fp8_meta = module.fp8_meta
        forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
        amax_history_forward = fp8_meta[forward_key].amax_history
        scale_forward = fp8_meta[forward_key].scale
79
        # scale_inv_forward = fp8_meta[forward_key].scale_inv
80
81
82
        backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
        amax_history_backward = fp8_meta[backward_key].amax_history
        scale_backward = fp8_meta[backward_key].scale
83
        # scale_inv_backward = fp8_meta[backward_key].scale_inv
84
85
86

        # Tweak amax history and scaling factors
        amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
87
        amax_history_forward[0, :].zero_()
88
        scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
89
        # scale_inv_forward.copy_(torch.reciprocal(scale_forward))
90
        amax_history_backward[0, :].zero_()
91
92

        # Expected amax history after update
93
94
95
96
97
98
99
100
101
102
        # Note: amax history is only updated when amax is updated
        update_weight_amax = is_first_microbatch is None or is_first_microbatch
        ref_amax_history_forward = amax_history_forward.clone()
        ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1))
        if update_weight_amax:
            ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1))
        ref_amax_history_forward[0, :].zero_()
        ref_amax_history_backward = amax_history_backward.clone()
        ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1))
        ref_amax_history_backward[0, :].zero_()
103
104
105
106
107
108

        # Expected scale and scale inverse
        if amax_compute_algo == "max":
            ref_amax_forward = amax_history_forward.max(dim=0).values
            ref_amax_backward = amax_history_backward.max(dim=0).values
        elif amax_compute_algo == "most_recent":
109
110
            ref_amax_forward = amax_history_forward[-1]
            ref_amax_backward = amax_history_backward[-1]
111
112
        else:
            raise ValueError(f"{amax_compute_algo=} is not supported")
113
114
        ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
        ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
115
        # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
116
        update_weight_amax = is_first_microbatch is None or is_first_microbatch
117
118
119
        # if not update_weight_amax:
        #    ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
        # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
120

121
        # Perform forward, backward, and optimizer steps to update fp8_meta
122
        with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
123
            x = torch.randn([16, 16], device="cuda")
124
            y = module(x, is_first_microbatch=is_first_microbatch)
125
        y.backward(torch.randn_like(y))
126

127
        # Check that amax history matches expected values
128
        torch.testing.assert_close(
129
130
            amax_history_forward[:-1],
            ref_amax_history_forward[:-1],
131
132
        )
        torch.testing.assert_close(
133
134
            amax_history_backward[:-1],
            ref_amax_history_backward[:-1],
135
        )
136
137
138
139
140
141
142
143
144
145

        # Expected scale and scale inverse
        if amax_compute_algo == "max":
            ref_amax_forward = amax_history_forward.max(dim=0).values
            ref_amax_backward = amax_history_backward.max(dim=0).values
        elif amax_compute_algo == "most_recent":
            ref_amax_forward = amax_history_forward[-1]
            ref_amax_backward = amax_history_backward[-1]
        else:
            raise ValueError(f"{amax_compute_algo=} is not supported")
146
147
        ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
        ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
148
149
        # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
        # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
150
151
152

        # Check that scale and scale inverse match expected values
        # Note: scale and scale inverse are only updated when amax is updated
153
        torch.testing.assert_close(
154
155
            scale_forward[0],
            ref_scale_forward[0],
156
        )
157
158
159
160
161
        if update_weight_amax:
            torch.testing.assert_close(
                scale_forward[1],
                ref_scale_forward[1],
            )
162
        torch.testing.assert_close(
163
164
            scale_backward[0],
            ref_scale_backward[0],
165
        )
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    @pytest.mark.parametrize("amax_history_len", [31, 1024])
    @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
    def test_fp8_scale_update_with_linear_fuser_op(
        self,
        amax_history_len: int,
        amax_compute_algo: str,
        margin: float = 2,
        num_steps: int = 4,
        in_shape: tuple[int] = (16, 16),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
    ):

        # Construct linear op
        op = te_ops.BasicLinear(in_shape[-1], in_shape[-1])

183
        # FP8 recipe
184
185
        forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
        backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
186
187
188
189
190
191
192
193
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=amax_history_len,
            amax_compute_algo=amax_compute_algo,
        )

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        # Perform training steps
        x_history = []
        w_history = []
        dy_history = []
        for step in range(num_steps):

            # Fill tensors with known values
            x_history.append(step + 0.25)
            w_history.append(step + 0.5)
            dy_history.append(step + 0.75)
            x = torch.full(
                in_shape,
                x_history[-1],
                dtype=dtype,
                device=device,
                requires_grad=True,
            )
            dy = torch.full(
                in_shape,
                dy_history[-1],
                dtype=dtype,
                device=device,
            )
            with torch.no_grad():
                op.weight.fill_(w_history[-1])

            # Forward and backward pass
            with te.fp8_autocast(fp8_recipe=recipe):
                y = op(x)
            y.backward(dy)

Jan Bielak's avatar
Jan Bielak committed
225
226
227
228
229
230
231
232
233
234
235
            def check_metas(
                test_scale: float,
                test_amax_history: torch.Tensor,
                ref_amax_history_list: list[float],
                stage: str,
            ):
                """Check that meta tensors match expected values"""

                # Compute amax
                if len(ref_amax_history_list) > amax_history_len:
                    ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :]
236
                ref_amax_history = torch.tensor(
Jan Bielak's avatar
Jan Bielak committed
237
                    ref_amax_history_list,
238
239
240
                    dtype=torch.float32,
                    device=device,
                )
Jan Bielak's avatar
Jan Bielak committed
241
242
243
244
245
246
247
248
                if amax_compute_algo == "max":
                    ref_amax = max(ref_amax_history_list)
                elif amax_compute_algo == "most_recent":
                    ref_amax = ref_amax_history_list[-1]
                else:
                    raise RuntimeError(f"{amax_compute_algo=} is not supported")

                # Compare amax history
249
250
251
252
253
254
255
256
257
258
259
260
261
262
                tols = dict(rtol=0, atol=0)
                torch.testing.assert_close(
                    test_amax_history[-(step + 1) :],
                    ref_amax_history[: (step + 1)],
                    **tols,
                )

                # Compute scale
                max_val = {
                    "forward": 448.0,
                    "backward": 57344.0,
                }[stage]
                ref_scale = (max_val / ref_amax) / (2**margin)

Jan Bielak's avatar
Jan Bielak committed
263
                # Compare scale
264
                torch.testing.assert_close(
Jan Bielak's avatar
Jan Bielak committed
265
                    test_scale,
266
267
268
                    ref_scale,
                )

Jan Bielak's avatar
Jan Bielak committed
269
270
271
272
273
274
275
276
277
278
            # Get scaling factors
            x_test_scale = op.get_quantizer("forward", 0).scale.item()
            w_test_scale = op.get_quantizer("forward", 1).scale.item()
            dy_test_scale = op.get_quantizer("backward", 0).scale.item()

            # Get amax histories
            x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0]
            w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1]
            dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0]

279
            # Check that results match expected values
Jan Bielak's avatar
Jan Bielak committed
280
281
282
            check_metas(x_test_scale, x_test_history, x_history, "forward")
            check_metas(w_test_scale, w_test_history, w_history, "forward")
            check_metas(dy_test_scale, dy_test_history, dy_history, "backward")
283

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
    @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
    @pytest.mark.parametrize(
        "fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"]
    )
    def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype):

        if fp8_dtype == tex.DType.kFloat8E4M3:
            fp8_format = transformer_engine.common.recipe.Format.E4M3
            fp8_max = fp8_format.value.max_fwd
        elif fp8_dtype == tex.DType.kFloat8E5M2:
            fp8_format = transformer_engine.common.recipe.Format.HYBRID
            fp8_max = fp8_format.value.max_bwd
        else:
            raise ValueError(f"{fp8_dtype=} is not supported")

        scaling_factor_compute_algo = None
        if fused_update:
            scaling_factor_compute_algo = (
303
304
305
                lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
                    amax, scale, fp8_max, recipe.margin
                )
306
307
308
309
310
311
312
            )
        recipe = transformer_engine.common.recipe.DelayedScaling(
            fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
        )

        # Setup fp8_meta dictionary
        def setup_fp8_meta():
313
            with te.fp8_autocast(fp8_recipe=recipe):
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
                module = te.Linear(16, 16)
                y = module(torch.zeros([16, 16], device="cuda"))
            y.backward(torch.zeros_like(y))
            return module.fp8_meta

        fp8_meta = setup_fp8_meta()
        forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)

        # Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose
        fp8_meta[forward_key] = tex.FP8TensorMeta()
        fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda")
        fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")

        # test different scenarios
        if amax_case == "zero":
329
330
331
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[0]], dtype=torch.float32, device="cuda"
            )
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
            expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
        elif amax_case == "tiny":
            # calculate the minimum amax value that results in a FP32 maximum scale
            fp32_max = torch.tensor(torch.finfo(torch.float32).max)
            tiny_amax = fp8_max / fp32_max
            # make the amax less than the minimum amax so that the scale will be infinite
            amax_value = tiny_amax / 2
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[amax_value]], dtype=torch.float32, device="cuda"
            )
            # expected scale is FP32_max
            expected_scale = fp32_max.view(1).cuda()
        elif amax_case == "normal":
            # plus a small epsilon to avoid zero amax
            amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5
            fp8_meta[forward_key].amax_history = amax_value.view(1, 1)
            expected_scale = fp8_max / amax_value
        elif amax_case == "inf":
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[torch.inf]], dtype=torch.float32, device="cuda"
            )
            expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
        elif amax_case == "nan":
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[torch.nan]], dtype=torch.float32, device="cuda"
            )
            expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")

        if fused_update:
            tex.fused_amax_and_scale_update_after_reduction(
                fp8_meta[forward_key].amax_history.clone().view(-1),
                [fp8_meta[forward_key].amax_history],
                [fp8_meta[forward_key].scale],
                recipe.amax_compute_algo,
                fp8_dtype,
                recipe.margin,
            )
        else:
            _amax_and_scale_update(
                fp8_meta[forward_key].amax_history,
                fp8_meta[forward_key].scale,
                fp8_max,
                recipe,
            )

        torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470

    @pytest.mark.parametrize(
        "model_init_recipe",
        [
            pytest.param(
                MXFP8BlockScaling(),
                marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
            ),
            pytest.param(
                Float8BlockScaling(),
                marks=pytest.mark.skipif(
                    not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
                ),
            ),
        ],
    )
    def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe):
        with fp8_model_init(enabled=True, recipe=model_init_recipe):
            linear = Linear(32, 32).cuda()

        x = torch.randn(32, 32, device="cuda")
        with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
            with pytest.raises(RuntimeError) as excinfo:
                _ = linear(x)
            assert "Recipe mismatch for " in str(excinfo.value)

    @pytest.mark.parametrize(
        "target_recipe_class, expected_quantizer_type, available_flag, reason",
        [
            pytest.param(
                MXFP8BlockScaling,
                MXFP8Quantizer,
                mxfp8_available,
                reason_for_no_mxfp8,
                id="DelayedScaling->MXFP8BlockScaling",
            ),
            pytest.param(
                Float8BlockScaling,
                Float8BlockQuantizer,
                fp8_block_scaling_available,
                reason_for_no_fp8_block_scaling,
                id="DelayedScaling->Float8BlockScaling",
            ),
        ],
    )
    def test_dynamic_recipe_update(
        self, target_recipe_class, expected_quantizer_type, available_flag, reason
    ):
        if not available_flag:
            pytest.skip(reason)

        in_features = 32
        out_features = 32
        batch_size = 32
        linear = Linear(in_features, out_features).cuda()
        initial_recipe = DelayedScaling()

        # Run initial iterations with DelayedScaling
        for _ in range(3):
            x = torch.randn(batch_size, in_features, device="cuda")
            with fp8_autocast(enabled=True, fp8_recipe=initial_recipe):
                y = linear(x)
            loss = y.mean()
            loss.backward()

        for quantizer in linear.quantizers["scaling_fwd"]:
            assert isinstance(quantizer, Float8Quantizer)

        # Change recipe
        target_recipe = target_recipe_class()

        # Run subsequent iterations with the target recipe
        for i in range(3):
            x = torch.randn(batch_size, in_features, device="cuda")
            if i == 0:
                # Expect a warning on the first iteration with the new recipe
                with pytest.warns(UserWarning, match="Recipe type changed"):
                    with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
                        y = linear(x)
                for quantizer in linear.quantizers["scaling_fwd"]:
                    assert isinstance(quantizer, expected_quantizer_type)
            else:
                # No warning expected on subsequent iterations
                with warnings.catch_warnings():
                    warnings.simplefilter("error")  # Raise error if unexpected warning occurs
                    with fp8_autocast(enabled=True, fp8_recipe=target_recipe):
                        y = linear(x)
            loss = y.mean()
            loss.backward()

        # Final check
        for quantizer in linear.quantizers["scaling_fwd"]:
            assert isinstance(quantizer, expected_quantizer_type)
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501

    @pytest.mark.parametrize(
        "module_class",
        [
            Linear,
            LayerNormLinear,
            LayerNormMLP,
            GroupedLinear,
        ],
    )
    def test_quantizer_update(self, module_class):
        in_features = 32
        out_features = 32
        batch_size = 32

        recipe = DelayedScaling(amax_history_len=1024)
        with fp8_model_init(recipe=recipe):
            if module_class == GroupedLinear:
                module = module_class(1, in_features, out_features).cuda()
            else:
                module = module_class(in_features, out_features).cuda()

        x = torch.randn(batch_size, in_features, device="cuda")
        recipe = DelayedScaling(amax_history_len=1)
        with fp8_autocast(enabled=True, fp8_recipe=recipe):
            warn_msg = "Quantizer is being updated, this may affect model behavior"
            with pytest.warns(UserWarning, match=warn_msg):
                if module_class == GroupedLinear:
                    y = module(x, [batch_size])
                else:
                    y = module(x)