test_openai_schema.py 6.57 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import json
4
from http import HTTPStatus
5
6
from typing import Final

7
8
import pytest
import schemathesis
9
from httpx import URL
10
from hypothesis import settings
11
from schemathesis import GenerationConfig
12
13
14
15
from schemathesis.checks import not_a_server_error
from schemathesis.internal.checks import CheckContext
from schemathesis.models import Case
from schemathesis.transports.responses import GenericResponse
16

17
18
from vllm.platforms import current_platform

19
20
21
22
23
24
from ...utils import RemoteOpenAIServer

schemathesis.experimental.OPEN_API_3_1.enable()

MODEL_NAME = "HuggingFaceTB/SmolVLM-256M-Instruct"
MAXIMUM_IMAGES = 2
25
26
27
_ROCM_TIMEOUT_MULTIPLIER = 3 if current_platform.is_rocm() else 1
DEFAULT_TIMEOUT_SECONDS: Final[int] = 10 * _ROCM_TIMEOUT_MULTIPLIER
LONG_TIMEOUT_SECONDS: Final[int] = 60 * _ROCM_TIMEOUT_MULTIPLIER
28
29
30
31
32


@pytest.fixture(scope="module")
def server():
    args = [
33
        "--runner",
34
35
36
37
38
39
40
41
        "generate",
        "--max-model-len",
        "2048",
        "--max-num-seqs",
        "5",
        "--enforce-eager",
        "--trust-remote-code",
        "--limit-mm-per-prompt",
42
        json.dumps({"image": MAXIMUM_IMAGES}),
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    ]

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest.fixture(scope="module")
def get_schema(server):
    # avoid generating null (\x00) bytes in strings during test case generation
    return schemathesis.openapi.from_uri(
        f"{server.url_root}/openapi.json",
        generation_config=GenerationConfig(allow_x00=False),
    )


schema = schemathesis.from_pytest_fixture("get_schema")


61
62
63
64
65
@schemathesis.hook
def before_generate_case(context: schemathesis.hooks.HookContext, strategy):
    op = context.operation
    assert op is not None

66
    def no_invalid_types(case: schemathesis.models.Case):
67
        """
68
69
70
71
72
73
74
75
        This filter skips test cases with invalid data that schemathesis
        incorrectly generates due to permissive schema configurations.
        
        1. Skips `POST /tokenize` endpoint cases with `"type": "file"` in 
           message content, which isn't implemented.
        
        2. Skips tool_calls with `"type": "custom"` which schemathesis 
           incorrectly generates instead of the valid `"type": "function"`.
76
77
78

        Example test cases that are skipped:
        curl -X POST -H 'Content-Type: application/json' \
79
            -d '{"messages": [{"content": [{"file": {}, "type": "file"}], "role": "user"}]}' \
80
81
82
            http://localhost:8000/tokenize

        curl -X POST -H 'Content-Type: application/json' \
83
84
            -d '{"messages": [{"role": "assistant", "tool_calls": [{"custom": {"input": "", "name": ""}, "id": "", "type": "custom"}]}]}' \
            http://localhost:8000/v1/chat/completions
85
        """  # noqa: E501
86
        if hasattr(case, "body") and isinstance(case.body, dict):
87
88
89
90
91
            if (
                "messages" in case.body
                and isinstance(case.body["messages"], list)
                and len(case.body["messages"]) > 0
            ):
92
93
94
95
96
97
98
                for message in case.body["messages"]:
                    if not isinstance(message, dict):
                        continue

                    # Check for invalid file type in tokenize endpoint
                    if op.method.lower() == "post" and op.path == "/tokenize":
                        content = message.get("content", [])
99
100
101
                        if (
                            isinstance(content, list)
                            and len(content) > 0
102
103
104
105
                            and any(
                                isinstance(item, dict) and item.get("type") == "file"
                                for item in content
                            )
106
                        ):
107
108
109
110
111
112
113
114
115
116
117
118
                            return False

                    # Check for invalid tool_calls with non-function types
                    tool_calls = message.get("tool_calls", [])
                    if isinstance(tool_calls, list):
                        for tool_call in tool_calls:
                            if isinstance(tool_call, dict):
                                if tool_call.get("type") != "function":
                                    return False
                                if "custom" in tool_call:
                                    return False

119
            # Sometimes structured_outputs.grammar is generated to be empty
120
121
            # Causing a server error in EBNF grammar parsing
            # https://github.com/vllm-project/vllm/pull/22587#issuecomment-3195253421
122
            structured_outputs = case.body.get("structured_outputs", {})
123
124
125
126
127
            grammar = (
                structured_outputs.get("grammar")
                if isinstance(structured_outputs, dict)
                else None
            )
128

129
            if grammar == "":
130
131
132
133
                # Allow None (will be handled as no grammar)
                # But skip empty strings
                return False

134
135
        return True

136
    return strategy.filter(no_invalid_types)
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def customized_not_a_server_error(
    ctx: CheckContext, response: GenericResponse, case: Case
) -> bool | None:
    try:
        return not_a_server_error(ctx, response, case)
    except Exception:
        if (
            URL(response.request.url).path
            in ["/v1/chat/completions/render", "/v1/chat/completions"]
            and response.status_code == HTTPStatus.NOT_IMPLEMENTED.value
        ):
            return True
        raise


154
155
@schema.parametrize()
@schema.override(headers={"Content-Type": "application/json"})
156
@settings(deadline=LONG_TIMEOUT_SECONDS * 1000, max_examples=50)
157
def test_openapi_stateless(case: Case):
158
159
160
161
    key = (
        case.operation.method.upper(),
        case.operation.path,
    )
162
163
164
165
    if case.operation.path.startswith("/v1/responses"):
        # Skip responses API as it is meant to be stateful.
        return

166
167
168
169
170
171
172
173
    # Skip weight transfer endpoints as they require special setup
    # (weight_transfer_config) and are meant to be stateful.
    if case.operation.path in (
        "/init_weight_transfer_engine",
        "/update_weights",
    ):
        return

174
175
    timeout = {
        # requires a longer timeout
176
        ("POST", "/v1/chat/completions"): LONG_TIMEOUT_SECONDS,
177
        ("POST", "/v1/completions"): LONG_TIMEOUT_SECONDS,
178
        ("POST", "/v1/messages"): LONG_TIMEOUT_SECONDS,
179
180
    }.get(key, DEFAULT_TIMEOUT_SECONDS)

181
    # No need to verify SSL certificate for localhost
182
183
184
185
186
187
    case.call_and_validate(
        verify=False,
        timeout=timeout,
        additional_checks=(customized_not_a_server_error,),
        excluded_checks=(not_a_server_error,),
    )