test_async_llm_engine.py 2.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import asyncio
from dataclasses import dataclass

import pytest

from vllm.engine.async_llm_engine import AsyncLLMEngine


@dataclass
class RequestOutput:
    request_id: int
    finished: bool = False


class MockEngine:

    def __init__(self):
        self.step_calls = 0
        self.add_request_calls = 0
        self.abort_request_calls = 0
        self.request_id = None

    async def step_async(self):
        self.step_calls += 1
        return [RequestOutput(
            request_id=self.request_id)] if self.request_id else []

28
29
    async def encode_request_async(self, *args, **kwargs):
        pass
30

31
32
33
34
35
36
37
    def generate(self, request_id):
        self.request_id = request_id

    def stop_generating(self):
        self.request_id = None

    def add_request(self, **kwargs):
38
        del kwargs  # Unused
39
40
        self.add_request_calls += 1

41
42
    async def add_request_async(self, **kwargs):
        self.add_request_calls += 1
43
        return
44

45
    def abort_request(self, request_id):
46
        del request_id  # Unused
47
48
        self.abort_request_calls += 1

49
50
51
    def has_unfinished_requests(self):
        return self.request_id is not None

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

class MockAsyncLLMEngine(AsyncLLMEngine):

    def _init_engine(self, *args, **kwargs):
        return MockEngine()


@pytest.mark.asyncio
async def test_new_requests_event():
    engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
    engine.start_background_loop()
    await asyncio.sleep(0.01)
    assert engine.engine.step_calls == 0

    await engine.add_request("1", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 1
    assert engine.engine.step_calls == 1

    await engine.add_request("2", "", None)
    engine.engine.generate("2")
    await asyncio.sleep(0)
    await asyncio.sleep(0)
75
76
77
78
    assert engine.engine.add_request_calls == 2
    assert engine.engine.step_calls >= 2
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls >= 3
79
    engine.engine.stop_generating()
80
81
82
83
    await asyncio.sleep(0.001)
    old_step_calls = engine.engine.step_calls
    await asyncio.sleep(0.001)
    assert engine.engine.step_calls == old_step_calls
84
85
86
87

    await engine.add_request("3", "", None)
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
88
    assert engine.engine.step_calls == old_step_calls + 1
89
90
    await asyncio.sleep(0.01)
    assert engine.engine.add_request_calls == 3
91
    assert engine.engine.step_calls == old_step_calls + 1
92
93
94

    engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
    assert engine.get_tokenizer() is not None