test_recipe.py 15 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
#
# See LICENSE for license information.

5
from typing import Iterable, Optional
6
7
8
9
10
11

import pytest
import torch

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
12
import transformer_engine_torch as tex
13
14
from transformer_engine.pytorch.fp8 import (
    FP8GlobalStateManager,
15
    _amax_and_scale_update,
16
17
    get_default_fp8_recipe,
)
18
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
19
20
import transformer_engine.pytorch.ops as te_ops
import transformer_engine_torch as tex
21
22
23
24

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

25

26
# FP8 per tensor delayed scaling
27
28
29
30
31
32
33
34
35
36
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

37
    @pytest.mark.parametrize("amax_history_len", [31, 1024])
38
39
    @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
    @pytest.mark.parametrize("is_first_microbatch", [None, True, False])
40
    def test_fp8_scale_update_with_linear_module(
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
        self,
        amax_history_len: int,
        amax_compute_algo: str,
        is_first_microbatch: Optional[bool],
        margin: int = 2,
    ):

        # Construct linear module
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=amax_history_len,
            amax_compute_algo=amax_compute_algo,
        )
56
        with te.fp8_autocast(fp8_recipe=recipe):
57
            module = te.Linear(16, 16)
58
59
60
61
            y = module(
                torch.randn([16, 16], device="cuda"),
                is_first_microbatch=True,
            )
62
63
64
65
66
67
68
        y.backward(torch.zeros_like(y))

        # Get amax history and scaling factors
        fp8_meta = module.fp8_meta
        forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
        amax_history_forward = fp8_meta[forward_key].amax_history
        scale_forward = fp8_meta[forward_key].scale
69
        # scale_inv_forward = fp8_meta[forward_key].scale_inv
70
71
72
        backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
        amax_history_backward = fp8_meta[backward_key].amax_history
        scale_backward = fp8_meta[backward_key].scale
73
        # scale_inv_backward = fp8_meta[backward_key].scale_inv
74
75
76

        # Tweak amax history and scaling factors
        amax_history_forward.copy_(2 * torch.rand_like(amax_history_forward) + 0.5)
77
        amax_history_forward[0, :].zero_()
78
        scale_forward.copy_(2 * torch.rand_like(scale_forward) + 0.5)
79
        # scale_inv_forward.copy_(torch.reciprocal(scale_forward))
80
        amax_history_backward[0, :].zero_()
81
82

        # Expected amax history after update
83
84
85
86
87
88
89
90
91
92
        # Note: amax history is only updated when amax is updated
        update_weight_amax = is_first_microbatch is None or is_first_microbatch
        ref_amax_history_forward = amax_history_forward.clone()
        ref_amax_history_forward[:, 0].copy_(torch.roll(amax_history_forward[:, 0], -1))
        if update_weight_amax:
            ref_amax_history_forward[:, 1].copy_(torch.roll(amax_history_forward[:, 1], -1))
        ref_amax_history_forward[0, :].zero_()
        ref_amax_history_backward = amax_history_backward.clone()
        ref_amax_history_backward[:, 0].copy_(torch.roll(amax_history_backward[:, 0], -1))
        ref_amax_history_backward[0, :].zero_()
93
94
95
96
97
98

        # Expected scale and scale inverse
        if amax_compute_algo == "max":
            ref_amax_forward = amax_history_forward.max(dim=0).values
            ref_amax_backward = amax_history_backward.max(dim=0).values
        elif amax_compute_algo == "most_recent":
99
100
            ref_amax_forward = amax_history_forward[-1]
            ref_amax_backward = amax_history_backward[-1]
101
102
        else:
            raise ValueError(f"{amax_compute_algo=} is not supported")
103
104
        ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
        ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
105
        # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
106
        update_weight_amax = is_first_microbatch is None or is_first_microbatch
107
108
109
        # if not update_weight_amax:
        #    ref_scale_inv_forward[1].copy_(scale_inv_forward[1])
        # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
110

111
        # Perform forward, backward, and optimizer steps to update fp8_meta
112
        with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
113
            x = torch.randn([16, 16], device="cuda")
114
            y = module(x, is_first_microbatch=is_first_microbatch)
115
        y.backward(torch.randn_like(y))
116

117
        # Check that amax history matches expected values
118
        torch.testing.assert_close(
119
120
            amax_history_forward[:-1],
            ref_amax_history_forward[:-1],
121
122
        )
        torch.testing.assert_close(
123
124
            amax_history_backward[:-1],
            ref_amax_history_backward[:-1],
125
        )
126
127
128
129
130
131
132
133
134
135

        # Expected scale and scale inverse
        if amax_compute_algo == "max":
            ref_amax_forward = amax_history_forward.max(dim=0).values
            ref_amax_backward = amax_history_backward.max(dim=0).values
        elif amax_compute_algo == "most_recent":
            ref_amax_forward = amax_history_forward[-1]
            ref_amax_backward = amax_history_backward[-1]
        else:
            raise ValueError(f"{amax_compute_algo=} is not supported")
136
137
        ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
        ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
138
139
        # ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
        # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
140
141
142

        # Check that scale and scale inverse match expected values
        # Note: scale and scale inverse are only updated when amax is updated
143
        torch.testing.assert_close(
144
145
            scale_forward[0],
            ref_scale_forward[0],
146
        )
147
148
149
150
151
        if update_weight_amax:
            torch.testing.assert_close(
                scale_forward[1],
                ref_scale_forward[1],
            )
152
        torch.testing.assert_close(
153
154
            scale_backward[0],
            ref_scale_backward[0],
155
        )
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    @pytest.mark.parametrize("amax_history_len", [31, 1024])
    @pytest.mark.parametrize("amax_compute_algo", ["max", "most_recent"])
    def test_fp8_scale_update_with_linear_fuser_op(
        self,
        amax_history_len: int,
        amax_compute_algo: str,
        margin: float = 2,
        num_steps: int = 4,
        in_shape: tuple[int] = (16, 16),
        dtype: torch.dtype = torch.float32,
        device: torch.device = "cuda",
    ):

        # Construct linear op
        op = te_ops.BasicLinear(in_shape[-1], in_shape[-1])

173
        # FP8 recipe
174
175
        forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)
        backward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=False)
176
177
178
179
180
181
182
183
184
185
186
187
188
        fp8_format = transformer_engine.common.recipe.Format.HYBRID
        recipe = transformer_engine.common.recipe.DelayedScaling(
            margin=margin,
            fp8_format=fp8_format,
            amax_history_len=amax_history_len,
            amax_compute_algo=amax_compute_algo,
        )

        # Get FP8 meta tensors
        with te.fp8_autocast(fp8_recipe=recipe):
            x_fp8_meta = op.get_quantizer("forward", 0)
            w_fp8_meta = op.get_quantizer("forward", 1)
            dy_fp8_meta = op.get_quantizer("backward", 0)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241

        # Perform training steps
        x_history = []
        w_history = []
        dy_history = []
        for step in range(num_steps):

            # Fill tensors with known values
            x_history.append(step + 0.25)
            w_history.append(step + 0.5)
            dy_history.append(step + 0.75)
            x = torch.full(
                in_shape,
                x_history[-1],
                dtype=dtype,
                device=device,
                requires_grad=True,
            )
            dy = torch.full(
                in_shape,
                dy_history[-1],
                dtype=dtype,
                device=device,
            )
            with torch.no_grad():
                op.weight.fill_(w_history[-1])

            # Forward and backward pass
            with te.fp8_autocast(fp8_recipe=recipe):
                y = op(x)
            y.backward(dy)

            def check_amax_history(
                fp8_meta: dict,
                ref_amax_history: Iterable[float],
            ) -> None:
                """Check that amax history matches expected values"""
                if len(ref_amax_history) > amax_history_len:
                    ref_amax_history = ref_amax_history[-amax_history_len:]
                ref_amax_history = torch.tensor(
                    ref_amax_history,
                    dtype=torch.float32,
                    device=device,
                )
                test_amax_history = fp8_meta.amax_history[:, 0]
                tols = dict(rtol=0, atol=0)
                torch.testing.assert_close(
                    test_amax_history[-(step + 1) :],
                    ref_amax_history[: (step + 1)],
                    **tols,
                )

            def check_scale(
242
                quantizer: Float8Quantizer,
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                ref_amax_history: Iterable[float],
                stage: str,
            ):
                """Check that scale and scale reciprocal match expected values"""

                # Compute amax
                if len(ref_amax_history) > amax_history_len:
                    ref_amax_history = ref_amax_history[-(amax_history_len + 1) :]
                if amax_compute_algo == "max":
                    ref_amax = max(ref_amax_history)
                elif amax_compute_algo == "most_recent":
                    ref_amax = ref_amax_history[-1]
                else:
                    raise RuntimeError(f"{amax_compute_algo=} is not supported")

                # Compute scale
                max_val = {
                    "forward": 448.0,
                    "backward": 57344.0,
                }[stage]
                ref_scale = (max_val / ref_amax) / (2**margin)

                # Check values in FP8 meta tensors
                torch.testing.assert_close(
267
                    quantizer.scale.item(),
268
269
270
271
272
273
274
275
                    ref_scale,
                )

            # Check that results match expected values
            check_scale(x_fp8_meta, x_history, "forward")
            check_scale(w_fp8_meta, w_history, "forward")
            check_scale(dy_fp8_meta, dy_history, "backward")

276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
    @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"])
    @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"])
    @pytest.mark.parametrize(
        "fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2], ids=["E4M3", "E5M2"]
    )
    def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype):

        if fp8_dtype == tex.DType.kFloat8E4M3:
            fp8_format = transformer_engine.common.recipe.Format.E4M3
            fp8_max = fp8_format.value.max_fwd
        elif fp8_dtype == tex.DType.kFloat8E5M2:
            fp8_format = transformer_engine.common.recipe.Format.HYBRID
            fp8_max = fp8_format.value.max_bwd
        else:
            raise ValueError(f"{fp8_dtype=} is not supported")

        scaling_factor_compute_algo = None
        if fused_update:
            scaling_factor_compute_algo = (
295
296
297
                lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
                    amax, scale, fp8_max, recipe.margin
                )
298
299
300
301
302
303
304
            )
        recipe = transformer_engine.common.recipe.DelayedScaling(
            fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
        )

        # Setup fp8_meta dictionary
        def setup_fp8_meta():
305
            with te.fp8_autocast(fp8_recipe=recipe):
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
                module = te.Linear(16, 16)
                y = module(torch.zeros([16, 16], device="cuda"))
            y.backward(torch.zeros_like(y))
            return module.fp8_meta

        fp8_meta = setup_fp8_meta()
        forward_key = FP8GlobalStateManager.get_meta_tensor_key(forward=True)

        # Replace the fp8_meta[forward_key] with a new TensorMeta for test purpose
        fp8_meta[forward_key] = tex.FP8TensorMeta()
        fp8_meta[forward_key].scale = torch.ones(1, dtype=torch.float32, device="cuda")
        fp8_meta[forward_key].scale_inv = torch.ones(1, dtype=torch.float32, device="cuda")

        # test different scenarios
        if amax_case == "zero":
321
322
323
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[0]], dtype=torch.float32, device="cuda"
            )
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
            expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
        elif amax_case == "tiny":
            # calculate the minimum amax value that results in a FP32 maximum scale
            fp32_max = torch.tensor(torch.finfo(torch.float32).max)
            tiny_amax = fp8_max / fp32_max
            # make the amax less than the minimum amax so that the scale will be infinite
            amax_value = tiny_amax / 2
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[amax_value]], dtype=torch.float32, device="cuda"
            )
            # expected scale is FP32_max
            expected_scale = fp32_max.view(1).cuda()
        elif amax_case == "normal":
            # plus a small epsilon to avoid zero amax
            amax_value = torch.rand(1, dtype=torch.float32, device="cuda") + 1e-5
            fp8_meta[forward_key].amax_history = amax_value.view(1, 1)
            expected_scale = fp8_max / amax_value
        elif amax_case == "inf":
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[torch.inf]], dtype=torch.float32, device="cuda"
            )
            expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
        elif amax_case == "nan":
            fp8_meta[forward_key].amax_history = torch.tensor(
                [[torch.nan]], dtype=torch.float32, device="cuda"
            )
            expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")

        if fused_update:
            tex.fused_amax_and_scale_update_after_reduction(
                fp8_meta[forward_key].amax_history.clone().view(-1),
                [fp8_meta[forward_key].amax_history],
                [fp8_meta[forward_key].scale],
                recipe.amax_compute_algo,
                fp8_dtype,
                recipe.margin,
            )
        else:
            _amax_and_scale_update(
                fp8_meta[forward_key].amax_history,
                fp8_meta[forward_key].scale,
                fp8_max,
                recipe,
            )

        torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)