test_oot_registration.py 2.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import sys
import time

import torch
from openai import OpenAI, OpenAIError

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port

12
13
14
15
16
from ...utils import VLLM_PATH, RemoteOpenAIServer

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()

17
18
19

class MyOPTForCausalLM(OPTForCausalLM):

20
21
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
22
23
24
25
26
27
28
        # this dummy model always predicts the first token
        logits = super().compute_logits(hidden_states, sampling_metadata)
        logits.zero_()
        logits[:, 0] += 1.0
        return logits


29
def server_function(port: int):
30
31
    # register our dummy model
    ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    sys.argv = ["placeholder.py"] + [
        "--model",
        "facebook/opt-125m",
        "--gpu-memory-utilization",
        "0.10",
        "--dtype",
        "float32",
        "--api-key",
        "token-abc123",
        "--port",
        str(port),
        "--chat-template",
        str(chatml_jinja_path),
    ]

48
49
50
51
52
53
    import runpy
    runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')


def test_oot_registration_for_api_server():
    port = get_open_port()
54
55
    ctx = torch.multiprocessing.get_context()
    server = ctx.Process(target=server_function, args=(port, ))
56
    server.start()
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

    try:
        client = OpenAI(
            base_url=f"http://localhost:{port}/v1",
            api_key="token-abc123",
        )
        now = time.time()
        while True:
            try:
                completion = client.chat.completions.create(
                    model="facebook/opt-125m",
                    messages=[{
                        "role": "system",
                        "content": "You are a helpful assistant."
                    }, {
                        "role": "user",
                        "content": "Hello!"
                    }],
                    temperature=0,
                )
                break
            except OpenAIError as e:
                if "Connection error" in str(e):
                    time.sleep(3)
                    if time.time() - now > RemoteOpenAIServer.MAX_START_WAIT_S:
                        msg = "Server did not start in time"
                        raise RuntimeError(msg) from e
                else:
                    raise e
    finally:
        server.terminate()

89
    generated_text = completion.choices[0].message.content
90
    assert generated_text is not None
91
    # make sure only the first token is generated
92
93
94
    # TODO(youkaichao): Fix the test with plugin
    rest = generated_text.replace("<s>", "")  # noqa
    # assert rest == ""