test_utils.py 31.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# ruff: noqa
4

5
import asyncio
6
import hashlib
7
import json
8
import logging
9
import pickle
10
import socket
11
from collections.abc import AsyncIterator
12
from unittest.mock import patch
13

14
import pytest
15
import torch
16
import zmq
17
from transformers import AutoTokenizer
18
from vllm_test_utils.monitor import monitor
19

20
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
21
22
from vllm.transformers_utils.detokenizer_utils import (
    convert_ids_list_to_tokens)
23
24
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
                        MemorySnapshot, PlaceholderModule, StoreBoolean,
25
                        bind_kv_cache, common_broadcastable_dtype,
26
27
28
                        current_stream, deprecate_kwargs, get_open_port,
                        get_tcp_uri, is_lossless_cast, join_host_port,
                        make_zmq_path, make_zmq_socket, memory_profiling,
29
30
                        merge_async_iterators, sha256, split_host_port,
                        split_zmq_path, supports_kw, swap_dict_values)
31

32
from .utils import create_new_process_for_each_test, error_on_warning
33

34
35
36
37

@pytest.mark.asyncio
async def test_merge_async_iterators():

38
    async def mock_async_iterator(idx: int):
39
40
41
42
43
        try:
            while True:
                yield f"item from iterator {idx}"
                await asyncio.sleep(0.1)
        except asyncio.CancelledError:
44
            print(f"iterator {idx} cancelled")
45
46

    iterators = [mock_async_iterator(i) for i in range(3)]
47
    merged_iterator = merge_async_iterators(*iterators)
48

49
    async def stream_output(generator: AsyncIterator[tuple[int, str]]):
50
51
52
53
54
55
56
57
58
59
60
        async for idx, output in generator:
            print(f"idx: {idx}, output: {output}")

    task = asyncio.create_task(stream_output(merged_iterator))
    await asyncio.sleep(0.5)
    task.cancel()
    with pytest.raises(asyncio.CancelledError):
        await task

    for iterator in iterators:
        try:
61
62
            # Can use anext() in python >= 3.10
            await asyncio.wait_for(iterator.__anext__(), 1)
63
64
65
66
67
68
        except StopAsyncIteration:
            # All iterators should be cancelled and print this message.
            print("Iterator was cancelled normally")
        except (Exception, asyncio.CancelledError) as e:
            raise AssertionError() from e

69
70
71
72
73
74
75
76
77
78

def test_deprecate_kwargs_always():

    @deprecate_kwargs("old_arg", is_deprecated=True)
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

    with pytest.warns(DeprecationWarning, match="'old_arg'"):
        dummy(old_arg=1)

79
    with error_on_warning(DeprecationWarning):
80
81
82
83
84
85
86
87
88
        dummy(new_arg=1)


def test_deprecate_kwargs_never():

    @deprecate_kwargs("old_arg", is_deprecated=False)
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

89
    with error_on_warning(DeprecationWarning):
90
91
        dummy(old_arg=1)

92
    with error_on_warning(DeprecationWarning):
93
94
95
96
97
98
99
100
101
102
103
104
105
        dummy(new_arg=1)


def test_deprecate_kwargs_dynamic():
    is_deprecated = True

    @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

    with pytest.warns(DeprecationWarning, match="'old_arg'"):
        dummy(old_arg=1)

106
    with error_on_warning(DeprecationWarning):
107
108
109
110
        dummy(new_arg=1)

    is_deprecated = False

111
    with error_on_warning(DeprecationWarning):
112
113
        dummy(old_arg=1)

114
    with error_on_warning(DeprecationWarning):
115
116
117
118
119
120
121
122
123
124
125
        dummy(new_arg=1)


def test_deprecate_kwargs_additional_message():

    @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
    def dummy(*, old_arg: object = None, new_arg: object = None):
        pass

    with pytest.warns(DeprecationWarning, match="abcd"):
        dummy(old_arg=1)
126
127


128
129
130
131
132
133
134
135
136
137
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_PORT", "5678")
        # make sure we can get multiple ports, even if the env var is set
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
            s1.bind(("localhost", get_open_port()))
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
                s2.bind(("localhost", get_open_port()))
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
                    s3.bind(("localhost", get_open_port()))
138
139
140
141
142
143
144
145
146
147
148


# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
    parser = FlexibleArgumentParser()
    parser.add_argument('--image-input-type',
                        choices=['pixel_values', 'image_features'])
    parser.add_argument('--model-name')
    parser.add_argument('--batch-size', type=int)
    parser.add_argument('--enable-feature', action='store_true')
149
    parser.add_argument('--hf-overrides', type=json.loads)
150
    parser.add_argument('-O', '--compilation-config', type=json.loads)
151
152
153
    return parser


154
155
156
157
@pytest.fixture
def parser_with_config():
    parser = FlexibleArgumentParser()
    parser.add_argument('serve')
158
159
    parser.add_argument('model_tag', nargs='?')
    parser.add_argument('--model', type=str)
160
    parser.add_argument('--served-model-name', type=str)
161
162
163
    parser.add_argument('--config', type=str)
    parser.add_argument('--port', type=int)
    parser.add_argument('--tensor-parallel-size', type=int)
164
165
    parser.add_argument('--trust-remote-code', action='store_true')
    parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean)
166
167
168
    return parser


169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def test_underscore_to_dash(parser):
    args = parser.parse_args(['--image_input_type', 'pixel_values'])
    assert args.image_input_type == 'pixel_values'


def test_mixed_usage(parser):
    args = parser.parse_args([
        '--image_input_type', 'image_features', '--model-name',
        'facebook/opt-125m'
    ])
    assert args.image_input_type == 'image_features'
    assert args.model_name == 'facebook/opt-125m'


def test_with_equals_sign(parser):
    args = parser.parse_args(
        ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m'])
    assert args.image_input_type == 'pixel_values'
    assert args.model_name == 'facebook/opt-125m'


def test_with_int_value(parser):
    args = parser.parse_args(['--batch_size', '32'])
    assert args.batch_size == 32
    args = parser.parse_args(['--batch-size', '32'])
    assert args.batch_size == 32


def test_with_bool_flag(parser):
    args = parser.parse_args(['--enable_feature'])
    assert args.enable_feature is True
    args = parser.parse_args(['--enable-feature'])
    assert args.enable_feature is True


def test_invalid_choice(parser):
    with pytest.raises(SystemExit):
        parser.parse_args(['--image_input_type', 'invalid_choice'])


def test_missing_required_argument(parser):
    parser.add_argument('--required-arg', required=True)
    with pytest.raises(SystemExit):
        parser.parse_args([])
213
214


215
def test_cli_override_to_config(parser_with_config, cli_config_file):
216
    args = parser_with_config.parse_args([
217
        'serve', 'mymodel', '--config', cli_config_file,
218
219
220
221
        '--tensor-parallel-size', '3'
    ])
    assert args.tensor_parallel_size == 3
    args = parser_with_config.parse_args([
222
        'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
223
        cli_config_file
224
225
    ])
    assert args.tensor_parallel_size == 3
226
227
228
    assert args.port == 12312
    args = parser_with_config.parse_args([
        'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
229
        cli_config_file, '--port', '666'
230
231
232
    ])
    assert args.tensor_parallel_size == 3
    assert args.port == 666
233
234


235
def test_config_args(parser_with_config, cli_config_file):
236
    args = parser_with_config.parse_args(
237
        ['serve', 'mymodel', '--config', cli_config_file])
238
    assert args.tensor_parallel_size == 2
239
240
    assert args.trust_remote_code
    assert not args.multi_step_stream_outputs
241
242
243
244


def test_config_file(parser_with_config):
    with pytest.raises(FileNotFoundError):
245
246
        parser_with_config.parse_args(
            ['serve', 'mymodel', '--config', 'test_config.yml'])
247
248
249

    with pytest.raises(ValueError):
        parser_with_config.parse_args(
250
            ['serve', 'mymodel', '--config', './data/test_config.json'])
251
252
253

    with pytest.raises(ValueError):
        parser_with_config.parse_args([
254
255
            'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
            '--batch-size', '32'
256
        ])
257
258


259
def test_no_model_tag(parser_with_config, cli_config_file):
260
    with pytest.raises(ValueError):
261
        parser_with_config.parse_args(['serve', '--config', cli_config_file])
262
263


264
265
266
267
268
def test_dict_args(parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
269
        # Test nesting
270
271
272
273
        "--hf-overrides.key2.key3",
        "val2",
        "--hf-overrides.key2.key4",
        "val3",
274
275
276
277
278
        # Test compile config and compilation level
        "-O.use_inductor=true",
        "-O.backend",
        "custom",
        "-O1",
279
        # Test = sign
280
        "--hf-overrides.key5=val4",
281
282
283
284
285
        # Test underscore to dash conversion
        "--hf_overrides.key_6",
        "val5",
        "--hf_overrides.key-7.key_8",
        "val6",
286
287
288
289
290
291
292
293
294
        # Test data type detection
        "--hf_overrides.key9",
        "100",
        "--hf_overrides.key10",
        "100.0",
        "--hf_overrides.key11",
        "true",
        "--hf_overrides.key12.key13",
        "null",
295
296
297
298
299
300
301
        # Test '-' and '.' in value
        "--hf_overrides.key14.key15",
        "-minus.and.dot",
        # Test array values
        "-O.custom_ops+",
        "-quant_fp8",
        "-O.custom_ops+=+silu_mul,-rms_norm",
302
303
304
305
306
307
308
309
310
311
    ]
    parsed_args = parser.parse_args(args)
    assert parsed_args.model_name == "something.something"
    assert parsed_args.hf_overrides == {
        "key1": "val1",
        "key2": {
            "key3": "val2",
            "key4": "val3",
        },
        "key5": "val4",
312
313
314
315
        "key_6": "val5",
        "key-7": {
            "key_8": "val6",
        },
316
317
318
319
320
321
        "key9": 100,
        "key10": 100.0,
        "key11": True,
        "key12": {
            "key13": None,
        },
322
323
324
        "key14": {
            "key15": "-minus.and.dot",
        }
325
    }
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    assert parsed_args.compilation_config == {
        "level": 1,
        "use_inductor": True,
        "backend": "custom",
        "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
    }


def test_duplicate_dict_args(caplog_vllm, parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
        "--hf-overrides.key1",
        "val2",
        "-O1",
        "-O.level",
        "2",
        "-O3",
    ]

    parsed_args = parser.parse_args(args)
    # Should be the last value
    assert parsed_args.hf_overrides == {"key1": "val2"}
    assert parsed_args.compilation_config == {"level": 3}

    assert len(caplog_vllm.records) == 1
    assert "duplicate" in caplog_vllm.text
    assert "--hf-overrides.key1" in caplog_vllm.text
    assert "-O.level" in caplog_vllm.text
356
357


358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# yapf: enable
@pytest.mark.parametrize(
    "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
    [
        # Tests for positional argument support
        (lambda foo: None, "foo", True, True, False),
        (lambda foo: None, "foo", False, True, True),
        # Tests for positional or keyword / keyword only
        (lambda foo=100: None, "foo", True, True, False),
        (lambda *, foo: None, "foo", False, True, True),
        # Tests to make sure the names of variadic params are NOT supported
        (lambda *args: None, "args", False, True, False),
        (lambda **kwargs: None, "kwargs", False, True, False),
        # Tests for if we allow var kwargs to add support
        (lambda foo: None, "something_else", False, True, False),
        (lambda foo, **kwargs: None, "something_else", False, True, True),
        (lambda foo, **kwargs: None, "kwargs", True, True, False),
        (lambda foo, **kwargs: None, "foo", True, True, False),
    ])
# yapf: disable
def test_supports_kw(callable,kw_name,requires_kw_only,
                     allow_var_kwargs,is_supported):
    assert supports_kw(
        callable=callable,
        kw_name=kw_name,
        requires_kw_only=requires_kw_only,
        allow_var_kwargs=allow_var_kwargs
    ) == is_supported
386
387


388
@create_new_process_for_each_test()
389
390
391
392
393
394
395
396
397
def test_memory_profiling():
    # Fake out some model loading + inference memory usage to test profiling
    # Memory used by other processes will show up as cuda usage outside of torch
    from vllm.distributed.device_communicators.cuda_wrapper import (
        CudaRTLibrary)
    lib = CudaRTLibrary()
    # 512 MiB allocation outside of this instance
    handle1 = lib.cudaMalloc(512 * 1024 * 1024)

398
    baseline_snapshot = MemorySnapshot()
399
400
401
402
403

    # load weights

    weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)

404
    weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB
405

406
407
408
409
410
411
412
    def measure_current_non_torch():
        free, total = torch.cuda.mem_get_info()
        current_used = total - free
        current_torch = torch.cuda.memory_reserved()
        current_non_torch = current_used - current_torch
        return current_non_torch

413
414
    with memory_profiling(baseline_snapshot=baseline_snapshot,
    weights_memory=weights_memory) as result, \
415
        monitor(measure_current_non_torch) as monitored_values:
416
417
418
419
420
421
422
        # make a memory spike, 1 GiB
        spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
        del spike

        # Add some extra non-torch memory 256 MiB (simulate NCCL)
        handle2 = lib.cudaMalloc(256 * 1024 * 1024)

423
424
425
426
427
    # this is an analytic value, it is exact,
    # we only have 256 MiB non-torch memory increase
    measured_diff = monitored_values.values[-1] - monitored_values.values[0]
    assert measured_diff == 256 * 1024 * 1024

428
    # Check that the memory usage is within 5% of the expected values
429
430
    # 5% tolerance is caused by cuda runtime.
    # we cannot control cuda runtime in the granularity of bytes,
431
    # which causes a small error (<10 MiB in practice)
432
    non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa
433
    assert abs(non_torch_ratio - 1) <= 0.05
434
    assert result.torch_peak_increase == 1024 * 1024 * 1024
435
436
437
    del weights
    lib.cudaFree(handle1)
    lib.cudaFree(handle2)
438
439


440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def test_bind_kv_cache():
    from vllm.attention import Attention

    ctx = {
        'layers.0.self_attn': Attention(32, 128, 0.1),
        'layers.1.self_attn': Attention(32, 128, 0.1),
        'layers.2.self_attn': Attention(32, 128, 0.1),
        'layers.3.self_attn': Attention(32, 128, 0.1),
    }
    kv_cache = [
        torch.zeros((1, )),
        torch.zeros((1, )),
        torch.zeros((1, )),
        torch.zeros((1, )),
    ]
    bind_kv_cache(ctx, [kv_cache])
    assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
    assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
    assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2]
    assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3]

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
def test_bind_kv_cache_kv_sharing():
    from vllm.attention import Attention

    ctx = {
        'layers.0.self_attn': Attention(32, 128, 0.1),
        'layers.1.self_attn': Attention(32, 128, 0.1),
        'layers.2.self_attn': Attention(32, 128, 0.1),
        'layers.3.self_attn': Attention(32, 128, 0.1),
    }
    kv_cache = [
        torch.zeros((1, )),
        torch.zeros((1, )),
        torch.zeros((1, )),
        torch.zeros((1, )),
    ]
    shared_kv_cache_layers = {
        'layers.2.self_attn': 'layers.1.self_attn',
        'layers.3.self_attn': 'layers.0.self_attn'
    }
    bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
    assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0]
    assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1]
    assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[1]
    assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[0]

486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def test_bind_kv_cache_non_attention():
    from vllm.attention import Attention

    # example from Jamba PP=2
    ctx = {
        'model.layers.20.attn': Attention(32, 128, 0.1),
        'model.layers.28.attn': Attention(32, 128, 0.1),
    }
    kv_cache = [
        torch.zeros((1, )),
        torch.zeros((1, )),
    ]
    bind_kv_cache(ctx, [kv_cache])
    assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0]
    assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1]


503
def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch):
504
    # V1 TESTS: ENCODER_DECODER is not supported on V1 yet.
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "0")

        from vllm.attention import Attention, AttentionType

        # example from bart
        ctx = {
            'encoder.layers.0.self_attn.attn':
                Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER),
            'decoder.layers.0.encoder_attn.attn':
                Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER),
            'decoder.layers.0.self_attn.attn':
                Attention(32, 128, 0.1, attn_type=AttentionType.DECODER),
        }

        kv_cache = [
            torch.zeros((1, )),
        ]
        encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache

        bind_kv_cache(ctx, [kv_cache])
        assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache
        assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0]
        assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0]
529
530
531


def test_bind_kv_cache_pp():
532
533
534
535
    with patch("vllm.utils.cuda_device_count_stateless", lambda: 2):
        # this test runs with 1 GPU, but we simulate 2 GPUs
        cfg = VllmConfig(
            parallel_config=ParallelConfig(pipeline_parallel_size=2))
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    with set_current_vllm_config(cfg):
        from vllm.attention import Attention

        ctx = {
            'layers.0.self_attn': Attention(32, 128, 0.1),
        }
        kv_cache = [
            [torch.zeros((1, ))],
            [torch.zeros((1, ))]
        ]
        bind_kv_cache(ctx, kv_cache)
        assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0]
        assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]


551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
class TestLRUCache(LRUCache):

    def _on_remove(self, key, value):
        if not hasattr(self, "_remove_counter"):
            self._remove_counter = 0
        self._remove_counter += 1


def test_lru_cache():
    cache = TestLRUCache(3)
    assert cache.stat() == CacheInfo(hits=0, total=0)
    assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)

    cache.put(1, 1)
    assert len(cache) == 1

    cache.put(1, 1)
    assert len(cache) == 1

    cache.put(2, 2)
    assert len(cache) == 2

    cache.put(3, 3)
    assert len(cache) == 3
    assert set(cache.cache) == {1, 2, 3}

    cache.put(4, 4)
    assert len(cache) == 3
    assert set(cache.cache) == {2, 3, 4}
    assert cache._remove_counter == 1

    assert cache.get(2) == 2
    assert cache.stat() == CacheInfo(hits=1, total=1)
    assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)

    assert cache[2] == 2
    assert cache.stat() == CacheInfo(hits=2, total=2)
    assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)

    cache.put(5, 5)
    assert set(cache.cache) == {2, 4, 5}
    assert cache._remove_counter == 2

    assert cache.pop(5) == 5
    assert len(cache) == 2
    assert set(cache.cache) == {2, 4}
    assert cache._remove_counter == 3

    assert cache.get(-1) is None
    assert cache.stat() == CacheInfo(hits=2, total=3)
    assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)

    cache.pop(10)
    assert len(cache) == 2
    assert set(cache.cache) == {2, 4}
    assert cache._remove_counter == 3

    cache.get(10)
    assert len(cache) == 2
    assert set(cache.cache) == {2, 4}
    assert cache._remove_counter == 3

    cache.put(6, 6)
    assert len(cache) == 3
    assert set(cache.cache) == {2, 4, 6}
    assert 2 in cache
    assert 4 in cache
    assert 6 in cache

    cache.remove_oldest()
    assert len(cache) == 2
    assert set(cache.cache) == {2, 6}
    assert cache._remove_counter == 4

    cache.clear()
    assert len(cache) == 0
    assert cache._remove_counter == 6
    assert cache.stat() == CacheInfo(hits=0, total=0)
    assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)

    cache._remove_counter = 0

    cache[1] = 1
    assert len(cache) == 1

    cache[1] = 1
    assert len(cache) == 1

    cache[2] = 2
    assert len(cache) == 2

    cache[3] = 3
    assert len(cache) == 3
    assert set(cache.cache) == {1, 2, 3}

    cache[4] = 4
    assert len(cache) == 3
    assert set(cache.cache) == {2, 3, 4}
    assert cache._remove_counter == 1
    assert cache[2] == 2

    cache[5] = 5
    assert set(cache.cache) == {2, 4, 5}
    assert cache._remove_counter == 2

    del cache[5]
    assert len(cache) == 2
    assert set(cache.cache) == {2, 4}
    assert cache._remove_counter == 3

    cache.pop(10)
    assert len(cache) == 2
    assert set(cache.cache) == {2, 4}
    assert cache._remove_counter == 3

    cache[6] = 6
    assert len(cache) == 3
    assert set(cache.cache) == {2, 4, 6}
    assert 2 in cache
    assert 4 in cache
    assert 6 in cache


674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
# yapf: disable
@pytest.mark.parametrize(
    ("src_dtype", "tgt_dtype", "expected_result"),
    [
        # Different precision_levels
        (torch.bool, torch.int8, True),
        (torch.bool, torch.float16, True),
        (torch.bool, torch.complex32, True),
        (torch.int64, torch.bool, False),
        (torch.int64, torch.float16, True),
        (torch.int64, torch.complex32, True),
        (torch.float64, torch.bool, False),
        (torch.float64, torch.int8, False),
        (torch.float64, torch.complex32, True),
        (torch.complex128, torch.bool, False),
        (torch.complex128, torch.int8, False),
        (torch.complex128, torch.float16, False),
        # precision_level=0
        (torch.bool, torch.bool, True),
        # precision_level=1
        (torch.int8, torch.int16, True),
        (torch.int16, torch.int8, False),
        (torch.uint8, torch.int8, False),
        (torch.int8, torch.uint8, False),
        # precision_level=2
        (torch.float16, torch.float32, True),
        (torch.float32, torch.float16, False),
        (torch.bfloat16, torch.float32, True),
        (torch.float32, torch.bfloat16, False),
        # precision_level=3
        (torch.complex32, torch.complex64, True),
        (torch.complex64, torch.complex32, False),
    ],
)
# yapf: enable
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
    assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result


# yapf: disable
@pytest.mark.parametrize(
    ("dtypes", "expected_result"),
    [
        ([torch.bool], torch.bool),
        ([torch.bool, torch.int8], torch.int8),
        ([torch.bool, torch.int8, torch.float16], torch.float16),
        ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32),  # noqa: E501
    ],
)
# yapf: enable
def test_common_broadcastable_dtype(dtypes, expected_result):
    assert common_broadcastable_dtype(dtypes) == expected_result


728
729
730
731
def test_placeholder_module_error_handling():
    placeholder = PlaceholderModule("placeholder_1234")

    def build_ctx():
732
        return pytest.raises(ModuleNotFoundError, match="No module named")
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765

    with build_ctx():
        int(placeholder)

    with build_ctx():
        placeholder()

    with build_ctx():
        _ = placeholder.some_attr

    with build_ctx():
        # Test conflict with internal __name attribute
        _ = placeholder.name

    # OK to print the placeholder or use it in a f-string
    _ = repr(placeholder)
    _ = str(placeholder)

    # No error yet; only error when it is used downstream
    placeholder_attr = placeholder.placeholder_attr("attr")

    with build_ctx():
        int(placeholder_attr)

    with build_ctx():
        placeholder_attr()

    with build_ctx():
        _ = placeholder_attr.some_attr

    with build_ctx():
        # Test conflict with internal __module attribute
        _ = placeholder_attr.module
766
767


768
# yapf: disable
769
770
771
772
773
774
775
776
777
778
@pytest.mark.parametrize(
    "obj,key1,key2",
    [
        # Tests for both keys exist
        ({1: "a", 2: "b"}, 1, 2),
        # Tests for one key does not exist
        ({1: "a", 2: "b"}, 1, 3),
        # Tests for both keys do not exist
        ({1: "a", 2: "b"}, 3, 4),
    ])
779
# yapf: enable
780
781
782
783
784
785
786
787
788
789
790
def test_swap_dict_values(obj, key1, key2):
    original_obj = obj.copy()
    swap_dict_values(obj, key1, key2)
    if key1 in original_obj:
        assert obj[key2] == original_obj[key1]
    else:
        assert key2 not in obj
    if key2 in original_obj:
        assert obj[key1] == original_obj[key2]
    else:
        assert key1 not in obj
791

792

793
def test_model_specification(parser_with_config, cli_config_file,
794
795
                             cli_config_file_with_model):
    # Test model in CLI takes precedence over config
796
797
    args = parser_with_config.parse_args(
        ['serve', 'cli-model', '--config', cli_config_file_with_model])
798
799
800
801
802
    assert args.model_tag == 'cli-model'
    assert args.served_model_name == 'mymodel'

    # Test model from config file works
    args = parser_with_config.parse_args([
803
804
805
        'serve',
        '--config',
        cli_config_file_with_model,
806
807
808
809
810
811
812
813
814
815
    ])
    assert args.model == 'config-model'
    assert args.served_model_name == 'mymodel'

    # Test no model specified anywhere raises error
    with pytest.raises(ValueError, match="No model specified!"):
        parser_with_config.parse_args(['serve', '--config', cli_config_file])

    # Test using --model option raises error
    with pytest.raises(
816
817
818
819
            ValueError,
            match=
        ("With `vllm serve`, you should provide the model as a positional "
         "argument or in a config file instead of via the `--model` option."),
820
821
822
823
824
    ):
        parser_with_config.parse_args(['serve', '--model', 'my-model'])

    # Test other config values are preserved
    args = parser_with_config.parse_args([
825
826
827
828
        'serve',
        'cli-model',
        '--config',
        cli_config_file_with_model,
829
830
831
832
833
834
835
    ])
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.multi_step_stream_outputs is False
    assert args.port == 12312


836
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
837
                                   (None, bool, [1, 2, 3])])
838
839
840
841
842
843
844
845
@pytest.mark.parametrize("output", [0, 1, 2])
def test_sha256(input: tuple, output: int):
    hash = sha256(input)
    assert hash is not None
    assert isinstance(hash, int)
    assert hash != 0

    bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
846
847
    assert hash == int.from_bytes(hashlib.sha256(bytes).digest(),
                                  byteorder="big")
848
849
850
851
852
853

    # hashing again, returns the same value
    assert hash == sha256(input)

    # hashing different input, returns different value
    assert hash != sha256(input + (1, ))
854
855
856
857
858
859
860
861
862


@pytest.mark.parametrize(
    "path,expected",
    [
        ("ipc://some_path", ("ipc", "some_path", "")),
        ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
        ("tcp://[::1]:5555", ("tcp", "::1", "5555")),  # IPv6 address
        ("inproc://some_identifier", ("inproc", "some_identifier", "")),
863
    ])
864
865
866
867
868
869
870
871
872
873
874
def test_split_zmq_path(path, expected):
    assert split_zmq_path(path) == expected


@pytest.mark.parametrize(
    "invalid_path",
    [
        "invalid_path",  # Missing scheme
        "tcp://127.0.0.1",  # Missing port
        "tcp://[::1]",  # Missing port for IPv6
        "tcp://:5555",  # Missing host
875
    ])
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
def test_split_zmq_path_invalid(invalid_path):
    with pytest.raises(ValueError):
        split_zmq_path(invalid_path)


def test_make_zmq_socket_ipv6():
    # Check if IPv6 is supported by trying to create an IPv6 socket
    try:
        sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
        sock.close()
    except socket.error:
        pytest.skip("IPv6 is not supported on this system")

    ctx = zmq.Context()
    ipv6_path = "tcp://[::]:5555"  # IPv6 loopback address
    socket_type = zmq.REP  # Example socket type

    # Create the socket
    zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)

    # Verify that the IPV6 option is set
897
898
    assert zsock.getsockopt(
        zmq.IPV6) == 1, "IPV6 option should be enabled for IPv6 addresses"
899
900
901
902

    # Clean up
    zsock.close()
    ctx.term()
903
904
905
906
907


def test_make_zmq_path():
    assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
    assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948


def test_get_tcp_uri():
    assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
    assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"


def test_split_host_port():
    # valid ipv4
    assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
    # invalid ipv4
    with pytest.raises(ValueError):
        # multi colon
        assert split_host_port("127.0.0.1::5555")
    with pytest.raises(ValueError):
        # tailing colon
        assert split_host_port("127.0.0.1:5555:")
    with pytest.raises(ValueError):
        # no colon
        assert split_host_port("127.0.0.15555")
    with pytest.raises(ValueError):
        # none int port
        assert split_host_port("127.0.0.1:5555a")

    # valid ipv6
    assert split_host_port("[::1]:5555") == ("::1", 5555)
    # invalid ipv6
    with pytest.raises(ValueError):
        # multi colon
        assert split_host_port("[::1]::5555")
    with pytest.raises(IndexError):
        # no colon
        assert split_host_port("[::1]5555")
    with pytest.raises(ValueError):
        # none int port
        assert split_host_port("[::1]:5555a")


def test_join_host_port():
    assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
    assert join_host_port("::1", 5555) == "[::1]:5555"
949
950
951
952
953
954
955
956
957
958
959


def test_convert_ids_list_to_tokens():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
    token_ids = tokenizer.encode("Hello, world!")
    # token_ids = [9707, 11, 1879, 0]
    assert tokenizer.convert_ids_to_tokens(token_ids) == [
        'Hello', ',', 'Ġworld', '!'
    ]
    tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
    assert tokens == ['Hello', ',', ' world', '!']
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997


def test_current_stream_multithread():
    import threading
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")

    main_default_stream = torch.cuda.current_stream()
    child_stream = torch.cuda.Stream()

    thread_stream_ready = threading.Event()
    thread_can_exit = threading.Event()

    def child_thread_func():
        with torch.cuda.stream(child_stream):
            thread_stream_ready.set()
            thread_can_exit.wait(timeout=10)

    child_thread = threading.Thread(target=child_thread_func)
    child_thread.start()

    try:
        assert thread_stream_ready.wait(
            timeout=5), "Child thread failed to enter stream context in time"

        main_current_stream = current_stream()

        assert main_current_stream != child_stream, "Main thread's current_stream was contaminated by child thread"
        assert main_current_stream == main_default_stream, "Main thread's current_stream is not the default stream"

        # Notify child thread it can exit
        thread_can_exit.set()

    finally:
        # Ensure child thread exits properly
        child_thread.join(timeout=5)
        if child_thread.is_alive():
            pytest.fail("Child thread failed to exit properly")