test_executor.py 3.94 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import asyncio
import os
5
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
6
7
8

import pytest

9
from vllm.config import LoadFormat
10
11
12
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
13
from vllm.executor.uniproc_executor import UniProcExecutor
14
from vllm.sampling_params import SamplingParams
15
16
import os
from ..utils import models_path_prefix
17

18
19
20
21
from ..conftest import MODEL_WEIGHTS_S3_BUCKET

RUNAI_STREAMER_LOAD_FORMAT = LoadFormat.RUNAI_STREAMER

22
23
24
25
26

class Mock:
    ...


27
class CustomUniExecutor(UniProcExecutor):
28

29
    def collective_rpc(self,
30
                       method: Union[str, Callable],
31
32
33
                       timeout: Optional[float] = None,
                       args: Tuple = (),
                       kwargs: Optional[Dict] = None) -> List[Any]:
34
35
36
        # Drop marker to show that this was ran
        with open(".marker", "w"):
            ...
37
        return super().collective_rpc(method, timeout, args, kwargs)
38
39


40
CustomUniExecutorAsync = CustomUniExecutor
41

42
43
@pytest.mark.parametrize("model",
                         [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"])
44
45
46
def test_custom_executor_type_checking(model):
    with pytest.raises(ValueError):
        engine_args = EngineArgs(model=model,
47
                                 load_format=RUNAI_STREAMER_LOAD_FORMAT,
48
49
50
51
52
53
54
55
                                 distributed_executor_backend=Mock)
        LLMEngine.from_engine_args(engine_args)
    with pytest.raises(ValueError):
        engine_args = AsyncEngineArgs(model=model,
                                      distributed_executor_backend=Mock)
        AsyncLLMEngine.from_engine_args(engine_args)


56
57
@pytest.mark.parametrize("model",
                         [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"])
58
def test_custom_executor(model, tmp_path):
59
    cwd = os.path.abspath(".")
60
    os.chdir(tmp_path)
61
62
63
64
    try:
        assert not os.path.exists(".marker")

        engine_args = EngineArgs(
65
            model=model,
66
            load_format=RUNAI_STREAMER_LOAD_FORMAT,
67
            distributed_executor_backend=CustomUniExecutor,
68
            enforce_eager=True,  # reduce test time
69
        )
70
71
72
73
74
75
76
77
78
79
80
        engine = LLMEngine.from_engine_args(engine_args)
        sampling_params = SamplingParams(max_tokens=1)

        engine.add_request("0", "foo", sampling_params)
        engine.step()

        assert os.path.exists(".marker")
    finally:
        os.chdir(cwd)


81
82
@pytest.mark.parametrize("model",
                         [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"])
83
def test_custom_executor_async(model, tmp_path):
84
    cwd = os.path.abspath(".")
85
    os.chdir(tmp_path)
86
87
88
89
    try:
        assert not os.path.exists(".marker")

        engine_args = AsyncEngineArgs(
90
            model=model,
91
            load_format=RUNAI_STREAMER_LOAD_FORMAT,
92
93
94
            distributed_executor_backend=CustomUniExecutorAsync,
            enforce_eager=True,  # reduce test time
        )
95
96
97
98
99
100
101
102
103
104
105
106
107
        engine = AsyncLLMEngine.from_engine_args(engine_args)
        sampling_params = SamplingParams(max_tokens=1)

        async def t():
            stream = await engine.add_request("0", "foo", sampling_params)
            async for x in stream:
                ...

        asyncio.run(t())

        assert os.path.exists(".marker")
    finally:
        os.chdir(cwd)
108
109


110
111
@pytest.mark.parametrize("model",
                         [f"{MODEL_WEIGHTS_S3_BUCKET}/distilbert/distilgpt2"])
112
113
114
115
116
117
118
119
def test_respect_ray(model):
    # even for TP=1 and PP=1,
    # if users specify ray, we should use ray.
    # users might do this if they want to manage the
    # resources using ray.
    engine_args = EngineArgs(
        model=model,
        distributed_executor_backend="ray",
120
        load_format=RUNAI_STREAMER_LOAD_FORMAT,
121
122
123
124
        enforce_eager=True,  # reduce test time
    )
    engine = LLMEngine.from_engine_args(engine_args)
    assert engine.model_executor.uses_ray