"vscode:/vscode.git/clone" did not exist on "d4f0f17b025487e78525ba13133a161b20538d41"
utils.py 2.36 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import asyncio
import multiprocessing
6
from typing import Callable, Union
7
8
9
10
11
12
13
14
15
16
17
18
19

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext


async def generate(
        client: MQLLMEngineClient,
        request_id: str,
        num_tokens: int,
20
        return_output: bool = False) -> Union[RequestOutput, tuple[int, str]]:
21
22
23
24
25

    final_output = None
    count = 0
    async for out in client.generate(
            request_id=request_id,
26
            prompt="Hello my name is Robert and",
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
            sampling_params=SamplingParams(max_tokens=num_tokens,
                                           temperature=0)):

        count += 1
        final_output = out
        await asyncio.sleep(0.)

    if return_output:
        return final_output

    # Confirm we generated all the tokens we expected.
    return count, request_id


def run_normal(engine_args: AsyncEngineArgs, ipc_path: str):
    # Make engine.
    engine = MQLLMEngine.from_engine_args(
        engine_args=engine_args,
        usage_context=UsageContext.UNKNOWN_CONTEXT,
        ipc_path=ipc_path)

    # Run engine.
    engine.start()


class RemoteMQLLMEngine:

    def __init__(self,
                 engine_args: AsyncEngineArgs,
                 ipc_path: str,
                 run_fn: Callable = run_normal) -> None:

        self.engine_args = engine_args
        self.ipc_path = ipc_path
        context = multiprocessing.get_context("spawn")
        self.proc = context.Process(target=run_fn,
                                    args=(engine_args, ipc_path))
        self.proc.start()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.proc.kill()

    async def make_client(self) -> MQLLMEngineClient:
        engine_config = self.engine_args.create_engine_config()
74
        client = MQLLMEngineClient(self.ipc_path, engine_config, self.proc.pid)
75
76
77
78
79
80
81
        while True:
            try:
                await client.setup()
                break
            except TimeoutError:
                assert self.proc.is_alive()
        return client