test_utils.py 6.27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pytest
5
6
import torch

7
8
9
10
11
from vllm.model_executor.models.utils import (
    AutoWeightsLoader,
    _merge_multimodal_embeddings,
)
from vllm.platforms import current_platform
12

13
14
DEVICE_TYPE = current_platform.device_type

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

class ModuleWithBatchNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm1d(2)

    def forward(self, x):
        return self.bn(x)


class ModuleWithNestedBatchNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.nested_mod = ModuleWithBatchNorm()

    def forward(self, x):
        return self.nested_mod(x)


34
@pytest.mark.cpu_test
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def test_module_with_batchnorm_can_load():
    """Ensure the auto weight loader can load batchnorm stats."""
    mod = ModuleWithBatchNorm()
    # Run some data through the module with batchnorm
    mod(torch.Tensor([[1, 2], [3, 4]]))

    # Try to load the weights to a new instance
    def weight_generator():
        yield from mod.state_dict().items()

    new_mod = ModuleWithBatchNorm()

    assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
    assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
    assert new_mod.bn.num_batches_tracked.item() == 0

    loader = AutoWeightsLoader(new_mod)
    loader.load_weights(weight_generator())

    # Ensure the stats are updated
    assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
    assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
    assert new_mod.bn.num_batches_tracked.item() == 1


60
@pytest.mark.cpu_test
61
62
63
64
65
66
67
68
69
70
71
72
73
def test_module_with_child_containing_batchnorm_can_autoload():
    """Ensure the auto weight loader can load nested modules batchnorm stats."""
    mod = ModuleWithNestedBatchNorm()
    # Run some data through the module with batchnorm
    mod(torch.Tensor([[1, 2], [3, 4]]))

    # Try to load the weights to a new instance
    def weight_generator():
        yield from mod.state_dict().items()

    new_mod = ModuleWithNestedBatchNorm()

    assert not torch.all(
74
75
        new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
    )
76
    assert not torch.all(
77
78
        new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
    )
79
80
81
82
83
84
85
    assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0

    loader = AutoWeightsLoader(new_mod)
    loader.load_weights(weight_generator())

    # Ensure the stats are updated
    assert torch.all(
86
87
88
        new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
    )
    assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
89
    assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
90
91


92
@pytest.mark.cpu_test
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
def test_module_skip_prefix():
    """Ensure the auto weight loader can skip prefix."""
    mod = ModuleWithNestedBatchNorm()
    # Run some data through the module with batchnorm
    mod(torch.Tensor([[1, 2], [3, 4]]))

    # Try to load the weights to a new instance
    def weight_generator():
        # weights needed to be filtered out
        redundant_weights = {
            "prefix.bn.weight": torch.Tensor([1, 2]),
            "prefix.bn.bias": torch.Tensor([3, 4]),
        }
        yield from (mod.state_dict() | redundant_weights).items()

    new_mod = ModuleWithNestedBatchNorm()

    assert not torch.all(
111
112
        new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
    )
113
    assert not torch.all(
114
115
        new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
    )
116
117
118
119
120
121
122
    assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0

    loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
    loader.load_weights(weight_generator())

    # Ensure the stats are updated
    assert torch.all(
123
124
125
        new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
    )
    assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
126
127
128
    assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1


129
@pytest.mark.cpu_test
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def test_module_skip_substr():
    """Ensure the auto weight loader can skip prefix."""
    mod = ModuleWithNestedBatchNorm()
    # Run some data through the module with batchnorm
    mod(torch.Tensor([[1, 2], [3, 4]]))

    # Try to load the weights to a new instance
    def weight_generator():
        # weights needed to be filtered out
        redundant_weights = {
            "nested_mod.0.substr.weight": torch.Tensor([1, 2]),
            "nested_mod.0.substr.bias": torch.Tensor([3, 4]),
            "nested_mod.substr.weight": torch.Tensor([1, 2]),
            "nested_mod.substr.bias": torch.Tensor([3, 4]),
        }
        yield from (mod.state_dict() | redundant_weights).items()

    new_mod = ModuleWithNestedBatchNorm()

    assert not torch.all(
150
151
        new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
    )
152
    assert not torch.all(
153
154
        new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
    )
155
156
157
158
159
160
161
    assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0

    loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
    loader.load_weights(weight_generator())

    # Ensure the stats are updated
    assert torch.all(
162
163
164
        new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
    )
    assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
165
    assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
166
167
168
169
170
171
172
173
174
175
176
177
178


class raise_if_cuda_sync:
    def __enter__(self):
        self.previous_debug_mode = torch.cuda.get_sync_debug_mode()
        torch.cuda.set_sync_debug_mode("error")

    def __exit__(self, exception_type, exception_value, traceback):
        torch.cuda.set_sync_debug_mode(self.previous_debug_mode)


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_merge_multimodal_embeddings_no_sync():
179
180
181
182
183
184
    inputs_embeds = torch.zeros(
        [5, 10], dtype=torch.bfloat16, device=f"{DEVICE_TYPE}:0"
    )
    multimodal_embeddings = [
        torch.ones([3, 10], dtype=torch.bfloat16, device=f"{DEVICE_TYPE}:0")
    ]
185
186
187
188
189
    is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
    with raise_if_cuda_sync():
        _merge_multimodal_embeddings(
            inputs_embeds, multimodal_embeddings, is_multimodal
        )