test_deferred_init.py 2.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import pytest
import torch
import torch.distributed as dist

import transformer_engine.pytorch as te

_core_modules = [
    te.LayerNorm,
    te.RMSNorm,
    te.Linear,
    te.LayerNormLinear,
    te.LayerNormMLP,
]

_composed_modules = [
    te.MultiheadAttention,
    te.TransformerLayer,
]

batch_size = 32
seq_length = 2048
num_heads = 16
head_dim = 64
dtype = torch.bfloat16

class TestDeferredInit:

    @staticmethod
    def get_module_args(module):
        hidden_size = num_heads * head_dim
        args = (hidden_size,)
        kwargs = {
            'params_dtype': dtype,
            'device': 'meta'
        }
        if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
            ffn_hidden_size = 2 * hidden_size
            args += (ffn_hidden_size, )
            kwargs['bias'] = True
            if module == te.LayerNormMLP:
                kwargs['seq_length'] = seq_length
        elif module == te.MultiheadAttention:
            args += (num_heads, )
            kwargs['fuse_qkv_params'] = True
        elif module == te.TransformerLayer:
            args += (3 * hidden_size, num_heads)
            kwargs['fuse_qkv_params'] = True
            kwargs['seq_length'] = seq_length

        return args, kwargs

    @pytest.mark.parametrize("module_type", _core_modules+_composed_modules)
    def test_zero_memory_init(
        self,
        module_type: torch.nn.Module,
    ) -> None:
        """Test deferred initialization via device='meta'."""
        # This should not allocate any memory on CUDA device until we call reset_parameters() later.
        args, kwargs = TestDeferredInit.get_module_args(module_type)
        module = module_type(*args, **kwargs)
        assert torch.cuda.memory_allocated(device=0) == 0.0, (
            f"Initializing {module_type.__name__} with device='meta' prematurely allocated "
            "memory on CUDA device"
        )
        del module

    @pytest.mark.parametrize("module_type", _core_modules)
    def test_reset_parameters(
        self,
        module_type: torch.nn.Module,
    ) -> None:
        """Test parameter reset for core modules that have been initialized with device='meta'."""
        # Core modules own their own parameters so calling reset_parameters() here should
        # materialize them on CUDA device.
        args, kwargs = TestDeferredInit.get_module_args(module_type)
        module = module_type(*args, **kwargs)
        with torch.no_grad():
            module.reset_parameters()
        assert torch.cuda.memory_allocated(device=0) > 0.0, (
            f"{module_type.__name__}.reset_parameters() failed to materialize parameters "
            "on CUDA device"
        )
        del module