test_async_llm_engine.py 4.06 KB
Newer Older
1
2
3
4
import asyncio
from dataclasses import dataclass

import pytest
5
import torch
6

7
from vllm import SamplingParams
8
from vllm.config import ParallelConfig
9
10
11
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine

from ..utils import wait_for_gpu_memory_to_clear
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None
27
28
        # Ugly, remove dependency when possible
        self.parallel_config = ParallelConfig(1, 1, False)
29

30
31
    async def step_async(self, virtual_engine):
        # PP size is 1, ignore virtual engine
32
33
34
35
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

36
    async def process_model_inputs_async(self, *args, **kwargs):
37
        pass
38

39
40
41
    async def stop_remote_worker_execution_loop_async(self):
        pass

42
43
44
45
46
47
48
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
49
        del kwargs  # Unused
50
        self.add_request_calls += 1
51
        print(f'Request calls: {self.add_request_calls}')
52

53
54
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
55
        return
56

57
    def abort_request(self, request_id):
58
        del request_id  # Unused
59
60
        self.abort_request_calls += 1

61
62
63
    def has_unfinished_requests(self):
        return self.request_id is not None

64
65
66
    def has_unfinished_requests_for_virtual_engine(self, virtual_engine):
        return self.request_id is not None

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89

class MockAsyncLLMEngine(AsyncLLMEngine):

    def _init_engine(self, *args, **kwargs):
        return MockEngine()


@pytest.mark.asyncio
async def test_new_requests_event():
    engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

    await engine.add_request("1", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

    await engine.add_request("2", "", None)
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
90
    await asyncio.sleep(0)
91
92
93
94
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
95
    engine.engine.stop_generating()
96
97
98
99
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
100
101
102
103

    await engine.add_request("3", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
104
    assert engine.engine.step_calls == old_step_calls + 1
105
106
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
107
    assert engine.engine.step_calls == old_step_calls + 1
108
109

    engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
110
    assert engine.get_model_config() is not None
111
    assert engine.get_tokenizer() is not None
112
    assert engine.get_decoding_config() is not None
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


def test_asyncio_run():
    wait_for_gpu_memory_to_clear(
        devices=list(range(torch.cuda.device_count())),
        threshold_bytes=2 * 2**30,
        timeout_s=60,
    )

    engine = AsyncLLMEngine.from_engine_args(
        AsyncEngineArgs(model="facebook/opt-125m"))

    async def run(prompt: str):
        sampling_params = SamplingParams(
            temperature=0,
            max_tokens=32,
        )

        async for output in engine.generate(prompt,
                                            sampling_params,
                                            request_id=prompt):
            final_output = output
        return final_output

    async def generate():
        return await asyncio.gather(
            run("test0"),
            run("test1"),
        )

    results = asyncio.run(generate())
    assert len(results) == 2