test_mp_reducer.py 1.81 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import os
4
5
6
7
8
9
import sys
from unittest.mock import patch

from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
10
from ..utils import models_path_prefix
11
12


13
def test_mp_reducer():
14
15
16
17
18
19
20
    """
    Test that _reduce_config reducer is registered when AsyncLLM is instantiated
    without transformers_modules. This is a regression test for
    https://github.com/vllm-project/vllm/pull/18640.
    """

    # Ensure transformers_modules is not in sys.modules
21
22
    if "transformers_modules" in sys.modules:
        del sys.modules["transformers_modules"]
23

24
    with patch("multiprocessing.reducer.register") as mock_register:
25
        engine_args = AsyncEngineArgs(
26
            model=os.path.join(models_path_prefix, "facebook/opt-125m"),
27
28
29
30
31
32
33
34
35
36
37
            max_model_len=32,
            gpu_memory_utilization=0.1,
            disable_log_stats=True,
        )

        async_llm = AsyncLLM.from_engine_args(
            engine_args,
            start_engine_loop=False,
        )

        assert mock_register.called, (
38
39
            "multiprocessing.reducer.register should have been called"
        )
40
41
42
43
44
45
46
47

        vllm_config_registered = False
        for call_args in mock_register.call_args_list:
            # Verify that a reducer for VllmConfig was registered
            if len(call_args[0]) >= 2 and call_args[0][0] == VllmConfig:
                vllm_config_registered = True

                reducer_func = call_args[0][1]
48
                assert callable(reducer_func), "Reducer function should be callable"
49
50
51
52
53
54
55
                break

        assert vllm_config_registered, (
            "VllmConfig should have been registered to multiprocessing.reducer"
        )

        async_llm.shutdown()