test_oot_registration.py 2.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import sys
import time

import torch
from openai import OpenAI, OpenAIError

from vllm import ModelRegistry
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port


class MyOPTForCausalLM(OPTForCausalLM):

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        # this dummy model always predicts the first token
        logits = super().compute_logits(hidden_states, sampling_metadata)
        logits.zero_()
        logits[:, 0] += 1.0
        return logits


def server_function(port):
    # register our dummy model
    ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
    sys.argv = ["placeholder.py"] + \
28
29
        ("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
        f"--dtype float32 --api-key token-abc123 --port {port}").split()
30
31
32
33
34
35
    import runpy
    runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')


def test_oot_registration_for_api_server():
    port = get_open_port()
36
37
    ctx = torch.multiprocessing.get_context()
    server = ctx.Process(target=server_function, args=(port, ))
38
    server.start()
39
    MAX_SERVER_START_WAIT_S = 60
40
41
42
43
    client = OpenAI(
        base_url=f"http://localhost:{port}/v1",
        api_key="token-abc123",
    )
44
    now = time.time()
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    while True:
        try:
            completion = client.chat.completions.create(
                model="facebook/opt-125m",
                messages=[{
                    "role": "system",
                    "content": "You are a helpful assistant."
                }, {
                    "role": "user",
                    "content": "Hello!"
                }],
                temperature=0,
            )
            break
        except OpenAIError as e:
            if "Connection error" in str(e):
                time.sleep(3)
62
63
                if time.time() - now > MAX_SERVER_START_WAIT_S:
                    raise RuntimeError("Server did not start in time") from e
64
65
66
67
68
69
70
            else:
                raise e
    server.kill()
    generated_text = completion.choices[0].message.content
    # make sure only the first token is generated
    rest = generated_text.replace("<s>", "")
    assert rest == ""