test_float8tensor.py 10.7 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.

from collections.abc import Iterable
6
import io
7
8
9
10
11
12
13
14
from typing import Any, Dict, List, Tuple, Union

import pytest
import torch

import transformer_engine.common.recipe
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
15
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8Tensor
16
import transformer_engine_torch as tex
17
18
19
20
21
22
23
24
25
26
27
28

# PyTorch tensor dtypes
_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16]
# TE FP8 dtypes
_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]

# Numerical tolerances with FP8 types
_tols: Dict[tex.DType, Dict[str, float]] = {
    tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675),  # epsilon = 0.0625
    tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125),  # epsilon = 0.125
}

29

30
31
32
33
34
35
36
def _to_list(x: Union[Iterable, Any]) -> List:
    """Convert to list if iterable, otherwise put in singleton list"""
    if isinstance(x, Iterable):
        return list(x)
    else:
        return [x]

37

38
39
40
41
42
43
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()

44

45
46
47
48
49
50
51
52
53
54
55
56
57
58
def to_float8(
    tensor: torch.Tensor,
    fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
    scale: float = 1.0,
) -> Float8Tensor:
    """Cast tensor to FP8"""
    quantizer = Float8Quantizer(
        scale=torch.full([1], scale, dtype=torch.float32, device="cuda"),
        amax=torch.empty([1], dtype=torch.float32, device="cuda"),
        fp8_dtype=fp8_dtype,
    )
    return quantizer(tensor.cuda())


59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:

    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    def test_constructor(
        self,
        dims: DimsType = 1,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale_inv: float = 0.375,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Call constructor and perform sanity checks"""
        dims = _to_list(dims)
        tensor = Float8Tensor(
79
80
            shape=dims,
            dtype=dtype,
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
            data=torch.zeros(dims, device="cuda", dtype=torch.uint8),
            fp8_dtype=fp8_dtype,
            fp8_scale_inv=torch.full([1], scale_inv),
        )
        assert list(tensor.size()) == dims, "Incorrect dims"
        assert tensor.dtype == dtype, "Incorrect nominal dtype"
        assert tensor.is_cuda, "Incorrect device"

    def _test_quantize_dequantize(
        self,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
        dims: DimsType = 23,
    ) -> None:
        """Check numerical error when casting to FP8 and back"""

        # Initialize random data
        x_ref = 2 * torch.rand(_to_list(dims), dtype=dtype, device="cpu") - 1

        # Cast to FP8 and back
102
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
103
        x_fp8 = x_fp8.dequantize().cpu()
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

        # Check results
        torch.testing.assert_close(x_fp8, x_ref, **_tols[fp8_dtype])

        # Make sure we are not trivially passing the test
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8, -x_ref, **_tols[fp8_dtype])

    @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes)
    @pytest.mark.parametrize("dtype", _dtypes)
    def test_quantize_dequantize_dtypes(
        self,
        fp8_dtype: tex.DType,
        dtype: torch.dtype,
    ) -> None:
        self._test_quantize_dequantize(fp8_dtype=fp8_dtype, dtype=dtype)

    @pytest.mark.parametrize("scale", [0.375, 1, 3.5])
    def test_quantize_dequantize_scales(self, scale: float) -> None:
        self._test_quantize_dequantize(scale=scale)

125
    @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
        self._test_quantize_dequantize(dims=dims)

    def test_basic_ops(
        self,
        dims: DimsType = 23,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Test basic out-of-place ops"""

        # Initialize random data
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
142
143
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
144
145
        x_ref = x_fp8.dequantize()
        y_ref = y_fp8.dequantize()
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

        # Exact operations
        torch.testing.assert_close(-x_fp8, -x_ref, rtol=0, atol=0)
        torch.testing.assert_close(x_fp8.abs(), x_ref.abs(), rtol=0, atol=0)

        # Operations with numerical error
        tols = _tols[fp8_dtype]
        torch.testing.assert_close(x_fp8 + y_fp8, x_ref + y_ref, **tols)
        torch.testing.assert_close(x_fp8 - y_fp8, x_ref - y_ref, **tols)
        torch.testing.assert_close(x_fp8 * y_fp8, x_ref * y_ref, **tols)
        torch.testing.assert_close(x_fp8 + y_ref, x_ref + y_ref, **tols)
        torch.testing.assert_close(x_ref + y_fp8, x_ref + y_ref, **tols)
        torch.testing.assert_close(torch.sin(x_fp8), torch.sin(x_ref), **tols)

        # Make sure we are not trivially passing tests
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

    @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]])
    def test_chunk_op(
        self,
        dims: DimsType,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Test for ops for which shape of inputs and outputs differ."""

        # Initialize random data
        dims = _to_list(dims)
        x_ref = torch.randn(dims, dtype=dtype, device="cpu")
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0)

        # Get chunks.
        chunk1, chunk2 = x_fp8.chunk(2, dim=0)

        # Test chunks.
        torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0)
        torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0)

        # Check shapes.
        assert (
            chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
        ), "Wrong shape for chunk1"
        assert (
            chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
        ), "Wrong shape for chunk2"
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    def test_inplace_ops(
        self,
        dims: DimsType = 23,
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 3.5,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        """Test in-place ops"""

        # Initialize random data
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
        y_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
207
208
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
        y_fp8 = to_float8(y_ref, fp8_dtype=fp8_dtype, scale=scale)
209
210
        x_ref = x_fp8.dequantize()
        y_ref = y_fp8.dequantize()
211
212
213
214
215
216

        # In-place operations
        tols = _tols[fp8_dtype]
        x_fp8 += y_ref
        x_ref += y_ref
        torch.testing.assert_close(x_fp8, x_ref, **tols)
217
        x_ref = x_fp8.dequantize()
218
219
220
        x_fp8 -= y_fp8
        x_ref -= y_fp8
        torch.testing.assert_close(x_fp8, x_ref, **tols)
221
        x_ref = x_fp8.dequantize()
222
223
224
        x_fp8 *= 2
        x_ref *= 2
        torch.testing.assert_close(x_fp8, x_ref, **tols)
225
        x_ref = x_fp8.dequantize()
226
227
228
229
230
231

        # Make sure we are not trivially passing tests
        x_ref += 123
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8, x_ref, **tols)

232
233
    def test_serialization(
        self,
234
        dims: DimsType = [2, 3, 5],
235
236
237
238
239
240
241
242
        fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
        scale: float = 0.5,
        dtype: torch.dtype = torch.float32,
    ):

        # Initialize random data
        dims = _to_list(dims)
        x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1
243
        x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=scale)
244
        x_ref = x_fp8.dequantize()
245
246
247
248
249
250
251
252
253
254
255
256

        # Serialize tensor
        byte_stream = io.BytesIO()
        torch.save(x_fp8, byte_stream)
        x_bytes = byte_stream.getvalue()

        # Mess up and delete old tensor
        x_fp8._data.zero_()
        x_fp8._scale_inv.zero_()
        del x_fp8, byte_stream

        # Deserialize tensor
257
        x_fp8 = torch.load(io.BytesIO(x_bytes), weights_only=False)
258
259
260
261
262
263
264
265
266
267
268
        del x_bytes

        # Check results
        tols = dict(rtol=0, atol=0)
        torch.testing.assert_close(x_fp8, x_ref, **tols)

        # Make sure we are not trivially passing tests
        x_fp8._data.zero_()
        x_fp8._scale_inv.zero_()
        with pytest.raises(AssertionError):
            torch.testing.assert_close(x_fp8, x_ref, **tols)
269
270
271
272
273
274

    def test_set_data(self):
        """Test directly setting .data attr"""

        # Initialize Float8Tensor
        x0 = torch.zeros(4, dtype=torch.float32)
275
        x = to_float8(x0)
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
        assert isinstance(x, Float8Tensor)
        assert x0.size() == x.size() == x._data.size()
        assert x.dtype == torch.float32
        assert x.is_cuda and x._data.is_cuda
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert x.size() == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device

        # Set data to plain tensor
        x0 = torch.zeros((3, 2), dtype=torch.float16, device=x.device)
        x.data = x0
        assert isinstance(x, Float8Tensor)
        assert x0.size() == x.size() == x._data.size()
        assert x0.dtype == x.dtype
        assert x0.device == x.device == x._data.device
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert x.size() == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device

        # Set data to Float8Tensor
300
        x0 = to_float8(torch.zeros((4, 3, 1), dtype=torch.float32))
301
302
303
304
305
306
307
308
309
310
311
312
        x.data = x0
        assert isinstance(x, Float8Tensor)
        assert x0.size() == x.size() == x._data.size()
        assert x0.dtype == x.dtype
        assert x0.device == x.device == x._data.device
        assert x0._data is x._data
        assert x0._scale_inv is x._scale_inv
        y = x.dequantize()
        assert not isinstance(y, Float8Tensor)
        assert x.size() == y.size()
        assert x.dtype == y.dtype
        assert x.device == y.device