test_request_tracker.py 2.13 KB
Newer Older
1
2
3
4
5
6
import pytest

from vllm.engine.async_llm_engine import RequestTracker
from vllm.outputs import RequestOutput


7
8
@pytest.mark.asyncio
async def test_request_tracker():
9
10
    tracker = RequestTracker()
    stream_1 = tracker.add_request("1")
11
12
    assert tracker.new_requests_event.is_set()
    await tracker.wait_for_new_requests()
13
    new, aborted = tracker.get_new_and_aborted_requests()
14
    assert not tracker.new_requests_event.is_set()
15
16
    assert len(new) == 1
    assert new[0]["request_id"] == "1"
17
    assert not aborted
18
19
20
21
    assert not stream_1.finished

    stream_2 = tracker.add_request("2")
    stream_3 = tracker.add_request("3")
22
23
    assert tracker.new_requests_event.is_set()
    await tracker.wait_for_new_requests()
24
    new, aborted = tracker.get_new_and_aborted_requests()
25
    assert not tracker.new_requests_event.is_set()
26
27
28
    assert len(new) == 2
    assert new[0]["request_id"] == "2"
    assert new[1]["request_id"] == "3"
29
    assert not aborted
30
31
32
33
34
35
    assert not stream_2.finished
    assert not stream_3.finished

    # request_ids must be unique
    with pytest.raises(KeyError):
        tracker.add_request("1")
36
    assert not tracker.new_requests_event.is_set()
37
38

    tracker.abort_request("1")
39
40
41
    new, aborted = tracker.get_new_and_aborted_requests()
    assert len(aborted) == 1
    assert "1" in aborted
42
43
44
45
46
    assert not new
    assert stream_1.finished

    stream_4 = tracker.add_request("4")
    tracker.abort_request("4")
47
48
    assert tracker.new_requests_event.is_set()
    await tracker.wait_for_new_requests()
49
50
51
    new, aborted = tracker.get_new_and_aborted_requests()
    assert len(aborted) == 1
    assert "4" in aborted
52
53
54
55
    assert not new
    assert stream_4.finished

    stream_5 = tracker.add_request("5")
56
    assert tracker.new_requests_event.is_set()
57
    tracker.process_request_output(
58
59
        RequestOutput("2", "output", [], [], [], finished=True))
    await tracker.wait_for_new_requests()
60
    new, aborted = tracker.get_new_and_aborted_requests()
61
    assert not tracker.new_requests_event.is_set()
62
    assert not aborted
63
64
65
66
    assert len(new) == 1
    assert new[0]["request_id"] == "5"
    assert stream_2.finished
    assert not stream_5.finished