test_grouped_tensor.py 14.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Tests for GroupedTensor class"""

from typing import List, Tuple
import pytest
import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.pytorch import (
    Quantizer,
    Float8Quantizer,
    Float8CurrentScalingQuantizer,
    Float8BlockQuantizer,
    MXFP8Quantizer,
    NVFP4Quantizer,
)
from transformer_engine.pytorch.constants import TE_DType_To_Torch
import transformer_engine_torch as tex

# Check available recipes
fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available(
    return_reason=True
)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)
nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True)

_quantization_params = [
    pytest.param(
        "fp8_delayed_scaling",
        marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
    ),
    pytest.param(
        "fp8_current_scaling",
        marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8),
    ),
    pytest.param(
        "fp8_blockwise",
        marks=pytest.mark.skipif(
            not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling
        ),
    ),
    pytest.param(
        "mxfp8",
        marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8),
    ),
    pytest.param(
        "nvfp4",
        marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4),
    ),
]


def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
    """Create quantizers for given quantization scheme"""

    if quantization == "fp8_delayed_scaling":
        quantizer = Float8Quantizer(
            scale=torch.ones(1, dtype=torch.float32, device="cuda"),
            amax=torch.zeros(1, dtype=torch.float32, device="cuda"),
            fp8_dtype=tex.DType.kFloat8E4M3,
        )
    elif quantization == "fp8_current_scaling":
        quantizer = Float8CurrentScalingQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            device="cuda",
        )
        quantizer.set_usage(rowwise=True, columnwise=False)
    elif quantization == "fp8_blockwise":
        quantizer = Float8BlockQuantizer(
            fp8_dtype=tex.DType.kFloat8E4M3,
            rowwise=True,
            columnwise=False,
            force_pow_2_scales=True,
            amax_epsilon=0.0,
            block_scaling_dim=1,
        )
    elif quantization == "mxfp8":
        quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
    elif quantization == "nvfp4":
        quantizer = NVFP4Quantizer(
            with_rht=False,
            with_post_rht_amax=False,
            with_2d_quantization=False,
            stochastic_rounding=False,
            with_random_sign_mask=False,
        )
    else:
        raise ValueError(f"Unknown quantization scheme: {quantization}")

    quantizer.internal = False

    return quantizer


def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor:
    if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"):
        return qtensor._data
    if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"):
        return qtensor._rowwise_data
    raise ValueError(f"Unknown quantization scheme: {quantization}")


def _rowwise_offset_bytes(numel: int, quantization: str) -> int:
    if quantization == "nvfp4":
        return numel // 2
    return numel


class TestGroupedTensor:
    @staticmethod
    def setup_class(cls) -> None:
        # Configure RNG
        seed = 1234
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    def test_basic_construction_all_same_shape(self) -> None:
        """Test GroupedTensor construction with all tensors having same shape"""
        num_tensors = 4
        shape = [(256, 512) for _ in range(num_tensors)]

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=None,
            device="cuda",
            dtype=torch.float32,
        )

        assert grouped_tensor.num_tensors == num_tensors
        assert grouped_tensor.all_same_shape()
        assert grouped_tensor.all_same_first_dim()
        assert grouped_tensor.all_same_last_dim()
        assert grouped_tensor.logical_shape == (num_tensors * 256, 512)
        assert grouped_tensor.get_common_first_dim() == 256
        assert grouped_tensor.get_common_last_dim() == 512
        assert grouped_tensor.has_data()

    def test_basic_construction_varying_first_dim(self) -> None:
        """Test GroupedTensor construction with varying first dimension"""
        num_tensors = 3
        shape = [(128, 512), (256, 512), (384, 512)]

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=None,
            device="cuda",
            dtype=torch.float32,
        )

        assert grouped_tensor.num_tensors == num_tensors
        assert not grouped_tensor.all_same_shape()
        assert not grouped_tensor.all_same_first_dim()
        assert grouped_tensor.all_same_last_dim()
        assert grouped_tensor.get_common_last_dim() == shape[0][1]
        assert grouped_tensor.logical_shape == (
            sum(v for v, _ in shape),
            shape[0][1],
        )  # sum of first dims

    def test_split_into_quantized_tensors_no_quantization(self) -> None:
        """Test split_into_quantized_tensors for unquantized tensors"""
        num_tensors = 3
        shape = [(256, 512) for _ in range(num_tensors)]

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=None,
            device="cuda",
            dtype=torch.float32,
        )

        # Get the original data pointer
        original_data_ptr = grouped_tensor.data.data_ptr()

        # Split into tensors
        tensors = grouped_tensor.split_into_quantized_tensors()

        assert len(tensors) == num_tensors

        # Verify each tensor has correct shape and shares storage
        for i, tensor in enumerate(tensors):
            assert tensor.shape == shape[i]
            assert isinstance(tensor, torch.Tensor)
            assert not hasattr(tensor, "_data")  # Not a quantized tensor

            # Verify data pointer is within the original grouped tensor storage
            # The tensor should be a view of the original data
            assert tensor.data_ptr() >= original_data_ptr

            # Calculate expected offset
            expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size()
            assert tensor.data_ptr() == original_data_ptr + expected_offset

    @pytest.mark.parametrize("quantization", _quantization_params)
    def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None:
        """Test split_into_quantized_tensors for quantized tensors"""
        num_tensors = 3
        shape = [(512, 512) for _ in range(num_tensors)]
        quantizers = make_quantizer(quantization, num_tensors, shape)

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=quantizers,
            device="cuda",
        )

        # Get the original data pointer
        original_data_ptr = grouped_tensor.data.data_ptr()

        # Split into tensors
        tensors = grouped_tensor.split_into_quantized_tensors()

        assert len(tensors) == num_tensors

        # Verify each tensor shares storage with the grouped tensor
        for i, tensor in enumerate(tensors):
            rowwise_data = _get_rowwise_data_tensor(tensor, quantization)
            assert rowwise_data is not None
            assert rowwise_data.data_ptr() >= original_data_ptr
            numel = shape[i][0] * shape[i][1]
            expected_offset = _rowwise_offset_bytes(i * numel, quantization)
            assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

    def test_split_varying_shapes(self) -> None:
        """Test split_into_quantized_tensors with varying shapes"""
        num_tensors = 3
        shape = [(128, 512), (256, 512), (384, 512)]

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=None,
            device="cuda",
            dtype=torch.float32,
        )

        original_data_ptr = grouped_tensor.data.data_ptr()
        tensors = grouped_tensor.split_into_quantized_tensors()

        assert len(tensors) == num_tensors

        # Verify shapes and storage
        cumulative_offset = 0
        for i, tensor in enumerate(tensors):
            assert tensor.shape == shape[i]
            expected_offset = cumulative_offset * tensor.element_size()
            assert tensor.data_ptr() == original_data_ptr + expected_offset
            cumulative_offset += shape[i][0] * shape[i][1]

    @pytest.mark.parametrize("quantization", _quantization_params)
    def test_quantize_inplace(self, quantization: str) -> None:
        """Test that quantize is done in-place for all recipes"""
        num_tensors = 3
        shape = [(512, 512) for _ in range(num_tensors)]
        quantizers = make_quantizer(quantization, num_tensors, shape)

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=quantizers,
            device="cuda",
        )

        # Get original data pointers before quantization
        original_data_ptr = grouped_tensor.data.data_ptr()
        original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr()
        original_scale_ptr = (
            grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None
        )

        # Create input tensors
        input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

        # Quantize in place
        quantized_tensors = grouped_tensor.quantize(input_tensors)

        # Verify data pointers haven't changed (in-place operation)
        assert grouped_tensor.data.data_ptr() == original_data_ptr
        assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr
        if original_scale_ptr is not None:
            assert grouped_tensor.scale.data_ptr() == original_scale_ptr

        # Verify returned tensors point to the same storage
        for i, qtensor in enumerate(quantized_tensors):
            rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
            numel = shape[i][0] * shape[i][1]
            expected_offset = _rowwise_offset_bytes(i * numel, quantization)
            assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

    @pytest.mark.parametrize("quantization", _quantization_params)
    def test_quantize_varying_shapes(self, quantization: str) -> None:
        """Test quantize with varying shapes"""
        num_tensors = 3
        shape = [(256, 512), (512, 512), (768, 512)]
        quantizers = make_quantizer(quantization, num_tensors, shape)

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=quantizers,
            device="cuda",
        )

        # Get original data pointers
        original_data_ptr = grouped_tensor.data.data_ptr()

        # Create input tensors with varying shapes
        input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

        # Quantize in place
        quantized_tensors = grouped_tensor.quantize(input_tensors)

        # Verify data pointer hasn't changed
        assert grouped_tensor.data.data_ptr() == original_data_ptr

        # Verify each tensor points to correct location
        cumulative_numel = 0
        for qtensor, tensor_shape in zip(quantized_tensors, shape):
            rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
            expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization)
            assert rowwise_data.data_ptr() == original_data_ptr + expected_offset
            cumulative_numel += tensor_shape[0] * tensor_shape[1]

    @pytest.mark.parametrize("quantization", _quantization_params)
    def test_static_quantize_method(self, quantization: str) -> None:
        """Test the static quantize method"""
        num_tensors = 3
        shape = [(512, 512) for _ in range(num_tensors)]
        quantizers = make_quantizer(quantization, num_tensors, shape)

        # Create input tensors
        input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

        # Use static quantize method
        grouped_tensor = GroupedTensor.create_and_quantize(
            tensors=input_tensors,
            quantizer=quantizers,
            device="cuda",
        )

        # Verify the grouped tensor was created correctly
        assert grouped_tensor.num_tensors == num_tensors
        assert grouped_tensor.has_data()

        # Verify quantized_tensors were created and point to same storage
        assert grouped_tensor.quantized_tensors is not None
        assert len(grouped_tensor.quantized_tensors) == num_tensors

        original_data_ptr = grouped_tensor.data.data_ptr()
        for i, qtensor in enumerate(grouped_tensor.quantized_tensors):
            rowwise_data = _get_rowwise_data_tensor(qtensor, quantization)
            numel = shape[i][0] * shape[i][1]
            expected_offset = _rowwise_offset_bytes(i * numel, quantization)
            assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

    def test_clear(self) -> None:
        """Test clear method"""
        num_tensors = 3
        shape = [(256, 512) for _ in range(num_tensors)]

        grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
            num_tensors=num_tensors,
            shape=shape,
            quantizer=None,
            device="cuda",
            dtype=torch.float32,
        )

        assert grouped_tensor.has_data()
        assert grouped_tensor.num_tensors == num_tensors

        grouped_tensor.clear()

        assert not grouped_tensor.has_data()
        assert grouped_tensor.num_tensors == 0
        assert grouped_tensor.data is None
        assert grouped_tensor.logical_shape == (0, 0)