test_oot_registration.py 1.36 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
from ...utils import VLLM_PATH, RemoteOpenAIServer

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

9

10
11
12
def run_and_test_dummy_opt_api_server(model, tp=1):
    # the model is registered through the plugin
    server_args = [
13
14
15
16
17
18
        "--gpu-memory-utilization",
        "0.10",
        "--dtype",
        "float32",
        "--chat-template",
        str(chatml_jinja_path),
19
20
21
22
        "--load-format",
        "dummy",
        "-tp",
        f"{tp}",
23
    ]
24
25
26
27
28
29
30
31
32
33
34
35
    with RemoteOpenAIServer(model, server_args) as server:
        client = server.get_client()
        completion = client.chat.completions.create(
            model=model,
            messages=[{
                "role": "system",
                "content": "You are a helpful assistant."
            }, {
                "role": "user",
                "content": "Hello!"
            }],
            temperature=0,
36
        )
37
38
39
40
41
42
        generated_text = completion.choices[0].message.content
        assert generated_text is not None
        # make sure only the first token is generated
        rest = generated_text.replace("<s>", "")
        assert rest == ""

43

44
45
def test_oot_registration_for_api_server(dummy_opt_path: str):
    run_and_test_dummy_opt_api_server(dummy_opt_path)