test_async_omni.py 5.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import asyncio
import os
import sys
from contextlib import ExitStack
from pathlib import Path

import pytest
from vllm import SamplingParams
from vllm.inputs import PromptType

from vllm_omni.entrypoints.async_omni import AsyncOmni, ClientRequestState

os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

SEED = 42

stage_config = str(Path(__file__).parent / "stage_configs" / "qwen3_omni_thinker_ci.yaml")
model = "Qwen/Qwen3-Omni-30B-A3B-Instruct"


async def generate(
    engine: AsyncOmni,
    request_id: str,
    prompt: PromptType,
    max_tokens: int,
) -> tuple[int, str]:
    # Ensure generate doesn't complete too fast for cancellation test.
    await asyncio.sleep(0.2)
    thinker_sampling_params = SamplingParams(
        temperature=0.4,  # Deterministic
        top_p=0.9,
        top_k=1,
        max_tokens=max_tokens,
        repetition_penalty=1.05,
        stop_token_ids=[151645],  # Qwen EOS token <|im_end|>
        seed=SEED,
    )

    sampling_params_list = [
        thinker_sampling_params,
    ]
    count = 0
    async for omni_output in engine.generate(
        prompt=prompt,
        request_id=request_id,
        sampling_params_list=sampling_params_list,
        output_modalities=["text"],
    ):
        stage_id = omni_output.stage_id
        out = omni_output.request_output
        if stage_id == 0:
            num_tokens = sum(len(output.token_ids) for output in out.outputs)
            count = num_tokens

        await asyncio.sleep(0.0)

    return count, request_id


@pytest.mark.asyncio
async def test_abort():
    with ExitStack() as after:
        # Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS.
        engine = AsyncOmni(
            model=model,
            stage_configs_path=stage_config,
            shm_threshold_bytes=sys.maxsize,
        )
        after.callback(engine.shutdown)

        # Keep token counts modest to reduce flakiness on slow test hardware.
        NUM_REQUESTS = 3
        NUM_EXPECTED_TOKENS = 64
        NUM_EXPECTED_TOKENS_LONG = 256
        REQUEST_IDS_TO_ABORT = [1]

        prompt = "Hello my name is Robert and "

        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
            max_tokens = NUM_EXPECTED_TOKENS_LONG if (idx in REQUEST_IDS_TO_ABORT) else NUM_EXPECTED_TOKENS
            tasks.append(asyncio.create_task(generate(engine, request_id, prompt, max_tokens)))

        # API server cancels requests when they disconnect.
        # Explicitly abort in the engine to avoid orphaned requests hanging.
        for idx in REQUEST_IDS_TO_ABORT:
            tasks[idx].cancel()
            await engine.abort(request_ids[idx])
            await asyncio.sleep(0.1)

        # Confirm the other requests are okay.
        for idx, task in enumerate(tasks):
            # Confirm that it was actually canceled.
            if idx in REQUEST_IDS_TO_ABORT:
                with pytest.raises((asyncio.CancelledError, GeneratorExit)):
                    await asyncio.wait_for(task, timeout=60)
            else:
                # Otherwise, make sure the request was not impacted.
                num_generated_tokens, request_id = await asyncio.wait_for(task, timeout=180)
                expected_tokens = NUM_EXPECTED_TOKENS
                assert num_generated_tokens == expected_tokens, (
                    f"{request_id} generated {num_generated_tokens} but expected {expected_tokens}"
                )

        # Confirm we can do another generation.
        request_id = f"request-{REQUEST_IDS_TO_ABORT[0]}"
        task = asyncio.create_task(generate(engine, request_id, prompt, NUM_EXPECTED_TOKENS))
        num_generated_tokens, request_id = await task
        assert num_generated_tokens == NUM_EXPECTED_TOKENS
    await asyncio.sleep(5)


@pytest.mark.asyncio
async def test_build_and_log_summary(monkeypatch):
    from vllm_omni.entrypoints.utils import get_final_stage_id_for_e2e

    RealCRS = ClientRequestState
    capture_metrics = {}

    class MockCRS(RealCRS):
        def __init__(self, request_id: str):
            super().__init__(request_id)
            capture_metrics[request_id] = self

    monkeypatch.setattr("vllm_omni.entrypoints.async_omni.ClientRequestState", MockCRS)
    monkeypatch.setattr("vllm_omni.entrypoints.client_request_state.ClientRequestState", MockCRS)

    with ExitStack() as after:
        # Avoid SHM IPC in tests to prevent /dev/shm exhaustion and SIGBUS.
        engine = AsyncOmni(
            model=model,
            stage_configs_path=stage_config,
            shm_threshold_bytes=sys.maxsize,
        )
        after.callback(engine.shutdown)
        prompt = "Hello my name is Robert and "
        NUM_EXPECTED_TOKENS = 64
        NUM_REQUESTS = 3
        request_ids = [f"request-{i}" for i in range(NUM_REQUESTS)]

        # Create concurrent requests.
        tasks: list[asyncio.Task] = []
        for idx, request_id in enumerate(request_ids):
            tasks.append(asyncio.create_task(generate(engine, request_id, prompt, NUM_EXPECTED_TOKENS)))

        # Confirm the requests are okay.
        for idx, task in enumerate(tasks):
            await task
            output_modalities = ["text"]
            final_stage_id_for_e2e = get_final_stage_id_for_e2e(
                output_modalities, engine.output_modalities, engine.stage_list
            )
            summary = capture_metrics[request_ids[idx]].metrics.build_and_log_summary(final_stage_id_for_e2e)

            # Check that total tokens matches sum of stage tokens.
            assert summary["e2e_total_tokens"] == sum(stage["tokens"] for stage in summary["stages"])
            # Check that total time matches sum of stage times.
            assert summary["e2e_total_time_ms"] >= sum(stage["total_time_ms"] for stage in summary["stages"])