test_utils.py 24.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
# ruff: noqa
4

5
import hashlib
6
import json
7
import os
8
import pickle
9
import socket
10
11
import tempfile
from pathlib import Path
12
from unittest.mock import patch
13

14
import pytest
15
import torch
16
import yaml
17
import zmq
18
from transformers import AutoTokenizer
19
from vllm_test_utils.monitor import monitor
20

21
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
22
from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
23

24
from vllm.utils import (
25
26
27
28
29
30
31
32
33
34
35
36
    FlexibleArgumentParser,
    bind_kv_cache,
    get_open_port,
    get_tcp_uri,
    join_host_port,
    make_zmq_path,
    make_zmq_socket,
    sha256,
    split_host_port,
    split_zmq_path,
    unique_filepath,
)
37
38
39
40
41
from vllm.utils.torch_utils import (
    common_broadcastable_dtype,
    current_stream,
    is_lossless_cast,
)
42

43
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
44
from ..utils import create_new_process_for_each_test, flat_product
45

46

47
48
49
50
51
52
53
54
55
56
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
    with monkeypatch.context() as m:
        m.setenv("VLLM_PORT", "5678")
        # make sure we can get multiple ports, even if the env var is set
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
            s1.bind(("localhost", get_open_port()))
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
                s2.bind(("localhost", get_open_port()))
                with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
                    s3.bind(("localhost", get_open_port()))
57
58
59
60
61
62


# Tests for FlexibleArgumentParser
@pytest.fixture
def parser():
    parser = FlexibleArgumentParser()
63
64
65
66
67
68
69
70
    parser.add_argument(
        "--image-input-type", choices=["pixel_values", "image_features"]
    )
    parser.add_argument("--model-name")
    parser.add_argument("--batch-size", type=int)
    parser.add_argument("--enable-feature", action="store_true")
    parser.add_argument("--hf-overrides", type=json.loads)
    parser.add_argument("-O", "--compilation-config", type=json.loads)
71
72
73
    return parser


74
75
76
@pytest.fixture
def parser_with_config():
    parser = FlexibleArgumentParser()
77
78
79
80
81
82
83
84
    parser.add_argument("serve")
    parser.add_argument("model_tag", nargs="?")
    parser.add_argument("--model", type=str)
    parser.add_argument("--served-model-name", type=str)
    parser.add_argument("--config", type=str)
    parser.add_argument("--port", type=int)
    parser.add_argument("--tensor-parallel-size", type=int)
    parser.add_argument("--trust-remote-code", action="store_true")
85
86
87
    return parser


88
def test_underscore_to_dash(parser):
89
90
    args = parser.parse_args(["--image_input_type", "pixel_values"])
    assert args.image_input_type == "pixel_values"
91
92
93


def test_mixed_usage(parser):
94
95
96
97
98
    args = parser.parse_args(
        ["--image_input_type", "image_features", "--model-name", "facebook/opt-125m"]
    )
    assert args.image_input_type == "image_features"
    assert args.model_name == "facebook/opt-125m"
99
100
101
102


def test_with_equals_sign(parser):
    args = parser.parse_args(
103
104
105
106
        ["--image_input_type=pixel_values", "--model-name=facebook/opt-125m"]
    )
    assert args.image_input_type == "pixel_values"
    assert args.model_name == "facebook/opt-125m"
107
108
109


def test_with_int_value(parser):
110
    args = parser.parse_args(["--batch_size", "32"])
111
    assert args.batch_size == 32
112
    args = parser.parse_args(["--batch-size", "32"])
113
114
115
116
    assert args.batch_size == 32


def test_with_bool_flag(parser):
117
    args = parser.parse_args(["--enable_feature"])
118
    assert args.enable_feature is True
119
    args = parser.parse_args(["--enable-feature"])
120
121
122
123
124
    assert args.enable_feature is True


def test_invalid_choice(parser):
    with pytest.raises(SystemExit):
125
        parser.parse_args(["--image_input_type", "invalid_choice"])
126
127
128


def test_missing_required_argument(parser):
129
    parser.add_argument("--required-arg", required=True)
130
131
    with pytest.raises(SystemExit):
        parser.parse_args([])
132
133


134
def test_cli_override_to_config(parser_with_config, cli_config_file):
135
136
137
    args = parser_with_config.parse_args(
        ["serve", "mymodel", "--config", cli_config_file, "--tensor-parallel-size", "3"]
    )
138
    assert args.tensor_parallel_size == 3
139
140
141
    args = parser_with_config.parse_args(
        ["serve", "mymodel", "--tensor-parallel-size", "3", "--config", cli_config_file]
    )
142
    assert args.tensor_parallel_size == 3
143
    assert args.port == 12312
144
145
146
147
148
149
150
151
152
153
154
155
    args = parser_with_config.parse_args(
        [
            "serve",
            "mymodel",
            "--tensor-parallel-size",
            "3",
            "--config",
            cli_config_file,
            "--port",
            "666",
        ]
    )
156
157
    assert args.tensor_parallel_size == 3
    assert args.port == 666
158
159


160
def test_config_args(parser_with_config, cli_config_file):
161
    args = parser_with_config.parse_args(
162
163
        ["serve", "mymodel", "--config", cli_config_file]
    )
164
    assert args.tensor_parallel_size == 2
165
    assert args.trust_remote_code
166
167
168
169


def test_config_file(parser_with_config):
    with pytest.raises(FileNotFoundError):
170
        parser_with_config.parse_args(
171
172
            ["serve", "mymodel", "--config", "test_config.yml"]
        )
173
174
175

    with pytest.raises(ValueError):
        parser_with_config.parse_args(
176
177
            ["serve", "mymodel", "--config", "./data/test_config.json"]
        )
178
179

    with pytest.raises(ValueError):
180
181
182
183
184
185
186
187
188
189
190
        parser_with_config.parse_args(
            [
                "serve",
                "mymodel",
                "--tensor-parallel-size",
                "3",
                "--config",
                "--batch-size",
                "32",
            ]
        )
191
192


193
def test_no_model_tag(parser_with_config, cli_config_file):
194
    with pytest.raises(ValueError):
195
        parser_with_config.parse_args(["serve", "--config", cli_config_file])
196
197


198
199
200
201
202
def test_dict_args(parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
203
        # Test nesting
204
205
206
207
        "--hf-overrides.key2.key3",
        "val2",
        "--hf-overrides.key2.key4",
        "val3",
208
        # Test compile config and compilation mode
209
210
211
212
        "-O.use_inductor=true",
        "-O.backend",
        "custom",
        "-O1",
213
        # Test = sign
214
        "--hf-overrides.key5=val4",
215
216
217
218
219
        # Test underscore to dash conversion
        "--hf_overrides.key_6",
        "val5",
        "--hf_overrides.key-7.key_8",
        "val6",
220
221
222
223
224
225
226
227
228
        # Test data type detection
        "--hf_overrides.key9",
        "100",
        "--hf_overrides.key10",
        "100.0",
        "--hf_overrides.key11",
        "true",
        "--hf_overrides.key12.key13",
        "null",
229
230
231
232
233
234
235
        # Test '-' and '.' in value
        "--hf_overrides.key14.key15",
        "-minus.and.dot",
        # Test array values
        "-O.custom_ops+",
        "-quant_fp8",
        "-O.custom_ops+=+silu_mul,-rms_norm",
236
237
238
239
240
241
242
243
244
245
    ]
    parsed_args = parser.parse_args(args)
    assert parsed_args.model_name == "something.something"
    assert parsed_args.hf_overrides == {
        "key1": "val1",
        "key2": {
            "key3": "val2",
            "key4": "val3",
        },
        "key5": "val4",
246
247
248
249
        "key_6": "val5",
        "key-7": {
            "key_8": "val6",
        },
250
251
252
253
254
255
        "key9": 100,
        "key10": 100.0,
        "key11": True,
        "key12": {
            "key13": None,
        },
256
257
        "key14": {
            "key15": "-minus.and.dot",
258
        },
259
    }
260
    assert parsed_args.compilation_config == {
261
        "mode": 1,
262
263
264
265
266
267
268
269
270
271
272
273
274
275
        "use_inductor": True,
        "backend": "custom",
        "custom_ops": ["-quant_fp8", "+silu_mul", "-rms_norm"],
    }


def test_duplicate_dict_args(caplog_vllm, parser):
    args = [
        "--model-name=something.something",
        "--hf-overrides.key1",
        "val1",
        "--hf-overrides.key1",
        "val2",
        "-O1",
276
        "-O.mode",
277
278
279
280
281
282
283
        "2",
        "-O3",
    ]

    parsed_args = parser.parse_args(args)
    # Should be the last value
    assert parsed_args.hf_overrides == {"key1": "val2"}
284
    assert parsed_args.compilation_config == {"mode": 3}
285
286
287
288

    assert len(caplog_vllm.records) == 1
    assert "duplicate" in caplog_vllm.text
    assert "--hf-overrides.key1" in caplog_vllm.text
289
    assert "-O.mode" in caplog_vllm.text
290
291


292
@create_new_process_for_each_test()
293
294
295
def test_memory_profiling():
    # Fake out some model loading + inference memory usage to test profiling
    # Memory used by other processes will show up as cuda usage outside of torch
296
297
    from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary

298
299
300
301
    lib = CudaRTLibrary()
    # 512 MiB allocation outside of this instance
    handle1 = lib.cudaMalloc(512 * 1024 * 1024)

302
    baseline_snapshot = MemorySnapshot()
303
304
305

    # load weights

306
    weights = torch.randn(128, 1024, 1024, device="cuda", dtype=torch.float32)
307

308
    weights_memory = 128 * 1024 * 1024 * 4  # 512 MiB
309

310
311
312
313
314
315
316
    def measure_current_non_torch():
        free, total = torch.cuda.mem_get_info()
        current_used = total - free
        current_torch = torch.cuda.memory_reserved()
        current_non_torch = current_used - current_torch
        return current_non_torch

317
318
319
320
321
322
    with (
        memory_profiling(
            baseline_snapshot=baseline_snapshot, weights_memory=weights_memory
        ) as result,
        monitor(measure_current_non_torch) as monitored_values,
    ):
323
        # make a memory spike, 1 GiB
324
        spike = torch.randn(256, 1024, 1024, device="cuda", dtype=torch.float32)
325
326
327
328
329
        del spike

        # Add some extra non-torch memory 256 MiB (simulate NCCL)
        handle2 = lib.cudaMalloc(256 * 1024 * 1024)

330
331
332
333
334
    # this is an analytic value, it is exact,
    # we only have 256 MiB non-torch memory increase
    measured_diff = monitored_values.values[-1] - monitored_values.values[0]
    assert measured_diff == 256 * 1024 * 1024

335
    # Check that the memory usage is within 5% of the expected values
336
337
    # 5% tolerance is caused by cuda runtime.
    # we cannot control cuda runtime in the granularity of bytes,
338
    # which causes a small error (<10 MiB in practice)
339
    non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024)  # noqa
340
    assert abs(non_torch_ratio - 1) <= 0.05
341
    assert result.torch_peak_increase == 1024 * 1024 * 1024
342
343
344
    del weights
    lib.cudaFree(handle1)
    lib.cudaFree(handle2)
345
346


347
348
349
350
def test_bind_kv_cache():
    from vllm.attention import Attention

    ctx = {
351
352
353
354
        "layers.0.self_attn": Attention(32, 128, 0.1),
        "layers.1.self_attn": Attention(32, 128, 0.1),
        "layers.2.self_attn": Attention(32, 128, 0.1),
        "layers.3.self_attn": Attention(32, 128, 0.1),
355
356
    }
    kv_cache = [
357
358
359
360
        torch.zeros((1,)),
        torch.zeros((1,)),
        torch.zeros((1,)),
        torch.zeros((1,)),
361
362
    ]
    bind_kv_cache(ctx, [kv_cache])
363
364
365
366
367
    assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
    assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
    assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[2]
    assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[3]

368

369
370
371
372
def test_bind_kv_cache_kv_sharing():
    from vllm.attention import Attention

    ctx = {
373
374
375
376
        "layers.0.self_attn": Attention(32, 128, 0.1),
        "layers.1.self_attn": Attention(32, 128, 0.1),
        "layers.2.self_attn": Attention(32, 128, 0.1),
        "layers.3.self_attn": Attention(32, 128, 0.1),
377
378
    }
    kv_cache = [
379
380
381
382
        torch.zeros((1,)),
        torch.zeros((1,)),
        torch.zeros((1,)),
        torch.zeros((1,)),
383
384
    ]
    shared_kv_cache_layers = {
385
386
        "layers.2.self_attn": "layers.1.self_attn",
        "layers.3.self_attn": "layers.0.self_attn",
387
388
    }
    bind_kv_cache(ctx, [kv_cache], shared_kv_cache_layers)
389
390
391
392
393
    assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0]
    assert ctx["layers.1.self_attn"].kv_cache[0] is kv_cache[1]
    assert ctx["layers.2.self_attn"].kv_cache[0] is kv_cache[1]
    assert ctx["layers.3.self_attn"].kv_cache[0] is kv_cache[0]

394

395
396
397
398
399
def test_bind_kv_cache_non_attention():
    from vllm.attention import Attention

    # example from Jamba PP=2
    ctx = {
400
401
        "model.layers.20.attn": Attention(32, 128, 0.1),
        "model.layers.28.attn": Attention(32, 128, 0.1),
402
403
    }
    kv_cache = [
404
405
        torch.zeros((1,)),
        torch.zeros((1,)),
406
407
    ]
    bind_kv_cache(ctx, [kv_cache])
408
409
    assert ctx["model.layers.20.attn"].kv_cache[0] is kv_cache[0]
    assert ctx["model.layers.28.attn"].kv_cache[0] is kv_cache[1]
410
411
412


def test_bind_kv_cache_pp():
413
    with patch("vllm.utils.torch_utils.cuda_device_count_stateless", lambda: 2):
414
        # this test runs with 1 GPU, but we simulate 2 GPUs
415
        cfg = VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=2))
416
417
418
419
    with set_current_vllm_config(cfg):
        from vllm.attention import Attention

        ctx = {
420
            "layers.0.self_attn": Attention(32, 128, 0.1),
421
        }
422
        kv_cache = [[torch.zeros((1,))], [torch.zeros((1,))]]
423
        bind_kv_cache(ctx, kv_cache)
424
425
        assert ctx["layers.0.self_attn"].kv_cache[0] is kv_cache[0][0]
        assert ctx["layers.0.self_attn"].kv_cache[1] is kv_cache[1][0]
426
427


428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
@pytest.mark.parametrize(
    ("src_dtype", "tgt_dtype", "expected_result"),
    [
        # Different precision_levels
        (torch.bool, torch.int8, True),
        (torch.bool, torch.float16, True),
        (torch.bool, torch.complex32, True),
        (torch.int64, torch.bool, False),
        (torch.int64, torch.float16, True),
        (torch.int64, torch.complex32, True),
        (torch.float64, torch.bool, False),
        (torch.float64, torch.int8, False),
        (torch.float64, torch.complex32, True),
        (torch.complex128, torch.bool, False),
        (torch.complex128, torch.int8, False),
        (torch.complex128, torch.float16, False),
        # precision_level=0
        (torch.bool, torch.bool, True),
        # precision_level=1
        (torch.int8, torch.int16, True),
        (torch.int16, torch.int8, False),
        (torch.uint8, torch.int8, False),
        (torch.int8, torch.uint8, False),
        # precision_level=2
        (torch.float16, torch.float32, True),
        (torch.float32, torch.float16, False),
        (torch.bfloat16, torch.float32, True),
        (torch.float32, torch.bfloat16, False),
        # precision_level=3
        (torch.complex32, torch.complex64, True),
        (torch.complex64, torch.complex32, False),
    ],
)
def test_is_lossless_cast(src_dtype, tgt_dtype, expected_result):
    assert is_lossless_cast(src_dtype, tgt_dtype) == expected_result


@pytest.mark.parametrize(
    ("dtypes", "expected_result"),
    [
        ([torch.bool], torch.bool),
        ([torch.bool, torch.int8], torch.int8),
        ([torch.bool, torch.int8, torch.float16], torch.float16),
        ([torch.bool, torch.int8, torch.float16, torch.complex32], torch.complex32),  # noqa: E501
    ],
)
def test_common_broadcastable_dtype(dtypes, expected_result):
    assert common_broadcastable_dtype(dtypes) == expected_result


478
479
480
def test_model_specification(
    parser_with_config, cli_config_file, cli_config_file_with_model
):
481
    # Test model in CLI takes precedence over config
482
    args = parser_with_config.parse_args(
483
484
485
486
        ["serve", "cli-model", "--config", cli_config_file_with_model]
    )
    assert args.model_tag == "cli-model"
    assert args.served_model_name == "mymodel"
487
488

    # Test model from config file works
489
490
491
492
493
494
495
496
497
    args = parser_with_config.parse_args(
        [
            "serve",
            "--config",
            cli_config_file_with_model,
        ]
    )
    assert args.model == "config-model"
    assert args.served_model_name == "mymodel"
498
499
500

    # Test no model specified anywhere raises error
    with pytest.raises(ValueError, match="No model specified!"):
501
        parser_with_config.parse_args(["serve", "--config", cli_config_file])
502
503

    # Test using --model option raises error
504
505
506
507
508
509
510
511
512
513
514
    # with pytest.raises(
    #         ValueError,
    #         match=
    #     ("With `vllm serve`, you should provide the model as a positional "
    #      "argument or in a config file instead of via the `--model` option."),
    # ):
    #     parser_with_config.parse_args(['serve', '--model', 'my-model'])

    # Test using --model option back-compatibility
    # (when back-compatibility ends, the above test should be uncommented
    # and the below test should be removed)
515
516
517
518
519
520
521
522
523
524
525
526
    args = parser_with_config.parse_args(
        [
            "serve",
            "--tensor-parallel-size",
            "2",
            "--model",
            "my-model",
            "--trust-remote-code",
            "--port",
            "8001",
        ]
    )
527
528
529
530
531
    assert args.model is None
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 8001

532
533
534
535
536
537
538
539
540
    args = parser_with_config.parse_args(
        [
            "serve",
            "--tensor-parallel-size=2",
            "--model=my-model",
            "--trust-remote-code",
            "--port=8001",
        ]
    )
541
542
543
544
    assert args.model is None
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 8001
545
546

    # Test other config values are preserved
547
548
549
550
551
552
553
554
    args = parser_with_config.parse_args(
        [
            "serve",
            "cli-model",
            "--config",
            cli_config_file_with_model,
        ]
    )
555
556
557
558
559
    assert args.tensor_parallel_size == 2
    assert args.trust_remote_code is True
    assert args.port == 12312


560
@pytest.mark.parametrize("input", [(), ("abc",), (None,), (None, bool, [1, 2, 3])])
561
562
563
564
565
def test_sha256(input: tuple):
    digest = sha256(input)
    assert digest is not None
    assert isinstance(digest, bytes)
    assert digest != b""
566

567
568
    input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
    assert digest == hashlib.sha256(input_bytes).digest()
569
570

    # hashing again, returns the same value
571
    assert digest == sha256(input)
572
573

    # hashing different input, returns different value
574
    assert digest != sha256(input + (1,))
575
576
577
578
579
580
581
582
583


@pytest.mark.parametrize(
    "path,expected",
    [
        ("ipc://some_path", ("ipc", "some_path", "")),
        ("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
        ("tcp://[::1]:5555", ("tcp", "::1", "5555")),  # IPv6 address
        ("inproc://some_identifier", ("inproc", "some_identifier", "")),
584
585
    ],
)
586
587
588
589
590
591
592
593
594
595
596
def test_split_zmq_path(path, expected):
    assert split_zmq_path(path) == expected


@pytest.mark.parametrize(
    "invalid_path",
    [
        "invalid_path",  # Missing scheme
        "tcp://127.0.0.1",  # Missing port
        "tcp://[::1]",  # Missing port for IPv6
        "tcp://:5555",  # Missing host
597
598
    ],
)
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def test_split_zmq_path_invalid(invalid_path):
    with pytest.raises(ValueError):
        split_zmq_path(invalid_path)


def test_make_zmq_socket_ipv6():
    # Check if IPv6 is supported by trying to create an IPv6 socket
    try:
        sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
        sock.close()
    except socket.error:
        pytest.skip("IPv6 is not supported on this system")

    ctx = zmq.Context()
    ipv6_path = "tcp://[::]:5555"  # IPv6 loopback address
    socket_type = zmq.REP  # Example socket type

    # Create the socket
    zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)

    # Verify that the IPV6 option is set
620
621
622
    assert zsock.getsockopt(zmq.IPV6) == 1, (
        "IPV6 option should be enabled for IPv6 addresses"
    )
623
624
625
626

    # Clean up
    zsock.close()
    ctx.term()
627
628
629
630
631


def test_make_zmq_path():
    assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
    assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672


def test_get_tcp_uri():
    assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
    assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"


def test_split_host_port():
    # valid ipv4
    assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
    # invalid ipv4
    with pytest.raises(ValueError):
        # multi colon
        assert split_host_port("127.0.0.1::5555")
    with pytest.raises(ValueError):
        # tailing colon
        assert split_host_port("127.0.0.1:5555:")
    with pytest.raises(ValueError):
        # no colon
        assert split_host_port("127.0.0.15555")
    with pytest.raises(ValueError):
        # none int port
        assert split_host_port("127.0.0.1:5555a")

    # valid ipv6
    assert split_host_port("[::1]:5555") == ("::1", 5555)
    # invalid ipv6
    with pytest.raises(ValueError):
        # multi colon
        assert split_host_port("[::1]::5555")
    with pytest.raises(IndexError):
        # no colon
        assert split_host_port("[::1]5555")
    with pytest.raises(ValueError):
        # none int port
        assert split_host_port("[::1]:5555a")


def test_join_host_port():
    assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
    assert join_host_port("::1", 5555) == "[::1]:5555"
673
674
675
676
677
678


def test_convert_ids_list_to_tokens():
    tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
    token_ids = tokenizer.encode("Hello, world!")
    # token_ids = [9707, 11, 1879, 0]
679
    assert tokenizer.convert_ids_to_tokens(token_ids) == ["Hello", ",", "Ġworld", "!"]
680
    tokens = convert_ids_list_to_tokens(tokenizer, token_ids)
681
    assert tokens == ["Hello", ",", " world", "!"]
682
683
684
685


def test_current_stream_multithread():
    import threading
686

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")

    main_default_stream = torch.cuda.current_stream()
    child_stream = torch.cuda.Stream()

    thread_stream_ready = threading.Event()
    thread_can_exit = threading.Event()

    def child_thread_func():
        with torch.cuda.stream(child_stream):
            thread_stream_ready.set()
            thread_can_exit.wait(timeout=10)

    child_thread = threading.Thread(target=child_thread_func)
    child_thread.start()

    try:
705
706
707
        assert thread_stream_ready.wait(timeout=5), (
            "Child thread failed to enter stream context in time"
        )
708
709
710

        main_current_stream = current_stream()

711
712
713
714
715
716
        assert main_current_stream != child_stream, (
            "Main thread's current_stream was contaminated by child thread"
        )
        assert main_current_stream == main_default_stream, (
            "Main thread's current_stream is not the default stream"
        )
717
718
719
720
721
722
723
724
725

        # Notify child thread it can exit
        thread_can_exit.set()

    finally:
        # Ensure child thread exits properly
        child_thread.join(timeout=5)
        if child_thread.is_alive():
            pytest.fail("Child thread failed to exit properly")
726
727
728
729
730
731
732
733


def test_load_config_file(tmp_path):
    # Define the configuration data
    config_data = {
        "enable-logging": True,
        "list-arg": ["item1", "item2"],
        "port": 12323,
734
        "tensor-parallel-size": 4,
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    }

    # Write the configuration data to a temporary YAML file
    config_file_path = tmp_path / "config.yaml"
    with open(config_file_path, "w") as config_file:
        yaml.dump(config_data, config_file)

    # Initialize the parser
    parser = FlexibleArgumentParser()

    # Call the function with the temporary file path
    processed_args = parser.load_config_file(str(config_file_path))

    # Expected output
    expected_args = [
        "--enable-logging",
        "--list-arg",
        "item1",
        "item2",
        "--port",
        "12323",
        "--tensor-parallel-size",
        "4",
    ]

    # Assert that the processed arguments match the expected output
    assert processed_args == expected_args
    os.remove(str(config_file_path))
763
764
765
766
767
768
769
770
771
772
773
774


def test_unique_filepath():
    temp_dir = tempfile.mkdtemp()
    path_fn = lambda i: Path(temp_dir) / f"file_{i}.txt"
    paths = set()
    for i in range(10):
        path = unique_filepath(path_fn)
        path.write_text("test")
        paths.add(path)
    assert len(paths) == 10
    assert len(list(Path(temp_dir).glob("*.txt"))) == 10
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796


def test_flat_product():
    # Check regular itertools.product behavior
    result1 = list(flat_product([1, 2, 3], ["a", "b"]))
    assert result1 == [
        (1, "a"),
        (1, "b"),
        (2, "a"),
        (2, "b"),
        (3, "a"),
        (3, "b"),
    ]

    # check that the tuples get flattened
    result2 = list(flat_product([(1, 2), (3, 4)], ["a", "b"], [(5, 6)]))
    assert result2 == [
        (1, 2, "a", 5, 6),
        (1, 2, "b", 5, 6),
        (3, 4, "a", 5, 6),
        (3, 4, "b", 5, 6),
    ]