Unverified Commit 75a69cd3 authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

feat: add vLLM V1 PD disagg example (#1013)

parent 4fd4d53d
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# vLLM Deployment Examples
This directory contains examples for deploying vLLM models in both aggregated and disaggregated configurations.
## Prerequisites
1. Install vLLM:
```bash
# Note: Currently requires installation from main branch
# From vLLM 0.8.6 onwards, you can install directly from wheel
git clone https://github.com/vllm-project/vllm.git
VLLM_USE_PRECOMPILED=1 uv pip install --editable ./vllm/
```
2. Start required services:
```bash
docker compose -f deploy/metrics/docker-compose.yml up -d
```
## Running the Server
### Aggregated Deployment
```bash
cd examples/vllm_v1
dynamo serve graphs.agg:Frontend -f configs/agg.yaml
```
### Disaggregated Deployment
```bash
cd examples/vllm_v1
dynamo serve graphs.disagg:Frontend -f configs/disagg.yaml
```
## Testing the API
Send a test request using curl:
```bash
curl localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"prompt": "In the heart of Eldoria...",
"stream": false,
"max_tokens": 30
}'
```
For more detailed explenations, refer to the main [LLM examples README](../llm/README.md).
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from pathlib import Path
from components.simple_load_balancer import SimpleLoadBalancer
from fastapi import FastAPI
from pydantic import BaseModel
import dynamo.sdk as sdk
from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__)
def get_dynamo_run_binary():
"""Find the dynamo-run binary path in SDK or fallback to 'dynamo-run' command."""
sdk_path = Path(sdk.__file__)
binary_path = sdk_path.parent / "cli/bin/dynamo-run"
if not binary_path.exists():
return "dynamo-run"
else:
return str(binary_path)
class FrontendConfig(BaseModel):
"""Configuration for the Frontend service including model and HTTP server settings."""
served_model_name: str
endpoint: str
port: int = 8080
# TODO: move these to common for all LLMs once we adopt dynamo-run
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
workers=1,
image=DYNAMO_IMAGE,
app=FastAPI(title="LLM Example"),
)
class Frontend:
worker = depends(SimpleLoadBalancer)
def __init__(self):
"""Initialize Frontend service with HTTP server and model configuration."""
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = frontend_config
self.process = None
self.start_ingress_and_processor()
def start_ingress_and_processor(self):
"""Starting dynamo-run based ingress and processor"""
logger.info(
f"Starting HTTP server and processor on port {self.frontend_config.port}"
)
dynamo_run_binary = get_dynamo_run_binary()
endpoint = f"dyn://{self.frontend_config.endpoint}"
logger.info(
f"Starting HTTP server and processor on port {self.frontend_config.port}"
)
logger.info(f"Endpoint: {endpoint}")
self.process = subprocess.Popen(
[
dynamo_run_binary,
"in=http",
f"out={endpoint}",
"--http-port",
str(self.frontend_config.port),
],
stdout=None,
stderr=None,
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import uuid
from typing import AsyncGenerator, Optional
from components.worker import VllmDecodeWorker, VllmPrefillWorker
from utils.args import parse_vllm_args
from utils.protocol import MyRequestOutput, PreprocessedRequest, vLLMGenerateRequest
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from dynamo.llm import ModelType, register_llm
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service
logger = logging.getLogger(__name__)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class SimpleLoadBalancer:
prefill_worker = depends(VllmPrefillWorker)
decode_worker = depends(VllmDecodeWorker)
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
model_config = self.engine_args.create_model_config()
self.default_sampling_params = model_config.get_diff_sampling_param()
self.enable_disagg = self.engine_args.enable_disagg
@async_on_start
async def async_init(self):
runtime = dynamo_context["runtime"]
logger.info("Registering LLM for discovery")
comp_ns, comp_name = SimpleLoadBalancer.dynamo_address() # type: ignore
endpoint_name = "generate"
for served_model_name in self.engine_args.served_model_name:
logger.info(
f"Registering endpoint {endpoint_name} with model {self.engine_args.model} and served_model_name {served_model_name}"
)
endpoint = (
runtime.namespace(comp_ns).component(comp_name).endpoint(endpoint_name)
)
await register_llm(
ModelType.Backend,
endpoint,
self.engine_args.model,
served_model_name,
)
comp_ns, comp_name = VllmDecodeWorker.dynamo_address() # type: ignore
self.decode_worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
comp_ns, comp_name = VllmPrefillWorker.dynamo_address() # type: ignore
self.prefill_worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
logger.info("SimpleLoadBalancer has been initialized")
async def send_request_to_prefill(
self, request: vLLMGenerateRequest
) -> MyRequestOutput:
logger.debug("Sending request to prefill")
prefill_request = copy.deepcopy(request)
extra_args = prefill_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
prefill_request.sampling_params.extra_args = extra_args
prefill_request.sampling_params.max_tokens = 1
prefill_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", prefill_request.model_dump_json())
async for prefill_response in await self.prefill_worker_client.round_robin(
prefill_request.model_dump_json()
):
return MyRequestOutput.model_validate_json(prefill_response.data())
async def send_request_to_decode(
self,
request: vLLMGenerateRequest,
prefill_response: Optional[MyRequestOutput] = None,
) -> AsyncGenerator[MyRequestOutput, None]:
logger.debug("Sending request to decode")
decode_request = copy.deepcopy(request)
if prefill_response:
extra_args = decode_request.sampling_params.extra_args or {}
extra_args["kv_transfer_params"] = prefill_response.kv_transfer_params
decode_request.sampling_params.extra_args = extra_args
logger.debug("Decode request: %s", decode_request.model_dump_json())
async for decode_response in await self.decode_worker_client.round_robin(
decode_request.model_dump_json()
):
yield MyRequestOutput.model_validate_json(decode_response.data())
@dynamo_endpoint()
async def generate(self, request: PreprocessedRequest):
logger.debug(
"Processor received completion request: %s", request.model_dump_json()
)
vllm_request = self._create_vllm_request(request)
logger.debug("VLLM request: %s", vllm_request.model_dump_json())
if self.enable_disagg:
prefill_response = await self.send_request_to_prefill(vllm_request)
logger.debug("Prefill response: %s", prefill_response.model_dump_json())
else:
prefill_response = None
gen = self.send_request_to_decode(vllm_request, prefill_response)
async for res in self._stream_response(gen):
yield res
def _create_vllm_request(self, request: PreprocessedRequest) -> vLLMGenerateRequest:
request_id = str(uuid.uuid4().hex)
prompt = TokensPrompt(prompt_token_ids=request.token_ids)
sampling_params = SamplingParams(**self.default_sampling_params)
for key, value in request.sampling_options.model_dump().items():
if not value:
continue
if hasattr(sampling_params, key):
setattr(sampling_params, key, value)
max_tokens = request.stop_conditions.max_tokens
if max_tokens:
sampling_params.max_tokens = max_tokens
return vLLMGenerateRequest(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
)
async def _stream_response(self, gen: AsyncGenerator[MyRequestOutput, None]):
num_output_tokens_so_far = 0
async for res in gen:
logger.debug("Decode response: %s", res.model_dump_json())
# res is our MyRequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
import signal
import socket
from typing import Optional
from utils.args import parse_vllm_args
from utils.protocol import MyRequestOutput, vLLMGenerateRequest
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from dynamo.sdk import async_on_start, dynamo_endpoint, service
logger = logging.getLogger(__name__)
class VllmBaseWorker:
def __init__(self):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
signal.signal(signal.SIGTERM, self.shutdown_vllm_engine)
signal.signal(signal.SIGINT, self.shutdown_vllm_engine)
self.set_side_channel_port()
async def async_init(self):
self._engine_context = build_async_engine_client_from_engine_args(
self.engine_args
)
if self._engine_context is not None:
self.engine_client = await self._engine_context.__aenter__()
else:
raise RuntimeError("Failed to initialize engine client")
logger.info("VllmWorker has been initialized")
def shutdown_vllm_engine(self, signum, frame):
"""Shutdown the background loop"""
logger.info(f"Received signal {signum}, shutting down")
loop = asyncio.get_event_loop()
try:
self.engine_client.close()
logger.info("VllmWorker shutdown complete")
except Exception as e:
logger.error(f"Error during shutdown: {e}")
finally:
loop.stop()
@dynamo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
gen = self.engine_client.generate(
prompt=request.prompt,
sampling_params=request.sampling_params,
request_id=request.request_id,
)
async for response in gen:
yield MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
metrics=response.metrics,
kv_transfer_params=response.kv_transfer_params,
).model_dump_json()
def set_side_channel_port(self, port: Optional[int] = None):
"""vLLM V1 NixlConnector creates a side channel to exchange metadata with other NIXL connectors.
This sets the port number for the side channel.
"""
if port is None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0)) # Bind to a free port provided by the host.
port = s.getsockname()[1] # Get the port number assigned.
logger.debug("Setting VLLM_NIXL_SIDE_CHANNEL_PORT to %s", port)
os.environ["VLLM_NIXL_SIDE_CHANNEL_PORT"] = str(port)
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmPrefillWorker(VllmBaseWorker):
@async_on_start
async def async_init(self):
await super().async_init()
logger.info("VllmPrefillWorker has been initialized")
@service(
dynamo={
"enabled": True,
"namespace": "dynamo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmDecodeWorker(VllmBaseWorker):
@async_on_start
async def async_init(self):
await super().async_init()
logger.info("VllmDecodeWorker has been initialized")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
Frontend:
endpoint: dynamo.SimpleLoadBalancer.generate_agg
port: 8000
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
SimpleLoadBalancer:
enable_disagg: false
common-configs: [model, served_model_name]
VllmDecodeWorker:
enforce-eager: true
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, served_model_name]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Common:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
kv-transfer-config: '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
Frontend:
endpoint: dynamo.SimpleLoadBalancer.generate_disagg
port: 8000
served_model_name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
SimpleLoadBalancer:
enable_disagg: true
common-configs: [model, kv-transfer-config, served_model_name]
VllmPrefillWorker:
enforce-eager: true
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, kv-transfer-config, served_model_name]
VllmDecodeWorker:
enforce-eager: true
ServiceArgs:
workers: 1
resources:
gpu: 1
common-configs: [model, kv-transfer-config, served_model_name]
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.frontend import Frontend
from components.simple_load_balancer import SimpleLoadBalancer
from components.worker import VllmDecodeWorker
load_balancer = Frontend.link(SimpleLoadBalancer)
load_balancer.link(VllmDecodeWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from components.frontend import Frontend
from components.simple_load_balancer import SimpleLoadBalancer
from components.worker import VllmDecodeWorker, VllmPrefillWorker
load_balancer = Frontend.link(SimpleLoadBalancer)
load_balancer.link(VllmPrefillWorker)
load_balancer.link(VllmDecodeWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: rename to avoid ambiguity with vllm package
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import FlexibleArgumentParser
from dynamo.sdk.lib.config import ServiceConfig
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance()
vllm_args = config.as_args(service_name, prefix=prefix)
parser = FlexibleArgumentParser()
parser.add_argument(
"--enable-disagg", action="store_true", help="Enable disaggregation"
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args.enable_disagg = args.enable_disagg
return engine_args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, List, Optional
import msgspec
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic_core import core_schema
from typing_extensions import NotRequired
from vllm.inputs import TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import CompletionOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import PromptLogprobs, RequestMetrics
TokenIdType = int
# TODO: move these to common for all LLMs once we adopt dynamo-run
# derived from lib/llm/src/protocols/common/preprocessor.rs
class StopConditions(BaseModel):
max_tokens: Optional[int] = None
stop: Optional[List[str]] = None
stop_token_ids_hidden: Optional[List[TokenIdType]] = None
min_tokens: Optional[int] = None
ignore_eos: Optional[bool] = None
class SamplingOptions(BaseModel):
n: Optional[int] = None
best_of: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
repetition_penalty: Optional[float] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
min_p: Optional[float] = None
use_beam_search: Optional[bool] = None
length_penalty: Optional[float] = None
seed: Optional[int] = None
class PreprocessedRequest(BaseModel):
token_ids: List[TokenIdType]
stop_conditions: StopConditions
sampling_options: SamplingOptions
eos_token_ids: List[TokenIdType] = Field(default_factory=list)
mdc_sum: Optional[str] = None
annotations: List[str] = Field(default_factory=list)
# Hack to override the type of multi_modal_data in TokensPrompt
# as pydantic doesn't understand generic types
# TokensPrompt is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/inputs/data.py#L38
# multi_modal_data is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L103
# ModalityData is defined here: https://github.com/vllm-project/vllm/blob/main/vllm/multimodal/inputs.py#L80
class PatchedTokensPrompt(TokensPrompt):
multi_modal_data: NotRequired[Optional[Any]] # type: ignore
# Monkey-patch the SamplingParams and KVTransferParams types to add a dummy core schema so pydantic can validate them
# Sampling params is a mspspec struct
# SamplingParams is defined here: https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/sampling_params.py#L88
SamplingParams.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
LoRARequest.__get_pydantic_core_schema__ = classmethod(
lambda cls, source, handler: core_schema.any_schema()
)
class vLLMGenerateRequest(BaseModel):
"""
Serializable class of all the fields vLLM engine requires for inference
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
prompt: PatchedTokensPrompt
sampling_params: SamplingParams
request_id: str
@field_validator("sampling_params", mode="before")
@classmethod
def parse_sampling_params(cls, v: Any) -> SamplingParams:
if isinstance(v, str):
v = json.loads(v)
if isinstance(v, dict):
return SamplingParams(**v)
return v
model_config = ConfigDict(
json_encoders={SamplingParams: lambda v: msgspec.json.encode(v)}
)
class MyRequestOutput(BaseModel):
"""
RequestOutput from vLLM is not serializable by default
https://github.com/vllm-project/vllm/blob/a4c402a756fa3213caf9d2cde0e4ceb2d57727f2/vllm/outputs.py#L85
This class is used to serialize the RequestOutput and any recursively defined types
We can do this because PromptLogprobs, RequestMetrics, and CompletionOutput are all serializable dataclasses
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
request_id: str
prompt: Optional[str] = None
prompt_token_ids: Optional[List[int]] = None
prompt_logprobs: Optional[PromptLogprobs] = None
outputs: List[CompletionOutput]
finished: bool
metrics: Optional[RequestMetrics] = None
kv_transfer_params: Optional[dict[str, Any]] = None
# lora_request: Optional[LoRARequest] = None
# encoder_prompt: Optional[str] = None
# encoder_prompt_token_ids: Optional[List[int]] = None
# num_cached_tokens: Optional[int] = None
# multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment