worker.py 5.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
17
18
import json
from typing import AsyncGenerator, AsyncIterator
19
20
21

import uvloop
from common.parser import parse_vllm_args
22
from vllm.config import ModelConfig
23
from vllm.engine.arg_utils import AsyncEngineArgs
24
25
26
27
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import (
    build_async_engine_client_from_engine_args,
)
28
29
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
30
    ChatCompletionResponse,
31
    ChatCompletionStreamResponse,
32
33
34
35
    CompletionRequest,
    CompletionResponse,
    CompletionStreamResponse,
    ErrorResponse,
36
)
37
38
39
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
40

Neelay Shah's avatar
Neelay Shah committed
41
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
Neelay Shah's avatar
Neelay Shah committed
42

43

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class VllmEngine:
    def __init__(
        self, engine_client: AsyncIterator[EngineClient], model_config: ModelConfig
    ):
        self.engine_client = engine_client
        self.model_config = model_config

        # Ensure served_model_name matches the openai model name
        # Use --served-model-name to explicitly set this or it will fallback to --model
        models = OpenAIServingModels(
            engine_client=engine_client,
            model_config=model_config,
            base_model_paths=[
                BaseModelPath(
                    name=model_config.served_model_name,
                    model_path=model_config.model,
                )
            ],
        )
63

64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
        self.chat_serving = OpenAIServingChat(
            engine_client=self.engine_client,
            model_config=self.model_config,
            models=models,
            response_role="assistant",
            request_logger=None,
            chat_template=None,
            chat_template_content_format="auto",
        )
        self.completion_serving = OpenAIServingCompletion(
            engine_client=self.engine_client,
            model_config=self.model_config,
            models=models,
            request_logger=None,
        )
79

Neelay Shah's avatar
Neelay Shah committed
80
    @dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    async def generate_chat(self, request):
        result = await self.chat_serving.create_chat_completion(request)

        if isinstance(result, AsyncGenerator):
            async for raw_response in result:
                if raw_response.startswith("data: [DONE]"):
                    break
                response = json.loads(raw_response.lstrip("data: "))
                yield response

        # We should always be streaming so should never get here
        elif isinstance(result, ChatCompletionResponse):
            raise RuntimeError("ChatCompletionResponse support not implemented")

        elif isinstance(result, ErrorResponse):
            error = result.dict()
            raise RuntimeError(
                f"Error {error['code']}: {error['message']} "
                f"(type: {error['type']}, param: {error['param']})"
            )

102
        else:
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            raise TypeError(f"Unexpected response type: {type(result)}")

    @dynamo_endpoint(CompletionRequest, CompletionStreamResponse)
    async def generate_completions(self, request):
        result = await self.completion_serving.create_completion(request)

        if isinstance(result, AsyncGenerator):
            async for raw_response in result:
                if raw_response.startswith("data: [DONE]"):
                    break
                response = json.loads(raw_response.lstrip("data: "))
                yield response

        # We should always be streaming so should never get here
        elif isinstance(result, CompletionResponse):
            raise RuntimeError("CompletionResponse support not implemented")

        elif isinstance(result, ErrorResponse):
            error = result.dict()
            raise RuntimeError(
                f"Error {error['code']}: {error['message']} "
                f"(type: {error['type']}, param: {error['param']})"
125
            )
126

127
128
        else:
            raise TypeError(f"Unexpected response type: {type(result)}")
129
130


Neelay Shah's avatar
Neelay Shah committed
131
@dynamo_worker()
132
133
134
135
136
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
    """
    Instantiate a `backend` component and serve the `generate` endpoint
    A `Component` can serve multiple endpoints
    """
Neelay Shah's avatar
Neelay Shah committed
137
    component = runtime.namespace("dynamo").component("vllm")
138
139
    await component.create_service()

140
141
142
143
144
145
    chat_endpoint = component.endpoint("chat/completions")
    completions_endpoint = component.endpoint("completions")

    async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
        model_config = await engine_client.get_model_config()
        engine = VllmEngine(engine_client, model_config)
146

147
148
149
150
        await asyncio.gather(
            chat_endpoint.serve_endpoint(engine.generate_chat),
            completions_endpoint.serve_endpoint(engine.generate_completions),
        )
151
152
153
154
155
156


if __name__ == "__main__":
    uvloop.install()
    engine_args = parse_vllm_args()
    asyncio.run(worker(engine_args))