test_utils.py 10.4 KB
Newer Older
1
import asyncio
2
3
import os
import socket
4
from functools import partial
5
from typing import AsyncIterator, Tuple
6

7
import pytest
8
import torch
9

10
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
11
12
                        get_open_port, memory_profiling, merge_async_iterators,
                        supports_kw)
13

14
from .utils import error_on_warning, fork_new_process_for_each_test
15

16
17
18
19

@pytest.mark.asyncio
async def test_merge_async_iterators():

20
    async def mock_async_iterator(idx: int):
21
22
23
24
25
        try:
            while True:
                yield f"item from iterator {idx}"
                await asyncio.sleep(0.1)
        except asyncio.CancelledError:
26
            print(f"iterator {idx} cancelled")
27
28

    iterators = [mock_async_iterator(i) for i in range(3)]
29
30
31
32
    merged_iterator = merge_async_iterators(*iterators,
                                            is_cancelled=partial(asyncio.sleep,
                                                                 0,
                                                                 result=False))
33
34
35
36
37
38
39
40
41
42
43
44
45

    async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
        async for idx, output in generator:
            print(f"idx: {idx}, output: {output}")

    task = asyncio.create_task(stream_output(merged_iterator))
    await asyncio.sleep(0.5)
    task.cancel()
    with pytest.raises(asyncio.CancelledError):
        await task

    for iterator in iterators:
        try:
46
47
            # Can use anext() in python >= 3.10
            await asyncio.wait_for(iterator.__anext__(), 1)
48
49
50
51
52
53
        except StopAsyncIteration:
            # All iterators should be cancelled and print this message.
            print("Iterator was cancelled normally")
        except (Exception, asyncio.CancelledError) as e:
            raise AssertionError() from e

54
55
56
57
58
59
60
61
62
63

def test_deprecate_kwargs_always():

    @deprecate_kwargs("old_arg", is_deprecated=True)
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

    with pytest.warns(DeprecationWarning, match="'old_arg'"):
        dummy(old_arg=1)

64
    with error_on_warning(DeprecationWarning):
65
66
67
68
69
70
71
72
73
        dummy(new_arg=1)


def test_deprecate_kwargs_never():

    @deprecate_kwargs("old_arg", is_deprecated=False)
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

74
    with error_on_warning(DeprecationWarning):
75
76
        dummy(old_arg=1)

77
    with error_on_warning(DeprecationWarning):
78
79
80
81
82
83
84
85
86
87
88
89
90
        dummy(new_arg=1)


def test_deprecate_kwargs_dynamic():
    is_deprecated = True

    @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

    with pytest.warns(DeprecationWarning, match="'old_arg'"):
        dummy(old_arg=1)

91
    with error_on_warning(DeprecationWarning):
92
93
94
95
        dummy(new_arg=1)

    is_deprecated = False

96
    with error_on_warning(DeprecationWarning):
97
98
        dummy(old_arg=1)

99
    with error_on_warning(DeprecationWarning):
100
101
102
103
104
105
106
107
108
109
110
        dummy(new_arg=1)


def test_deprecate_kwargs_additional_message():

    @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

    with pytest.warns(DeprecationWarning, match="abcd"):
        dummy(old_arg=1)
111
112
113
114
115
116
117
118
119
120
121
122


def test_get_open_port():
    os.environ["VLLM_PORT"] = "5678"
    # make sure we can get multiple ports, even if the env var is set
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
        s1.bind(("localhost", get_open_port()))
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
            s2.bind(("localhost", get_open_port()))
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
                s3.bind(("localhost", get_open_port()))
    os.environ.pop("VLLM_PORT")
123
124
125
126
127
128
129
130
131
132
133
134
135
136


# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
    parser = FlexibleArgumentParser()
    parser.add_argument('--image-input-type',
                        choices=['pixel_values', 'image_features'])
    parser.add_argument('--model-name')
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--enable-feature', action='store_true')
    return parser


137
138
139
140
@pytest.fixture
def parser_with_config():
    parser = FlexibleArgumentParser()
    parser.add_argument('serve')
141
142
    parser.add_argument('model_tag')
    parser.add_argument('--served-model-name', type=str)
143
144
145
    parser.add_argument('--config', type=str)
    parser.add_argument('--port', type=int)
    parser.add_argument('--tensor-parallel-size', type=int)
146
147
    parser.add_argument('--trust-remote-code', action='store_true')
    parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean)
148
149
150
    return parser


151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
def test_underscore_to_dash(parser):
    args = parser.parse_args(['--image_input_type', 'pixel_values'])
    assert args.image_input_type == 'pixel_values'


def test_mixed_usage(parser):
    args = parser.parse_args([
        '--image_input_type', 'image_features', '--model-name',
        'facebook/opt-125m'
    ])
    assert args.image_input_type == 'image_features'
    assert args.model_name == 'facebook/opt-125m'


def test_with_equals_sign(parser):
    args = parser.parse_args(
        ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m'])
    assert args.image_input_type == 'pixel_values'
    assert args.model_name == 'facebook/opt-125m'


def test_with_int_value(parser):
    args = parser.parse_args(['--batch_size', '32'])
    assert args.batch_size == 32
    args = parser.parse_args(['--batch-size', '32'])
    assert args.batch_size == 32


def test_with_bool_flag(parser):
    args = parser.parse_args(['--enable_feature'])
    assert args.enable_feature is True
    args = parser.parse_args(['--enable-feature'])
    assert args.enable_feature is True


def test_invalid_choice(parser):
    with pytest.raises(SystemExit):
        parser.parse_args(['--image_input_type', 'invalid_choice'])


def test_missing_required_argument(parser):
    parser.add_argument('--required-arg', required=True)
    with pytest.raises(SystemExit):
        parser.parse_args([])
195
196
197
198


def test_cli_override_to_config(parser_with_config):
    args = parser_with_config.parse_args([
199
        'serve', 'mymodel', '--config', './data/test_config.yaml',
200
201
202
203
        '--tensor-parallel-size', '3'
    ])
    assert args.tensor_parallel_size == 3
    args = parser_with_config.parse_args([
204
        'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
205
206
207
        './data/test_config.yaml'
    ])
    assert args.tensor_parallel_size == 3
208
209
210
211
212
213
214
    assert args.port == 12312
    args = parser_with_config.parse_args([
        'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
        './data/test_config.yaml', '--port', '666'
    ])
    assert args.tensor_parallel_size == 3
    assert args.port == 666
215
216
217
218


def test_config_args(parser_with_config):
    args = parser_with_config.parse_args(
219
        ['serve', 'mymodel', '--config', './data/test_config.yaml'])
220
    assert args.tensor_parallel_size == 2
221
222
    assert args.trust_remote_code
    assert not args.multi_step_stream_outputs
223
224
225
226


def test_config_file(parser_with_config):
    with pytest.raises(FileNotFoundError):
227
228
        parser_with_config.parse_args(
            ['serve', 'mymodel', '--config', 'test_config.yml'])
229
230
231

    with pytest.raises(ValueError):
        parser_with_config.parse_args(
232
            ['serve', 'mymodel', '--config', './data/test_config.json'])
233
234
235

    with pytest.raises(ValueError):
        parser_with_config.parse_args([
236
237
            'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
            '--batch-size', '32'
238
        ])
239
240
241
242
243
244


def test_no_model_tag(parser_with_config):
    with pytest.raises(ValueError):
        parser_with_config.parse_args(
            ['serve', '--config', './data/test_config.yaml'])
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274


# yapf: enable
@pytest.mark.parametrize(
    "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
    [
        # Tests for positional argument support
        (lambda foo: None, "foo", True, True, False),
        (lambda foo: None, "foo", False, True, True),
        # Tests for positional or keyword / keyword only
        (lambda foo=100: None, "foo", True, True, False),
        (lambda *, foo: None, "foo", False, True, True),
        # Tests to make sure the names of variadic params are NOT supported
        (lambda *args: None, "args", False, True, False),
        (lambda **kwargs: None, "kwargs", False, True, False),
        # Tests for if we allow var kwargs to add support
        (lambda foo: None, "something_else", False, True, False),
        (lambda foo, **kwargs: None, "something_else", False, True, True),
        (lambda foo, **kwargs: None, "kwargs", True, True, False),
        (lambda foo, **kwargs: None, "foo", True, True, False),
    ])
# yapf: disable
def test_supports_kw(callable,kw_name,requires_kw_only,
                     allow_var_kwargs,is_supported):
    assert supports_kw(
        callable=callable,
        kw_name=kw_name,
        requires_kw_only=requires_kw_only,
        allow_var_kwargs=allow_var_kwargs
    ) == is_supported
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312


@fork_new_process_for_each_test
def test_memory_profiling():
    # Fake out some model loading + inference memory usage to test profiling
    # Memory used by other processes will show up as cuda usage outside of torch
    from vllm.distributed.device_communicators.cuda_wrapper import (
        CudaRTLibrary)
    lib = CudaRTLibrary()
    # 512 MiB allocation outside of this instance
    handle1 = lib.cudaMalloc(512 * 1024 * 1024)

    baseline_memory_in_bytes = \
        torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]

    # load weights

    weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)

    weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB

    with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
    weights_memory_in_bytes=weights_memory_in_bytes) as result:
        # make a memory spike, 1 GiB
        spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
        del spike

        # Add some extra non-torch memory 256 MiB (simulate NCCL)
        handle2 = lib.cudaMalloc(256 * 1024 * 1024)

    # Check that the memory usage is within 5% of the expected values
    non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
    torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
    assert abs(non_torch_ratio - 1) <= 0.05
    assert abs(torch_peak_ratio - 1) <= 0.05
    del weights
    lib.cudaFree(handle1)
    lib.cudaFree(handle2)