test_executor.py 3.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import os
6
from typing import Any, Callable, Optional, Union
7
8
9
10
11
12

import pytest

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
13
from vllm.executor.uniproc_executor import UniProcExecutor
14
from vllm.sampling_params import SamplingParams
15
16
import os
from ..utils import models_path_prefix
17
import vllm.envs as envs
18
19
20
21
22
23


class Mock:
    ...


24
class CustomUniExecutor(UniProcExecutor):
25

26
    def collective_rpc(self,
27
                       method: Union[str, Callable],
28
                       timeout: Optional[float] = None,
29
30
                       args: tuple = (),
                       kwargs: Optional[dict] = None) -> list[Any]:
31
32
33
        # Drop marker to show that this was ran
        with open(".marker", "w"):
            ...
34
        return super().collective_rpc(method, timeout, args, kwargs)
35
36


37
CustomUniExecutorAsync = CustomUniExecutor
38

zhuwenwen's avatar
zhuwenwen committed
39
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "distilbert/distilgpt2")])
40
41
42
43
44
45
46
47
48
49
50
def test_custom_executor_type_checking(model):
    with pytest.raises(ValueError):
        engine_args = EngineArgs(model=model,
                                 distributed_executor_backend=Mock)
        LLMEngine.from_engine_args(engine_args)
    with pytest.raises(ValueError):
        engine_args = AsyncEngineArgs(model=model,
                                      distributed_executor_backend=Mock)
        AsyncLLMEngine.from_engine_args(engine_args)


zhuwenwen's avatar
zhuwenwen committed
51
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "distilbert/distilgpt2")])
52
def test_custom_executor(model, tmp_path):
53
    cwd = os.path.abspath(".")
54
    os.chdir(tmp_path)
55
56
57
58
    try:
        assert not os.path.exists(".marker")

        engine_args = EngineArgs(
59
60
            model=model,
            distributed_executor_backend=CustomUniExecutor,
61
            enforce_eager=True,  # reduce test time
zhuwenwen's avatar
zhuwenwen committed
62
            block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA else 16,
63
        )
64
65
66
67
68
69
70
71
72
73
74
        engine = LLMEngine.from_engine_args(engine_args)
        sampling_params = SamplingParams(max_tokens=1)

        engine.add_request("0", "foo", sampling_params)
        engine.step()

        assert os.path.exists(".marker")
    finally:
        os.chdir(cwd)


75
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "distilbert/distilgpt2")])
76
def test_custom_executor_async(model, tmp_path):
77
    cwd = os.path.abspath(".")
78
    os.chdir(tmp_path)
79
80
81
82
    try:
        assert not os.path.exists(".marker")

        engine_args = AsyncEngineArgs(
83
84
85
            model=model,
            distributed_executor_backend=CustomUniExecutorAsync,
            enforce_eager=True,  # reduce test time
zhuwenwen's avatar
zhuwenwen committed
86
            block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA else 16,
87
        )
88
89
90
91
92
93
94
95
96
97
98
99
100
        engine = AsyncLLMEngine.from_engine_args(engine_args)
        sampling_params = SamplingParams(max_tokens=1)

        async def t():
            stream = await engine.add_request("0", "foo", sampling_params)
            async for x in stream:
                ...

        asyncio.run(t())

        assert os.path.exists(".marker")
    finally:
        os.chdir(cwd)
101
102


103
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "distilbert/distilgpt2")])
104
105
106
107
108
109
110
111
112
def test_respect_ray(model):
    # even for TP=1 and PP=1,
    # if users specify ray, we should use ray.
    # users might do this if they want to manage the
    # resources using ray.
    engine_args = EngineArgs(
        model=model,
        distributed_executor_backend="ray",
        enforce_eager=True,  # reduce test time
zhuwenwen's avatar
zhuwenwen committed
113
        block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA else 16,
114
115
    )
    engine = LLMEngine.from_engine_args(engine_args)
116
    assert engine.model_executor.uses_ray