"lib/runtime/src/runnable.rs" did not exist on "ffc6dde1f0c6a45ac2ed72e91139949992c9c55d"
Commit b92834c8 authored by Neelay Shah's avatar Neelay Shah Committed by GitHub
Browse files

chore: removing outdated examples (#202)

parent fd79234f
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import copy
import json
import uvloop
from common.protocol import (
DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse,
DisaggCompletionStreamResponse,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest, DisaggregatedParams
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug")
class Router:
def __init__(
self,
ctx_chat_client,
gen_chat_client,
ctx_completion_client,
gen_completion_client,
):
self.ctx_chat_client = ctx_chat_client
self.gen_chat_client = gen_chat_client
self.ctx_completion_client = ctx_completion_client
self.gen_completion_client = gen_completion_client
logger.info("INITIALIZED ROUTER")
async def _get_ctx_resp(self, request, ctx_client):
logger.debug(f"Received request {request}")
request.max_completion_tokens = 1
request.disaggregated_params = DisaggregatedParams(request_type="context_only")
logger.debug(f"[router] Sending request to context server: {request}")
ctx_resp = [
resp
async for resp in await ctx_client.round_robin(request.model_dump_json())
]
if len(ctx_resp) > 1:
raise ValueError(
"Context server returned more than one response. This is currently not supported in disaggregated server."
)
logger.debug(
f"[router] received response from context server: {ctx_resp[0].data()}"
)
return ctx_resp[0].data()
# TODO (shreyasm): The only reason we cant further combine the two methods below is
# because the disagg params are in different locations.
# Disagg params should be in under the choices field in the response object.
# This is the case for completions but not for chat.
@dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
async def generate_completion(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_completion_client)
ctx_resp_obj = DisaggCompletionStreamResponse.model_validate(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.choices[0].disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_completion_client.round_robin(
gen_req.model_dump_json()
):
logger.debug(
f"[router] Received response from generation server: {response.data()}"
)
gen_resp_obj = DisaggCompletionStreamResponse.model_validate(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
async def generate_chat(self, request):
# These settings are needed to satisfy request checks.
request.skip_special_tokens = False
request.add_special_tokens = False
request.spaces_between_special_tokens = False
gen_req = copy.deepcopy(request)
ctx_resp = await self._get_ctx_resp(request, self.ctx_chat_client)
ctx_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(ctx_resp)
gen_req.disaggregated_params = DisaggregatedParams.model_validate(
ctx_resp_obj.disaggregated_params
)
gen_req.disaggregated_params.request_type = "generation_only"
if request.stream:
yield json.loads(
ctx_resp_obj.model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
)
logger.debug(f"[router] Sending request to generation server: {gen_req}")
async for response in await self.gen_chat_client.round_robin(
gen_req.model_dump_json()
):
logger.debug(
f"[router] Received response from generation server: {response.data()}"
)
gen_resp_obj = DisaggChatCompletionStreamResponse.model_validate_json(
response.data()
)
yield json.loads(gen_resp_obj.model_dump_json(exclude_unset=True))
@dynamo_worker()
async def worker(runtime: DistributedRuntime):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
component = runtime.namespace("dynamo").component("router")
await component.create_service()
ctx_completion_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-ctx")
.endpoint("completions")
.client()
)
gen_completion_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-gen")
.endpoint("completions")
.client()
)
ctx_chat_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-ctx")
.endpoint("chat/completions")
.client()
)
gen_chat_client = (
await runtime.namespace("dynamo")
.component("tensorrt-llm-gen")
.endpoint("chat/completions")
.client()
)
completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions")
router = Router(
ctx_chat_client, gen_chat_client, ctx_completion_client, gen_completion_client
)
await asyncio.gather(
completions_endpoint.serve_endpoint(router.generate_completion),
chat_endpoint.serve_endpoint(router.generate_chat),
)
if __name__ == "__main__":
uvloop.install()
asyncio.run(worker())
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import os
import signal
import uvloop
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.disagg_processor import ChatProcessor, parse_chat_message_content
from common.parser import LLMAPIConfig, parse_tensorrt_llm_args
from common.processor import merge_promises
from common.protocol import (
DisaggChatCompletionRequest,
DisaggChatCompletionStreamResponse,
DisaggCompletionStreamResponse,
DisaggregatedTypeConverter,
)
from mpi4py.futures import MPICommExecutor
from mpi4py.MPI import COMM_WORLD
from tensorrt_llm._utils import set_mpi_comm
from tensorrt_llm.executor import CppExecutorError
from tensorrt_llm.llmapi import MpiCommSession
from tensorrt_llm.llmapi.disagg_utils import (
CtxGenServerConfig,
DisaggServerConfig,
parse_disagg_config_file,
split_world_comm,
)
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import CompletionRequest
from dynamo.llm import KvMetricsPublisher
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug")
def update_args_from_disagg_config(
engine_config: LLMAPIConfig, server_config: CtxGenServerConfig
):
# Overwrite the LLM API config with the disaggregated config
# Allows for different configs for context and generation servers
engine_config.extra_args.update(**server_config.other_args)
engine_config.update_sub_configs(server_config.other_args)
return engine_config
class TensorrtLLMEngine(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
def __init__(
self,
trt_llm_engine_config: TensorrtLLMEngineConfig,
disagg_config: DisaggServerConfig,
instance_idx: int,
sub_comm,
):
self.disagg_config = disagg_config
self.instance_idx = instance_idx
self.server_config: CtxGenServerConfig = disagg_config.server_configs[
instance_idx
]
engine_config = update_args_from_disagg_config(
trt_llm_engine_config.engine_config, self.server_config
)
trt_llm_engine_config.engine_config = engine_config
# needed for disagg
self._mpi_session = MpiCommSession(sub_comm, n_workers=sub_comm.Get_size())
trt_llm_engine_config.engine_config.extra_args[
"_mpi_session"
] = self._mpi_session
super().__init__(trt_llm_engine_config)
@dynamo_endpoint(DisaggChatCompletionRequest, DisaggChatCompletionStreamResponse)
async def generate_chat(self, request):
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
# Check if there are any errors in the error queue.
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
logger.debug(f"Received request: {request}")
chat_processor = ChatProcessor(self._model, self._tokenizer, request)
self._ongoing_request_count += 1
try:
conversation = []
for message in request.messages:
conversation.extend(parse_chat_message_content(message))
tool_dicts = (
None
if request.tools is None
else [tool.model_dump() for tool in request.tools]
)
prompt: str = self._tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
chat_template=request.chat_template,
**(request.chat_template_kwargs or {}),
)
sampling_params = request.to_sampling_params()
disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
)
final_result = None
async for result in self._llm_engine.generate_async(
prompt,
sampling_params,
streaming=request.stream,
disaggregated_params=disaggregated_params,
):
final_result = result
logger.debug(f"Generated result: {result}")
if self.server_config.type == "ctx":
disaggregated_response = chat_processor.get_chat_stream_response(
request.id,
result,
first_iteration=True,
)
disaggregated_response.disaggregated_params = (
DisaggregatedTypeConverter.to_oai_disaggregated_params(
result.outputs[0].disaggregated_params
)
)
yield disaggregated_response.model_dump_json()
else:
yield chat_processor.get_chat_stream_response(
request.id,
result,
first_iteration=False,
).model_dump_json(
exclude_unset=True, exclude={"disaggregated_params"}
)
if request.stream_options and request.stream_options.include_usage:
yield chat_processor.create_final_stream_response(
request.id,
final_result,
).model_dump_json(exclude_unset=True, exclude={"disaggregated_params"})
except CppExecutorError:
# If internal executor error is raised, shutdown the server
signal.raise_signal(signal.SIGINT)
except Exception as e:
raise RuntimeError("Failed to generate: " + str(e))
# Start the publishing threads with first request submission
self._stats_loop = asyncio.get_running_loop()
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
self.publish_kv_cache_events_thread.start()
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self.publish_stats_thread.start()
self._ongoing_request_count -= 1
@dynamo_endpoint(CompletionRequest, DisaggCompletionStreamResponse)
async def generate_completions(self, request):
logger.debug(f"[worker] worker_id: {self._worker_id} received request")
if self._llm_engine is None:
raise RuntimeError("Engine not initialized")
# Check if there are any errors in the error queue.
if self._error_queue.qsize() > 0:
error = self._error_queue.get()
raise error
self._ongoing_request_count += 1
logger.debug(f"[worker] Received completions request: {request}")
if not isinstance(request.prompt, str):
# Check if it's a list and contains integers
if isinstance(request.prompt, list) and len(request.prompt) == 1:
request.prompt = request.prompt[0]
elif not isinstance(request.prompt, list) or not all(
isinstance(x, int) for x in request.prompt
):
raise ValueError(
"Disaggregated server currently only supports single string prompt or list of integers in request"
)
sampling_params = request.to_sampling_params()
llm_disaggregated_params = (
DisaggregatedTypeConverter.to_llm_disaggregated_params(
request.disaggregated_params
)
)
# only 1 prompt is supported for now
promise = self._llm_engine.generate_async(
request.prompt,
sampling_params,
streaming=request.stream,
disaggregated_params=llm_disaggregated_params,
)
generator = merge_promises([promise])
num_choices = 1 if request.n is None else request.n
if request.stream:
response_generator = self.completions_processor.create_completion_generator(
request, generator, num_choices
)
async for response in response_generator:
yield json.loads(response)
else:
raise RuntimeError("Non-streaming is not supported")
# Start the publishing threads with first request submission
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
):
# [NOTE:] TRTLLM needs the stats to be collected on the same loop as the request handler.
self._stats_loop = asyncio.get_running_loop()
self.publish_kv_cache_events_thread.set_loop(self._stats_loop)
self.publish_kv_cache_events_thread.start()
if self.publish_stats_thread and not self.publish_stats_thread.is_alive():
self._stats_loop = asyncio.get_running_loop()
self.publish_stats_thread.set_loop(self._stats_loop)
self.publish_stats_thread.start()
self._ongoing_request_count -= 1
@dynamo_worker()
async def worker(
runtime: DistributedRuntime,
engine_config: LLMAPIConfig,
disagg_config: DisaggServerConfig,
instance_idx: int,
sub_comm,
publish_stats: bool,
publish_kv_cache_events: bool,
):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
server_type = disagg_config.server_configs[instance_idx].type
logger.info(f"Starting {server_type} server")
namespace_str = "dynamo"
component_str = f"tensorrt-llm-{server_type}"
component = runtime.namespace(namespace_str).component(component_str)
await component.create_service()
completions_endpoint = component.endpoint("completions")
chat_endpoint = component.endpoint("chat/completions")
if server_type == "gen":
if publish_stats:
logger.warning("Stats can only be published for ctx server")
publish_stats = False
if publish_kv_cache_events:
logger.warning("KV cache events can only be published for ctx server")
publish_kv_cache_events = False
trt_llm_engine_config = TensorrtLLMEngineConfig(
namespace_str=namespace_str,
component_str=component_str,
engine_config=engine_config,
publish_stats=publish_stats,
publish_kv_cache_events=publish_kv_cache_events,
)
# NOTE: Current implementation adds two endpoints. We can refactor this code to expose only one endpoint.
# and handle both completions and chat in the same endpoint.
# Currently, we are using completions endpoint lease id as worker id.
# I believe this might cause some issues using smart routing with chat completions endpoint.
trt_llm_engine_config.worker_id = completions_endpoint.lease_id()
if publish_stats:
trt_llm_engine_config.kv_metrics_publisher = KvMetricsPublisher()
engine = TensorrtLLMEngine(
trt_llm_engine_config,
disagg_config,
instance_idx,
sub_comm,
)
coros = [
completions_endpoint.serve_endpoint(engine.generate_completions),
chat_endpoint.serve_endpoint(engine.generate_chat),
]
if publish_stats:
coros.append(
trt_llm_engine_config.kv_metrics_publisher.create_endpoint(component)
)
await asyncio.gather(*coros)
if __name__ == "__main__":
uvloop.install()
args, engine_config = parse_tensorrt_llm_args()
if args.llmapi_disaggregated_config is None or not os.path.exists(
args.llmapi_disaggregated_config
):
raise ValueError(
"llmapi_disaggregated_config file does not exist or not provided"
)
disagg_config: DisaggServerConfig = parse_disagg_config_file(
args.llmapi_disaggregated_config
)
logger.info(f"Parsed disaggregated config: {disagg_config}")
is_leader, instance_idx, sub_comm = split_world_comm(disagg_config.server_configs)
os.environ["TRTLLM_USE_MPI_KVCACHE"] = "1"
set_mpi_comm(sub_comm)
logger.info(f"is_leader: {is_leader}, instance_idx: {instance_idx}")
if is_leader:
asyncio.run(
worker(
engine_config,
disagg_config,
instance_idx,
sub_comm,
args.publish_stats,
args.publish_kv_cache_events,
)
)
else:
with MPICommExecutor(sub_comm) as executor:
if not is_leader and executor is not None:
raise RuntimeError(f"rank{COMM_WORLD} should not have executor")
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# KV Aware Routing
This document describes how to use the KV aware routing feature in Dynamo with TensorRT LLM disaggregated serving.
The KV Router is a component that aggregates KV Events from all the workers and maintains a prefix tree of the cached tokens. It makes decisions on which worker to route requests to based on the length of the prefix match and the load on the workers.
## KV Aware Routing with Disaggregated Serving
Follow the instructions in the [README](../README.md) to setup the environment for [disaggregated serving](../README.md#disaggregated-deployment).
All of the steps remain the same except launching the [workers and the router](../README.md#workers).
### 1. Workers
To launch the workers and the router, run the following command:
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
mpirun --allow-run-as-root --oversubscribe -n 5 python3 -m disaggregated.worker --publish-stats --publish-kv-cache-events --engine_args llm_api_config.yaml -c disaggregated/llmapi_disaggregated_configs/single_node_kv_aware_config.yaml 1>disagg_workers.log 2>&1 &
```
Note the extra arguments `--publish-stats` and `--publish-kv-cache-events` to publish the stats and kv cache events from the workers for effective routing.
The config file [single_node_kv_aware_config.yaml](disaggregated/llmapi_disaggregated_configs/single_node_kv_aware_config.yaml) specifies extra configuration for the LLM execution engine to support stats and kv cache events collection. These configurations are:
1. `enable_iter_perf_stats` in `pytorch_backend_config` to enable the iteration performance stats collection.
2. `event_buffer_max_size` in `kv_cache_config` to specify the maximum number of events that can be stored in the buffer.
3. `enable_block_reuse` in `kv_cache_config` to enable the block reuse feature for improved performance.
Note: The configuration also specifies 4 context servers and 1 generation server.
### 2. Router
To launch the router, run the following command:
```bash
cd /workspace/examples/python_rs/llm/tensorrt_llm/
python3 -m disaggregated.kv_router --engine_args llm_api_config.yaml 1>kv_router.log 2>&1 &
```
The router will route the incoming requests to the appropriate context server based on the stats and kv cache events.
### 3. Send Requests
Follow the instructions in the [README](../README.md#send-requests) to send requests to the [HTTP server](../README.md#http-server).
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# In the case of disaggregated deployment, this config will apply to each server
# and will be overwritten by the disaggregated config file
model_name: "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model_path: null
tensor_parallel_size: 1
moe_expert_parallel_size: 1
enable_attention_dp: false
max_num_tokens: 10240
max_batch_size: 16
trust_remote_code: true
backend: pytorch
kv_cache_config:
free_gpu_memory_fraction: 0.95
# Uncomment to enable kv cache event collection
#event_buffer_max_size: 1024
#enable_block_reuse: true
pytorch_backend_config:
enable_overlap_scheduler: false
use_cuda_graph: false
# Uncomment to enable iter perf stats
#enable_iter_perf_stats: true
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
IMPORTANT:
- This is only supposed to be used by dynamo-run launcher.
- It is part of bring-your-own-engine python feature in dynamo-run.
"""
import sys
from pathlib import Path
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from dynamo.runtime import dynamo_endpoint
# Add the project root to the Python path
project_root = str(Path(__file__).parents[1]) # Go up to trtllm directory
if project_root not in sys.path:
sys.path.append(project_root)
from common.base_engine import ( # noqa: E402
BaseTensorrtLLMEngine,
TensorrtLLMEngineConfig,
)
from common.generators import chat_generator # noqa: E402
from common.parser import parse_dynamo_run_args # noqa: E402
logger.set_level("info")
class DynamoTRTLLMEngine(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
def __init__(self, trt_llm_engine_config: TensorrtLLMEngineConfig):
super().__init__(trt_llm_engine_config)
engine = None # Global variable to store the engine instance. This is initialized in the main function.
def init_global_engine(args, engine_config):
global engine
logger.debug(f"Received args: {args}")
logger.info(f"Initializing global engine with engine config: {engine_config}")
trt_llm_engine_config = TensorrtLLMEngineConfig(
engine_config=engine_config,
)
engine = DynamoTRTLLMEngine(trt_llm_engine_config)
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(request):
async for response in chat_generator(engine, request):
yield response
if __name__ == "__main__":
args, engine_config = parse_dynamo_run_args()
init_global_engine(args, engine_config)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
from pathlib import Path
import uvloop
# Add the project root to the Python path
project_root = str(Path(__file__).parents[1]) # Go up to trtllm directory
if project_root not in sys.path:
sys.path.append(project_root)
from common.parser import parse_tensorrt_llm_args # noqa: E402
from .worker import trtllm_worker # noqa: E402
if __name__ == "__main__":
uvloop.install()
args, engine_config = parse_tensorrt_llm_args()
asyncio.run(trtllm_worker(engine_config))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from common.base_engine import BaseTensorrtLLMEngine, TensorrtLLMEngineConfig
from common.generators import chat_generator, completion_generator
from common.parser import LLMAPIConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionStreamResponse,
)
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
logger.set_level("debug")
class TensorrtLLMEngine(BaseTensorrtLLMEngine):
"""
Request handler for the generate endpoint
"""
def __init__(self, trt_llm_engine_config: TensorrtLLMEngineConfig):
super().__init__(trt_llm_engine_config)
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate_chat(self, request):
async for response in chat_generator(self, request):
yield response
@dynamo_endpoint(CompletionRequest, CompletionStreamResponse)
async def generate_completion(self, request):
async for response in completion_generator(self, request):
yield response
@dynamo_worker()
async def trtllm_worker(runtime: DistributedRuntime, engine_config: LLMAPIConfig):
"""
Instantiate a `backend` component and serve the `generate` endpoint
A `Component` can serve multiple endpoints
"""
namespace_str = "dynamo"
component_str = "tensorrt-llm"
component = runtime.namespace(namespace_str).component(component_str)
await component.create_service()
completions_endpoint = component.endpoint("completions")
chat_completions_endpoint = component.endpoint("chat/completions")
trt_llm_engine_config = TensorrtLLMEngineConfig(
namespace_str=namespace_str,
component_str=component_str,
engine_config=engine_config,
)
engine = TensorrtLLMEngine(trt_llm_engine_config)
await asyncio.gather(
completions_endpoint.serve_endpoint(engine.generate_completion),
chat_completions_endpoint.serve_endpoint(engine.generate_chat),
)
<!--
SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
> **NOTE**: This example is based on an internal NVIDIA library that will soon be publicly released. The example won't work until the official release.
## Prerequisites
Start required services (etcd and NATS):
Option A: Using [Docker Compose](/deploy/docker-compose.yml) (Recommended)
```bash
docker compose -f deploy/docker-compose.yml up -d
```
Option B: Manual Setup
- [NATS.io](https://docs.nats.io/running-a-nats-service/introduction/installation) server with [Jetstream](https://docs.nats.io/nats-concepts/jetstream)
- example: `nats-server -js --trace`
- [etcd](https://etcd.io) server
- follow instructions in [etcd installation](https://etcd.io/docs/v3.5/install/) to start an `etcd-server` locally
## Build docker
```
./container/build.sh
```
## Run container
```
./container/run.sh -it
```
All of the commands below are run inside the same container.
## Run deployment
This figure shows an overview of the major components to deploy:
```
+----------------+
+------| prefill worker |-------+
notify | | (optional) | |
finished | +----------------+ | pull
v v
+------+ +-----------+ +------------------+ push +---------------+
| HTTP |----->| processor |----->| decode/monolith |------------>| prefill queue |
| |<-----| |<-----| worker | (if disagg) | (optional) |
+------+ +-----------+ +------------------+ +---------------+
| ^ |
query best | | return | publish kv events
worker | | worker_id v
| | +------------------+
| +---------| kv-router |
+------------->| (optional) |
+------------------+
```
Add model to dynamo and start http server.
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.process.chat/completions
TRT_LOG=DEBUG http --port 8181
```
### Processor
Processor routes the requests to the (decode) workers. Three scheduling strategies are supported: 1. random, 2. round-robin, 3. kv (see [Kv Router](#kv-router)).
```
# Processor must take the same args as the (decoder) worker
# This is temporary until we communicate the ModelDeploymentCard over etcd
RUST_LOG=info python3 processor.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \
--max-model-len 16384 \
--router <random/round-robin/kv>
```
Alternatively, the processor can be bypassed by directly hitting the worker endpoints:
```
llmctl http add chat-models deepseek-ai/DeepSeek-R1-Distill-Llama-8B dynamo-init.vllm.generate
# monolithic
CUDA_VISIBLE_DEVICES=0 python3 routerless/worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager
# disaggregated
CUDA_VISIBLE_DEVICES=0 python routerless/prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}'
CUDA_VISIBLE_DEVICES=1 python3 routerless/worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}'
```
### Kv Router
The KV Router is a component that aggregates KV Events from all the workers and maintains
a prefix tree of the cached tokens. It makes decisions on which worker to route requests
to based on the length of the prefix match and the load on the workers.
There are three steps needed to enable the kv router:
1. Use `--router kv` in the processor.
2. Use `--router kv` and `--enable-prefix-caching` in all the (decode) workers.
3. Launch the kv router in a separate terminal.
```
RUST_LOG=info python3 kv_router.py \
--model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--block-size 64 \
--min-workers 1
```
where `--min-workers` is the number of (decode) workers.
You can choose only the prefix strategy for now:
- `prefix`: Route requests to the worker that has the longest prefix match.
### Disaggregated Router
The disaggregated router determines whether a request should be send to a
remote prefill engine or a local prefill engine for prefilling based on the
prefill length. If kv router is enabled, the disaggregated router will use
the absolute prefill length (actual prefill length - prefix hit length) to make
the decision.
When prefilling locally, the vllm scheduler will prioritize
prefill request and pause any ongoing decode requests.
To enable the disaggregated router, add the following commands in the decode workers:
```
python worker.py \
...
--conditional-disagg \
--max-local-prefill-length <length>
```
### Worker
#### Monolithic
Only kv router is supported for monolithic deployment.
```
CUDA_VISIBLE_DEVICES=0 python3 worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--block-size 64 \
--max-model-len 16384 \
<optional kv router args: --router kv --enable-prefix-caching>
```
#### Disaggregated
Kv router and disaggregated router are supported and can be turned on/off individually.
```
# start prefill worker in one terminal
# Note: prefix caching is not supported in the prefill for now
CUDA_VISIBLE_DEVICES=0 python3 prefill_worker.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}' \
--block-size 64 \
--max-num-batched-tokens 16384 \
--max-model-len 16384
# start decode worker in another terminal
CUDA_VISIBLE_DEVICES=1 python3 worker.py \
--remote-prefill \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--enforce-eager \
--tensor-parallel-size 1 \
--kv-transfer-config '{"kv_connector":"DynamoNixlConnector"}' \
--block-size 64 \
--max-num-batched-tokens 16384 \
--max-model-len 16384 \
<optional kv router args: --router kv --enable-prefix-caching>
<optional disaggregated router args: --conditional-disagg --max-local-prefill-length <length>>
```
### Multi-Node Deployment
For multi-node deployment, etcd, nats, processor, and kv router
are only required on the head node. The only components that need
to be deployed on all nodes are the workers.
Set the following environment variables on each node before running the workers:
```bash
export NATS_SERVER="nats://<nats-server-host>:<nats-server-port>"
export ETCD_ENDPOINTS="http://<etcd-server-host>:<etcd-server-port>"
```
### Common Issues
If torch GLOO backend is complaining about file name too long, set
```
export GLOO_SOCKET_IFNAME=lo
```
## Client
In another terminal:
```
# this test request has around 200 tokens isl
curl localhost:8181/v1/chat/completions -H "Content-Type: application/json" -d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [
{
"role": "user",
"content": "In the heart of Eldoria, an ancient land of boundless magic and mysterious creatures, lies the long-forgotten city of Aeloria. Once a beacon of knowledge and power, Aeloria was buried beneath the shifting sands of time, lost to the world for centuries. You are an intrepid explorer, known for your unparalleled curiosity and courage, who has stumbled upon an ancient map hinting at ests that Aeloria holds a secret so profound that it has the potential to reshape the very fabric of reality. Your journey will take you through treacherous deserts, enchanted forests, and across perilous mountain ranges. Your Task: Character Background: Develop a detailed background for your character. Describe their motivations for seeking out Aeloria, their skills and weaknesses, and any personal connections to the ancient city or its legends. Are they driven by a quest for knowledge, a search for lost familt clue is hidden."
}
],
"stream":false,
"max_tokens": 30
}'
```
## Run genai-perf
`genai-perf` is a tool for profiling and benchmarking LLM servers. It is already installed in the container. For more details, please refer to the [genai-perf README](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/perf_analyzer/genai-perf/README.html).
```
genai-perf profile \
-m deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--url localhost:8181 \
--endpoint-type chat \
--streaming \
--service-kind openai \
--endpoint v1/chat/completions \
--warmup-request-count 10 \
--random-seed 123 \
--synthetic-input-tokens-stddev 0 \
--output-tokens-stddev 0 \
--tokenizer deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--synthetic-input-tokens-mean 3000 \
--output-tokens-mean 150 \
--extra-inputs min_tokens:150 \
--extra-inputs max_tokens:150 \
--profile-export-file my_profile_export.json \
--artifact-dir artifacts/ \
--concurrency 10 \
--request-count 40 \
-- -v \
--async
```
## Close deployment
Kill all python processes and clean up metadata files:
```
pkill -9 -f python
```
## TODOs, limitations, known issues
- [ ] Add etcd for discovery
- [ ] Multi-node deployment support
- [ ] Enable chunked prefill
- [ ] Process many remote prefill in one iteration
- [ ] Support recompute preemption
- [ ] Make sure decode does not preempt blocks before xfer finishes
- [ ] Layer wise transfer
- [ ] Non blocking send in prefill (cache manager should check xfer status)
- [ ] Test under load
- [ ] Support pp > 1
- [ ] Check why adding extra seed input is crashing vllm with remote prefill
- [ ] Unified worker for both prefill and decode
- [x] Support mixed tp
- [x] Require sending two parallel requests to start decode for the first time
- [x] Concurrency > 2 is not working
- [x] Parse cmdline args
- [x] Manual nixl example with tp1
- [x] Zero copy
- [x] Conditional remote prefill
- [x] Manual example with tp > 1
- [x] Run on dynamo distributed runtime
- [x] add oai http endpoint
- [x] Sample only on decode, do note return remote prefill response
- [x] Check if all transfers finished before moving to decode
- [x] Enable async output processing - could be working
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.logger import logger as vllm_logger
class PyDisaggregatedRouter:
def __init__(
self,
runtime,
served_model_name,
max_local_prefill_length=1000,
):
self.runtime = runtime
self.served_model_name = served_model_name
self.max_local_prefill_length = max_local_prefill_length
def prefill_remote(self, prompt_length: int, prefix_hit_rate: float):
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
vllm_logger.info(
f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})"
)
return absolute_prefill_length > self.max_local_prefill_length
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
from argparse import Namespace
from typing import AsyncIterator
import uvloop
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
WorkerId = str
class CustomRouter:
"""
Request handler for the generate endpoint
"""
def __init__(
self,
workers_client,
indexer: KvIndexer,
metrics_aggregator: KvMetricsAggregator,
):
vllm_logger.info("Initializing Custom Router")
self.indexer = indexer
self.metrics_aggregator = metrics_aggregator
self.workers_client = workers_client
def _cost_function(
self,
scores: OverlapScores | None,
metrics: AggregatedMetrics | None,
token_length: int,
):
worker_scores = {}
if scores:
for worker_id, score in scores.scores.items():
# score is number of matching blocks we multiply by block_size to get tokens
# and compare to token_length. The larger the cache hit the better
worker_scores[worker_id] = (
score * self.indexer.block_size() / token_length
)
worker_metrics = {}
# pull metrics for each worker
max_waiting = 0.0
if metrics:
for endpoint in metrics.endpoints:
worker_id = endpoint.worker_id
worker_metrics[worker_id] = {
"gpu_cache_usage_perc": endpoint.gpu_cache_usage_perc
if hasattr(endpoint, "gpu_cache_usage_perc")
else 0.0,
"num_requests_waiting": endpoint.num_requests_waiting
if hasattr(endpoint, "num_requests_waiting")
else 0.0,
"gpu_prefix_cache_hit_rate": endpoint.gpu_prefix_cache_hit_rate
if hasattr(endpoint, "gpu_prefix_cache_hit_rate")
else 0.0,
}
max_waiting = max(
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
)
# Get all worker IDs from the client. This is needed because scores / metrics may not have values for all workers
# and we want all workers to be considered in the logit calculation
worker_ids = self.workers_client.endpoint_ids()
worker_logits = {}
for worker_id in worker_ids:
# Use default values if worker not in scores or metrics
score = worker_scores.get(worker_id, 0.0)
metrics_dict = worker_metrics.get(
worker_id,
{
"gpu_cache_usage_perc": 0.0,
"num_requests_waiting": 0.0,
"gpu_prefix_cache_hit_rate": 0.0,
},
)
normalized_waiting = (
metrics_dict["num_requests_waiting"] / max_waiting
if max_waiting > 0
else 0.0
)
# Have 1 metric that weights towards cache hit
# 2 metrics that penalize overloaded worker and queuing
worker_logits[worker_id] = (
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
)
vllm_logger.info(
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
)
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
return ""
# Select the worker with the highest logit
if worker_logits:
max_logit = max(worker_logits.values())
best_workers = [
wid for wid, logit in worker_logits.items() if logit == max_logit
]
best_worker_id = random.choice(best_workers)
else:
best_worker_id = ""
# Log the metrics for the selected worker
if best_worker_id:
vllm_logger.info(
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
)
vllm_logger.info(
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
)
metrics_dict = worker_metrics.get(best_worker_id, {})
vllm_logger.info(
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
)
vllm_logger.info(
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
)
vllm_logger.info(
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
)
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
@dynamo_endpoint(Tokens, WorkerId)
async def generate(self, request) -> AsyncIterator[WorkerId]:
lora_id = 0
worker_id = ""
try:
scores = await self.indexer.find_matches_for_request(
request.tokens, lora_id
)
except Exception as e:
scores = {}
vllm_logger.exception(f"Error finding matches: {e}")
token_length = len(request.tokens)
metrics = await self.metrics_aggregator.get_metrics()
schedule_result = self._cost_function(scores, metrics, token_length)
if schedule_result == "":
worker_id = ""
prefix_hit_rate = 0.0
else:
worker_id, prefix_hit_rate = schedule_result
vllm_logger.info(
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
yield f"{worker_id}_{prefix_hit_rate}"
@dynamo_worker()
async def worker(runtime: DistributedRuntime, args: Namespace):
"""
Set up the worker clients.
Serve the dynamo-init.router.generate endpoint.
"""
workers_client = (
await runtime.namespace("dynamo-init")
.component("vllm")
.endpoint("generate")
.client()
)
while len(workers_client.endpoint_ids()) < args.min_workers:
vllm_logger.info(
f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {args.min_workers}"
)
await asyncio.sleep(5)
vllm_logger.info(
f"Required number of workers ({args.min_workers}) are ready:\n"
+ "\n".join(f"id: {id}" for id in workers_client.endpoint_ids())
)
kv_listener = runtime.namespace("dynamo-init").component("vllm")
await kv_listener.create_service()
router_component = runtime.namespace("dynamo-init").component("router")
await router_component.create_service()
endpoint = router_component.endpoint("generate")
indexer = KvIndexer(kv_listener, args.block_size)
metrics_aggregator = KvMetricsAggregator(kv_listener)
await endpoint.serve_endpoint(
CustomRouter(workers_client, indexer, metrics_aggregator).generate
)
if __name__ == "__main__":
uvloop.install()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-workers",
type=int,
default=1,
help="Minimum number of workers required before proceeding",
)
parser.add_argument(
"--model-name",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model that is being served",
)
# TODO: Read block size
parser.add_argument(
"--block-size",
type=int,
default=64,
help="KV block size",
)
parser.add_argument(
"--custom-router",
type=bool,
default=False,
help="Whether to use custom router or not",
)
args = parser.parse_args()
asyncio.run(worker(args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import uvloop
from utils.nixl import NixlMetadataStore
from utils.prefill_queue import PrefillQueue
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.logger import logger as vllm_logger
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.runtime import DistributedRuntime, dynamo_worker
class RequestHandler:
def __init__(self, engine_client, metadata_store):
self.engine_client = engine_client
self._metadata_store = metadata_store
self._loaded_metadata = set()
print("RequestHandler initialized")
async def generate(self, request: RemotePrefillRequest):
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
print(
f"Loaded nixl metadata from engine {request.engine_id} into engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
# TODO: we don't need it now, but will need it after the queue is integrated to the runtime
component = runtime.namespace("dynamo-init").component("prefill")
await component.create_service()
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
# TODO: move this to prefill_queue.py
prefill_queue_nats_server = os.getenv("NATS_SERVER", "nats://localhost:4222")
prefill_queue_stream_name = (
engine_args.served_model_name
if engine_args.served_model_name is not None
else "vllm"
)
vllm_logger.info(
"Prefill queue: %s:%s", prefill_queue_nats_server, prefill_queue_stream_name
)
request_handler = RequestHandler(engine_client, metadata_store)
# TODO: integrate prefill_queue to a dynamo endpoint
async with PrefillQueue.get_instance(
nats_server=prefill_queue_nats_server,
stream_name=prefill_queue_stream_name,
) as prefill_queue:
while True:
# TODO: this might add a small overhead to pull prefill from nats
# need to test and check how much overhead it is
prefill_request = await prefill_queue.dequeue_prefill_request()
if prefill_request is not None:
vllm_logger.debug(
"Dequeued prefill request: %s", prefill_request.request_id
)
async for _ in request_handler.generate(prefill_request):
pass
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
if engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
engine_args.disable_async_output_proc = True
if engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
engine_args.enforce_eager = True
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import uuid
from enum import Enum
from typing import AsyncIterator, Tuple, Union
import uvloop
from transformers import AutoTokenizer
from utils.chat_processor import ChatProcessor, CompletionsProcessor, ProcessMixIn
from utils.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
CompletionRequest,
CompletionStreamResponse,
)
from vllm.logger import logger as vllm_logger
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.runtime import Client, DistributedRuntime, dynamo_endpoint, dynamo_worker
class RequestType(Enum):
CHAT = "chat"
COMPLETION = "completion"
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
def __init__(
self,
engine_args: AsyncEngineArgs,
router_client: Client,
workers_client: Client,
):
self.engine_args = engine_args
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer(engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
self.tokenizer, self.model_config
)
self.router_client = router_client
self.workers_client = workers_client
self.router_mode = engine_args.router
def _create_tokenizer(self, engine_args: AsyncEngineArgs) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def _generate(
self,
raw_request: Union[CompletionRequest, ChatCompletionRequest],
request_type: RequestType,
):
request_id = str(uuid.uuid4())
vllm_logger.debug(f"Got raw request: {raw_request}")
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
if self.router_mode == "kv":
worker_id_generator: AsyncIterator = await self.router_client.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
)
route_response = (
await worker_id_generator.__anext__()
) # only one worker id is returned
worker_id, prefix_hit_rate = route_response.data().split("_")
prefix_hit_rate = float(prefix_hit_rate)
vllm_logger.info(
f"Worker ID: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
)
if worker_id == "":
engine_generator = await self.workers_client.random(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json()
)
else:
engine_generator = await self.workers_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
prefix_hit_rate=prefix_hit_rate,
).model_dump_json(),
int(worker_id),
)
elif self.router_mode == "random":
engine_generator = await self.workers_client.random(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
elif self.router_mode == "round-robin":
engine_generator = await self.workers_client.round_robin(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
output = self._generate_responses(engine_generator, request_type)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
async def _generate_responses(
self, engine_generator: AsyncIterator[RequestOutput], request_type: RequestType
) -> AsyncIterator[Union[RequestOutput, Tuple[int, RequestOutput]]]:
prompt_idx = 0
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
# OpenAIServingChat.chat_completion_stream_generator() method expects a RequestOutput object
request_output = RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
if request_type == RequestType.CHAT:
# For chat requests, yield the request_output directly.
yield request_output
elif request_type == RequestType.COMPLETION:
# Completion requests can have multiple prompts and stream generator requires the prompt index
yield (prompt_idx, request_output)
else:
raise NotImplementedError(
f"Request type {request_type} not implemented"
)
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate_chat(self, raw_request):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
@dynamo_endpoint(CompletionRequest, CompletionStreamResponse)
async def generate_completions(self, raw_request):
async for response in self._generate(raw_request, RequestType.COMPLETION):
yield response
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
"""
Set up clients to the router and workers.
Serve the dynamo-init.process.chat/completions endpoint.
"""
workers_client = (
await runtime.namespace("dynamo-init")
.component("vllm")
.endpoint("generate")
.client()
)
router_client = (
await runtime.namespace("dynamo-init")
.component("router")
.endpoint("generate")
.client()
)
preprocess_component = runtime.namespace("dynamo-init").component("process")
await preprocess_component.create_service()
chat_endpoint = preprocess_component.endpoint("chat/completions")
completions_endpoint = preprocess_component.endpoint("completions")
processor = Processor(engine_args, router_client, workers_client)
await asyncio.gather(
chat_endpoint.serve_endpoint(processor.generate_chat),
completions_endpoint.serve_endpoint(processor.generate_completions),
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import msgspec
from vllm.sampling_params import SamplingParams
class Request(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True,
):
"""The request data of one remote prefill output of a request.
Args:
request_id: The unique ID of the request.
prompt: The prompt string of the request.
"""
request_id: str
prompt: str
sampling_params: SamplingParams
do_remote_prefill: bool = False
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import msgspec
import uvloop
from utils.nixl import NixlMetadataStore
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.runtime import DistributedRuntime, dynamo_worker
class RequestHandler:
def __init__(self, engine_client, metadata_store):
self.engine_client = engine_client
self._metadata_store = metadata_store
self._loaded_metadata = set()
print("RequestHandler initialized")
async def generate(self, raw_request: str):
request: RemotePrefillRequest = msgspec.json.decode(
raw_request.encode("utf-8"), type=RemotePrefillRequest
)
sampling_params = request.sampling_params
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
remote_prefill_params = RemotePrefillParams(
is_remote_decode=True,
decode_block_ids=request.block_ids,
decode_engine_id=request.engine_id,
)
# TODO check if metadata has changed
# and reload - currently only loading once
if request.engine_id not in self._loaded_metadata:
remote_metadata = await self._metadata_store.get(request.engine_id)
await self.engine_client.add_remote_nixl_metadata(remote_metadata)
print(
f"Loaded nixl metadata from engine {request.engine_id} into engine {self.engine_client.nixl_metadata.engine_id}"
)
self._loaded_metadata.add(request.engine_id)
async for _ in self.engine_client.generate(
request_id=request.request_id,
prompt=TokensPrompt(prompt_token_ids=request.prompt_token_ids),
sampling_params=sampling_params,
remote_prefill_params=remote_prefill_params,
):
yield
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("dynamo-init").component("prefill")
await component.create_service()
endpoint = component.endpoint("generate")
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
await endpoint.serve_endpoint(
RequestHandler(engine_client, metadata_store).generate
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
if engine_args.disable_async_output_proc is not True:
print("Async output processing is not supported yet, setting to True")
engine_args.disable_async_output_proc = True
if engine_args.enforce_eager is not True:
print("Prefill must be done eagerly, setting to True")
engine_args.enforce_eager = True
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import msgspec
import uvloop
from utils.nixl import NixlMetadataStore
from utils.vllm import parse_vllm_args
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.multiprocessing.client import EngineClient
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args,
)
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionStreamResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
class RequestHandler:
def __init__(
self,
model_name: str,
engine_client: EngineClient,
prefill_client,
do_remote_prefill: bool,
):
self.model_name = model_name
self.engine_client = engine_client
self.prefill_client = prefill_client
self.openai_serving_chat = None
self.initialized = False
self.do_remote_prefill = (
do_remote_prefill # TODO: this should be decided by the algorithm
)
print("RequestHandler initialized")
async def init(self):
models = OpenAIServingModels(
engine_client=self.engine_client,
model_config=await self.engine_client.get_model_config(),
base_model_paths=[
BaseModelPath(
name=self.model_name,
model_path=self.model_name,
)
],
)
self.openai_serving_chat = OpenAIServingChat(
engine_client=self.engine_client,
model_config=await self.engine_client.get_model_config(),
models=models,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
self.initialized = True
def get_remote_prefill_request_callback(self):
async def callback(request: RemotePrefillRequest):
json_request = msgspec.json.encode(request).decode("utf-8")
self.prefill_client.round_robin(json_request)
return callback
@dynamo_endpoint(ChatCompletionRequest, ChatCompletionStreamResponse)
async def generate(self, request):
if not self.initialized:
await self.init()
assert self.openai_serving_chat is not None
request.model = "vllm"
if self.do_remote_prefill:
remote_prefill_params = RemotePrefillParams(
is_remote_prefill=True,
remote_prefill_request_callback=self.get_remote_prefill_request_callback(),
)
else:
remote_prefill_params = None
async for raw_response in await self.openai_serving_chat.create_chat_completion(
request,
remote_prefill_params=remote_prefill_params,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
@dynamo_worker()
async def worker(runtime: DistributedRuntime, engine_args: AsyncEngineArgs):
component = runtime.namespace("dynamo-init").component("vllm")
await component.create_service()
endpoint = component.endpoint("generate")
prefill_client = (
await runtime.namespace("dynamo-init")
.component("prefill")
.endpoint("generate")
.client()
)
async with build_async_engine_client_from_engine_args(engine_args) as engine_client:
# This should be replaced with etcd
if engine_args.remote_prefill:
metadata = engine_client.nixl_metadata
metadata_store = NixlMetadataStore("dynamo-init", runtime)
await metadata_store.put(metadata.engine_id, metadata)
await endpoint.serve_endpoint(
RequestHandler(
model_name="vllm",
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=True,
).generate
)
else:
await endpoint.serve_endpoint(
RequestHandler(
model_name="vllm",
engine_client=engine_client,
prefill_client=prefill_client,
do_remote_prefill=False,
).generate
)
if __name__ == "__main__":
uvloop.install()
engine_args = parse_vllm_args()
if engine_args.remote_prefill:
if engine_args.enable_chunked_prefill is not False:
print("Chunked prefill is not supported yet, setting to False")
engine_args.enable_chunked_prefill = False
if engine_args.preemption_mode != "swap":
print("Preemption mode is not supported yet, setting to swap")
engine_args.preemption_mode = "swap"
if engine_args.pipeline_parallel_size != 1:
print("Pipeline parallel size is not supported yet, setting to 1")
engine_args.pipeline_parallel_size = 1
asyncio.run(worker(engine_args))
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
from typing import AsyncIterator, List, Optional, Protocol, Union, runtime_checkable
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
CompletionRequest,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import RequestPrompt
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import AnyTokenizer
@runtime_checkable
class ProcessMixInRequired(Protocol):
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
class ProcessMixIn(ProcessMixInRequired):
"""
Mixin for pre and post processing for vLLM
Requires engine_args, engine_client, processor, model_config to be initialized
"""
engine_args: AsyncEngineArgs
chat_processor: "ChatProcessor | None"
completions_processor: "CompletionsProcessor | None"
model_config: ModelConfig
def __init__(self):
pass
def _get_processor(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
# Determine the processor type based on the request structure
return (
self.chat_processor
if isinstance(raw_request, ChatCompletionRequest)
else self.completions_processor
)
async def _parse_raw_request(
self, raw_request: Union[CompletionRequest, ChatCompletionRequest]
):
processor = self._get_processor(raw_request)
if processor is None:
raise RuntimeError("Processor has not been initialized")
request = processor.parse_raw_request(raw_request)
preprocess_result = await processor.preprocess(raw_request)
default_max_tokens = self.model_config.max_model_len - len(
preprocess_result.engine_prompt["prompt_token_ids"]
)
default_sampling_params = self.model_config.get_diff_sampling_param()
sampling_params = request.to_sampling_params(
default_max_tokens,
self.model_config.logits_processor_pattern,
default_sampling_params,
)
return (
request,
preprocess_result.conversation,
preprocess_result.request_prompt,
preprocess_result.engine_prompt,
sampling_params,
)
async def _stream_response(self, request, generator, request_id, conversation):
processor = self._get_processor(request)
if processor is None:
raise RuntimeError("processor has not been initialized")
return processor.stream_response(
request,
generator,
request_id,
conversation,
)
class PreprocessResult:
def __init__(
self,
conversation: Optional[ConversationMessage],
request_prompt: RequestPrompt,
engine_prompt: TokensPrompt,
):
self.conversation = conversation
self.request_prompt = request_prompt
self.engine_prompt = engine_prompt
class ChatProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingChat(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
response_role="assistant",
chat_template=None,
chat_template_content_format="auto",
)
def parse_raw_request(
self, raw_request: ChatCompletionRequest
) -> ChatCompletionRequest:
return ChatCompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: ChatCompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
continue_final_message=request.continue_final_message,
tool_dicts=None,
documents=request.documents,
chat_template_kwargs=request.chat_template_kwargs,
tool_parser=self.openai_serving.tool_parser,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(conversation[0], request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: List,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.chat_completion_stream_generator(
request,
result_generator,
request_id,
request.model,
conversation,
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
class CompletionsProcessor:
def __init__(self, tokenizer: AnyTokenizer, model_config: ModelConfig):
self.tokenizer = tokenizer
self.model_config = model_config
self.openai_serving = OpenAIServingCompletion(
engine_client=None,
model_config=model_config,
models=None,
request_logger=None,
)
def parse_raw_request(self, raw_request: CompletionRequest) -> CompletionRequest:
return CompletionRequest.parse_obj(raw_request)
async def preprocess(self, raw_request: CompletionRequest) -> PreprocessResult:
request = self.parse_raw_request(raw_request)
(
request_prompts,
engine_prompts,
) = await self.openai_serving._preprocess_completion(
request,
self.tokenizer,
input_or_inputs=request.prompt,
truncate_prompt_tokens=request.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
return PreprocessResult(None, request_prompts[0], engine_prompts[0])
async def stream_response(
self,
request: CompletionRequest,
result_generator: AsyncIterator,
request_id: str,
conversation: Optional[List[ConversationMessage]] = None,
):
request_metadata = RequestResponseMetadata(request_id=request_id)
if not request.stream:
raise ValueError("Only streaming responses are supported")
async for raw_response in self.openai_serving.completion_stream_generator(
request,
result_generator,
request_id,
int(time.time()), # created_time
request.model,
1, # num_prompts
self.tokenizer,
request_metadata,
):
if raw_response.startswith("data: [DONE]"):
break
response = json.loads(raw_response.lstrip("data: "))
yield response
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment