test_deferred_init.py 4.08 KB
Newer Older
1
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#
# See LICENSE for license information.

import pytest
import torch

import transformer_engine.pytorch as te

_core_modules = [
    te.LayerNorm,
    te.RMSNorm,
    te.Linear,
    te.LayerNormLinear,
    te.LayerNormMLP,
]

_composed_modules = [
    te.MultiheadAttention,
    te.TransformerLayer,
]

batch_size = 32
seq_length = 2048
num_heads = 16
head_dim = 64
dtype = torch.bfloat16

29

30
31
32
33
34
class TestDeferredInit:
    @staticmethod
    def get_module_args(module):
        hidden_size = num_heads * head_dim
        args = (hidden_size,)
35
        kwargs = {"params_dtype": dtype, "device": "meta"}
36
37
        if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
            ffn_hidden_size = 2 * hidden_size
38
39
            args += (ffn_hidden_size,)
            kwargs["bias"] = True
40
            if module == te.LayerNormMLP:
41
                kwargs["seq_length"] = seq_length
42
        elif module == te.MultiheadAttention:
43
44
            args += (num_heads,)
            kwargs["fuse_qkv_params"] = True
45
46
        elif module == te.TransformerLayer:
            args += (3 * hidden_size, num_heads)
47
48
            kwargs["fuse_qkv_params"] = True
            kwargs["seq_length"] = seq_length
49
50
51

        return args, kwargs

52
    @pytest.mark.parametrize("module_type", _core_modules + _composed_modules)
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    def test_zero_memory_init(
        self,
        module_type: torch.nn.Module,
    ) -> None:
        """Test deferred initialization via device='meta'."""
        # This should not allocate any memory on CUDA device until we call reset_parameters() later.
        args, kwargs = TestDeferredInit.get_module_args(module_type)
        module = module_type(*args, **kwargs)
        assert torch.cuda.memory_allocated(device=0) == 0.0, (
            f"Initializing {module_type.__name__} with device='meta' prematurely allocated "
            "memory on CUDA device"
        )
        del module

    @pytest.mark.parametrize("module_type", _core_modules)
    def test_reset_parameters(
        self,
        module_type: torch.nn.Module,
    ) -> None:
        """Test parameter reset for core modules that have been initialized with device='meta'."""
        # Core modules own their own parameters so calling reset_parameters() here should
        # materialize them on CUDA device.
        args, kwargs = TestDeferredInit.get_module_args(module_type)
        module = module_type(*args, **kwargs)
        with torch.no_grad():
            module.reset_parameters()
        assert torch.cuda.memory_allocated(device=0) > 0.0, (
            f"{module_type.__name__}.reset_parameters() failed to materialize parameters "
            "on CUDA device"
        )
        del module
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    @pytest.mark.parametrize("module_type", _core_modules)
    def test_reset_parameters_doesnt_change_parameter_stats(
        self,
        module_type: torch.nn.Module,
    ) -> None:
        """Test for github issue #2528 and #2529 to ensure that reset_parameters() doesn't change
        the parameter mean and std"""
        args, kwargs = TestDeferredInit.get_module_args(module_type)
        kwargs["device"] = "cuda"
        module = module_type(*args, **kwargs)

        param_stats = {
            name: {"mean": param.mean(), "std": param.std()}
            for name, param in module.named_parameters()
        }

        with torch.no_grad():
            module.reset_parameters()

        param_stats_after = {
            name: {"mean": param.mean(), "std": param.std()}
            for name, param in module.named_parameters()
        }

        for name, stats in param_stats_after.items():
            torch.testing.assert_close(
                stats["mean"],
                param_stats[name]["mean"],
                atol=1e-3,
                rtol=1e-3,
                msg=f"{name} mean changed after reset_parameters",
            )
            torch.testing.assert_close(
                stats["std"],
                param_stats[name]["std"],
                atol=1e-3,
                rtol=1e-3,
                msg=f"{name} std changed after reset_parameters",
            )

        del module