test_float8tensor.py 15.9 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

from collections.abc import Iterable
6
import io
7
from typing import Any, Dict, List, Tuple, Union, Optional
8
9
10
11
12
13
14

import pytest
import torch

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
15
16
17
18
19
20
21
from transformer_engine.pytorch.tensor.float8_tensor import (
    Float8Quantizer,
    Float8Tensor,
    Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch
from transformer_engine.pytorch.utils import non_tn_fp8_gemm_supported
22
import transformer_engine_torch as tex
23

24
25
from references.ref_per_tensor_cs import ref_per_tensor_cs_cast

26
27
28
29
30
31
32
33
34
35
36
# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]

# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
    tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675),  # epsilon = 0.0625
    tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125),  # epsilon = 0.125
}

37

38
39
40
41
42
43
44
def _to_list(x: Union[Iterable, Any]) -> List:
    """Convert to list if iterable, otherwise put in singleton list"""
    if isinstance(x, Iterable):
        return list(x)
    else:
        return [x]

45

46
47
48
49
50
51
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

52

53
# delayed scaling
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def to_float8(
    tensor: torch.Tensor,
    fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
    scale: float = 1.0,
) -> Float8Tensor:
    """Cast tensor to FP8"""
    quantizer = Float8Quantizer(
        scale=torch.full([1], scale, dtype=torch.float32, device="cuda"),
        amax=torch.empty([1], dtype=torch.float32, device="cuda"),
        fp8_dtype=fp8_dtype,
    )
    return quantizer(tensor.cuda())


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# current scaling
def to_float8_CS(
    tensor: torch.Tensor,
    fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
    return_transpose: bool = False,
    force_pow_2_scales: bool = False,
    amax_epsilon: float = 0.0,
) -> Float8Tensor:
    """Cast tensor to FP8"""
    tensor = tensor.cuda()
    quantizer = Float8CurrentScalingQuantizer(
        fp8_dtype=fp8_dtype,
        device=tensor.device,
        force_pow_2_scales=force_pow_2_scales,
        amax_epsilon=amax_epsilon,
    )
    if return_transpose:
        quantizer.set_usage(rowwise=True, columnwise=True)
    else:
        quantizer.set_usage(rowwise=True, columnwise=False)
    return quantizer(tensor)


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    def test_constructor(
        self,
        dims: DimsType = 1,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale_inv: float = 0.375,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Call constructor and perform sanity checks"""
        dims = _to_list(dims)
        tensor = Float8Tensor(
111
112
            shape=dims,
            dtype=dtype,
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
            fp8_dtype=fp8_dtype,
            fp8_scale_inv=torch.full([1], scale_inv),
        )
        assert list(tensor.size()) == dims, "Incorrect dims"
        assert tensor.dtype == dtype, "Incorrect nominal dtype"
        assert tensor.is_cuda, "Incorrect device"

    def _test_quantize_dequantize(
        self,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
        dims: DimsType = 23,
    ) -> None:
        """Check numerical error when casting to FP8 and back"""

        # Initialize random data
        x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1

        # Cast to FP8 and back
134
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
135
        x_fp8 = x_fp8.dequantize().cpu()
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

        # Check results
        torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dtypes)
    def test_quantize_dequantize_dtypes(
        self,
        fp8_dtype: tex.DType,
        dtype: torch.dtype,
    ) -> None:
        self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)

    @pytest.mark.parametrize("scale", [0.375, 1, 3.5])
    def test_quantize_dequantize_scales(self, scale: float) -> None:
        self._test_quantize_dequantize(scale=scale)

157
    @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
158
159
160
    def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
        self._test_quantize_dequantize(dims=dims)

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize("noop", [True, False])
    def test_quantize_dequantize_noop(
        self, fp8_dtype: tex.DType, dtype: torch.dtype, noop: bool
    ) -> None:
        noop_tensor = torch.zeros(1, dtype=torch.float32, device="cuda")
        if noop:
            noop_tensor = torch.ones(1, dtype=torch.float32, device="cuda")
        dims = 23
        scale: float = 3.5

        # Initialize random data
        x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1

        # Cast to FP8 and back
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        # if noop, then when we input a different tensor, output should still be x_fp8_orig
        x_ref_noop_test = 2 * x_ref.cuda()
        x_fp8_orig = x_fp8.clone()
        x_fp8.quantize_(x_ref_noop_test, noop_flag=noop_tensor)
        if noop_tensor.item() == 1.0:
            torch.testing.assert_close(x_fp8, x_fp8_orig, atol=0, rtol=0)
        else:
            torch.testing.assert_close(x_fp8, x_ref_noop_test, **_tols[fp8_dtype])

187
188
189
190
191
192
193
194
195
196
197
198
199
    def test_basic_ops(
        self,
        dims: DimsType = 23,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Test basic out-of-place ops"""

        # Initialize random data
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
200
201
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
202
203
        x_ref = x_fp8.dequantize()
        y_ref = y_fp8.dequantize()
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        # Exact operations
        torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
        torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)

        # Operations with numerical error
        tols = _tols[fp8_dtype]
        torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
        torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
        torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
        torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
        torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
        torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)

        # Make sure we are not trivially passing tests
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

    @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]])
    def test_chunk_op(
        self,
        dims: DimsType,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Test for ops for which shape of inputs and outputs differ."""

        # Initialize random data
        dims = _to_list(dims)
        x_ref = torch.randn(dims, dtype=dtype, device="cpu")
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0)

        # Get chunks.
        chunk1, chunk2 = x_fp8.chunk(2, dim=0)

        # Test chunks.
        torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0)
        torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0)

        # Check shapes.
        assert (
            chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
        ), "Wrong shape for chunk1"
        assert (
            chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
        ), "Wrong shape for chunk2"
251
252
253
254
255
256
257
258
259
260
261
262
263
264

    def test_inplace_ops(
        self,
        dims: DimsType = 23,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Test in-place ops"""

        # Initialize random data
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
265
266
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
267
268
        x_ref = x_fp8.dequantize()
        y_ref = y_fp8.dequantize()
269
270
271
272
273
274

        # In-place operations
        tols = _tols[fp8_dtype]
        x_fp8 += y_ref
        x_ref += y_ref
        torch.testing.assert_close(x_fp8, x_ref, **tols)
275
        x_ref = x_fp8.dequantize()
276
277
278
        x_fp8 -= y_fp8
        x_ref -= y_fp8
        torch.testing.assert_close(x_fp8, x_ref, **tols)
279
        x_ref = x_fp8.dequantize()
280
281
282
        x_fp8 *= 2
        x_ref *= 2
        torch.testing.assert_close(x_fp8, x_ref, **tols)
283
        x_ref = x_fp8.dequantize()
284
285
286
287
288
289

        # Make sure we are not trivially passing tests
        x_ref += 123
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8, x_ref, **tols)

290
291
    def test_serialization(
        self,
292
        dims: DimsType = [2, 3, 5],
293
294
295
296
297
298
299
300
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 0.5,
        dtype: torch.dtype = torch.float32,
    ):

        # Initialize random data
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
301
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
302
        x_ref = x_fp8.dequantize()
303
304
305
306
307
308
309
310
311
312
313
314

        # Serialize tensor
        byte_stream = io.BytesIO()
        torch.save(x_fp8, byte_stream)
        x_bytes = byte_stream.getvalue()

        # Mess up and delete old tensor
        x_fp8._data.zero_()
        x_fp8._scale_inv.zero_()
        del x_fp8, byte_stream

        # Deserialize tensor
315
        x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False)
316
317
318
319
320
321
322
323
324
325
326
        del x_bytes

        # Check results
        tols = dict(rtol=0, atol=0)
        torch.testing.assert_close(x_fp8, x_ref, **tols)

        # Make sure we are not trivially passing tests
        x_fp8._data.zero_()
        x_fp8._scale_inv.zero_()
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8, x_ref, **tols)
327
328
329
330
331
332

    def test_set_data(self):
        """Test directly setting .data attr"""

        # Initialize Float8Tensor
        x0 = torch.zeros(4, dtype=torch.float32)
333
        x = to_float8(x0)
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
        assert isinstance(x, Float8Tensor)
        assert x0.size() == x.size() == x._data.size()
        assert x.dtype == torch.float32
        assert x.is_cuda and x._data.is_cuda
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert x.size() == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device

        # Set data to plain tensor
        x0 = torch.zeros((3, 2), dtype=torch.float16, device=x.device)
        x.data = x0
        assert isinstance(x, Float8Tensor)
        assert x0.size() == x.size() == x._data.size()
        assert x0.dtype == x.dtype
        assert x0.device == x.device == x._data.device
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert x.size() == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device

        # Set data to Float8Tensor
358
        x0 = to_float8(torch.zeros((4, 3, 1), dtype=torch.float32))
359
360
361
362
363
364
365
366
367
368
369
370
        x.data = x0
        assert isinstance(x, Float8Tensor)
        assert x0.size() == x.size() == x._data.size()
        assert x0.dtype == x.dtype
        assert x0.device == x.device == x._data.device
        assert x0._data is x._data
        assert x0._scale_inv is x._scale_inv
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert x.size() == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456


@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestCurrentScalingFloat8Tensor:

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dtypes)
    @pytest.mark.parametrize(
        "dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3], [128, 128], [611, 782]]
    )
    @pytest.mark.parametrize("return_transpose", [True, False], ids=str)
    @pytest.mark.parametrize("force_pow_2_scales", [True, False], ids=str)
    @pytest.mark.parametrize("amax_epsilon", [0.0, 1e-6], ids=str)
    def test_quantize(
        self,
        fp8_dtype: tex.DType,
        dtype: torch.dtype,
        dims: DimsType,
        return_transpose: bool,
        force_pow_2_scales: bool,
        amax_epsilon: float,
    ) -> None:
        """Check numerical error when casting to FP8"""

        # Skip invalid configurations
        if non_tn_fp8_gemm_supported() and return_transpose:
            pytest.skip("FP8 transpose is neither needed nor supported on current system")

        # Initialize random high precision data
        device = "cuda"
        x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1

        # Cast to FP8 and back
        x_fp8 = to_float8_CS(
            x_hp,
            fp8_dtype=fp8_dtype,
            return_transpose=return_transpose,
            force_pow_2_scales=force_pow_2_scales,
            amax_epsilon=amax_epsilon,
        )

        # get reference implementation of current scaling
        x_fp8_ref, sx_ref, x_fp8_t_ref, _ = ref_per_tensor_cs_cast(
            x_hp,
            fp8_dtype=fp8_dtype,
            return_transpose=return_transpose,
            force_pow_2_scales=force_pow_2_scales,
            amax_epsilon=amax_epsilon,
        )

        torch.testing.assert_close(x_fp8._data, x_fp8_ref.view(torch.uint8), atol=0.0, rtol=0.0)
        torch.testing.assert_close(x_fp8._scale_inv, sx_ref, atol=0.0, rtol=0.0)
        if return_transpose:
            torch.testing.assert_close(
                x_fp8._transpose, x_fp8_t_ref.view(torch.uint8), atol=0.0, rtol=0.0
            )

    @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3], ids=str)
    @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=str)
    @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
    def test_quantize_dequantize(
        self, fp8_dtype: tex.DType, dtype: torch.dtype, dims: DimsType
    ) -> None:
        """Check numerical error when casting to FP8 and back"""

        # Initialize random high precision data
        device = "cuda"
        x_hp = 2 * torch.rand(_to_list(dims), dtype=dtype, device=device) - 1

        # Cast to FP8 and back
        x_fp8 = to_float8_CS(x_hp, fp8_dtype=fp8_dtype)
        x_fp8_dequantized = x_fp8.dequantize()

        # Check results
        torch.testing.assert_close(x_fp8_dequantized, x_hp, **_tols[fp8_dtype])

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8_dequantized, -x_hp, **_tols[fp8_dtype])