Commit aacc5d76 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: onboard dynamo-sdk basic and kv-router examples (#20)


Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent 2ee29443
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import uuid
from typing import AsyncIterator
import bentoml
from sdk_kv_router.router import Router
from sdk_kv_router.worker import VllmEngine
with bentoml.importing():
from transformers import AutoTokenizer
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import AnyTokenizer
from common.chat_processor import ChatProcessor, ProcessMixIn
from common.protocol import MyRequestOutput, Tokens, vLLMGenerateRequest
from dynemo.sdk import depends, dynemo_context, dynemo_endpoint, service
@service(
dynemo={
"enabled": True,
"namespace": "dynemo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Processor(ProcessMixIn):
"""
vLLM pre and post processing
"""
workers = depends(VllmEngine)
router = depends(Router)
def __init__(self):
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.engine_args = AsyncEngineArgs(
model=model,
tokenizer=model,
enable_prefix_caching=True,
block_size=64,
max_model_len=16384,
)
self.model_config = self.engine_args.create_model_config()
self.tokenizer = self._create_tokenizer()
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
def _create_tokenizer(self) -> AnyTokenizer:
"""Create a TokenizerGroup using engine arguments similar to VLLM's approach"""
model_path = self.engine_args.model
# Create the base tokenizer with VLLM's typical settings
base_tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
padding_side="left",
truncation_side="left",
use_fast=True, # VLLM might use the fast tokenizer for efficiency
)
return base_tokenizer
async def generate_responses(
self, engine_generator
) -> AsyncIterator[RequestOutput]:
async for resp in engine_generator:
# Deserialize the response from the engine
# Creates correct vLLM objects for each field
output = MyRequestOutput.model_validate_json(resp.data())
yield RequestOutput(
request_id=output.request_id,
prompt=output.prompt,
prompt_token_ids=output.prompt_token_ids,
prompt_logprobs=output.prompt_logprobs,
outputs=output.outputs,
finished=output.finished,
metrics=output.metrics,
)
@dynemo_endpoint()
async def generate(self, raw_request: ChatCompletionRequest):
request_id = str(uuid.uuid4())
(
request,
conversation,
prompt,
engine_prompt,
sampling_params,
) = await self._parse_raw_request(raw_request)
worker_id = None
async for worker in self.router.generate(
Tokens(tokens=engine_prompt["prompt_token_ids"]).model_dump_json()
):
worker_id = worker
break
runtime = dynemo_context["runtime"]
comp_ns, comp_name = VllmEngine.dynemo_address() # type: ignore
worker_client = (
await runtime.namespace(comp_ns)
.component(comp_name)
.endpoint("generate")
.client()
)
if worker_id == "":
engine_generator = await worker_client.generate(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json()
)
else:
engine_generator = await worker_client.direct(
vLLMGenerateRequest(
engine_prompt=engine_prompt,
sampling_params=sampling_params,
request_id=request_id,
).model_dump_json(),
uuid.UUID(worker_id).int,
)
output = self.generate_responses(engine_generator)
async for response in await self._stream_response(
request, output, request_id, conversation
):
yield response
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from enum import Enum
import bentoml
from common.protocol import Tokens
from dynemo.sdk import async_onstart, dynemo_context, dynemo_endpoint, service
with bentoml.importing():
from dynemo.runtime import KvRouter
WorkerId = str
class RoutingStrategy(Enum):
PREFIX = "prefix"
ROUND_ROBIN = "round_robin"
RANDOM = "random"
@service(
dynemo={
"enabled": True,
"namespace": "dynemo",
},
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
class Router:
"""
Request handler for the generate endpoint
"""
def __init__(self):
self.model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.routing_strategy = RoutingStrategy.PREFIX
self.runtime = dynemo_context["runtime"]
self.min_workers = 1
@async_onstart
async def init_engine(self):
workers_client = (
await self.runtime.namespace("dynemo")
.component("VllmEngine")
.endpoint("generate")
.client()
)
wait_task = workers_client.wait_for_endpoints()
await asyncio.sleep(1)
while not wait_task.done():
print("Waiting for workers to be ready...")
await asyncio.sleep(5)
wait_task.result()
while len(workers_client.endpoint_ids()) < self.min_workers:
print(
f"Waiting for more workers... Current: {len(workers_client.endpoint_ids())}, Required: {self.min_workers}"
)
await asyncio.sleep(5)
kv_listener = self.runtime.namespace("dynemo").component(self.model_name)
await kv_listener.create_service()
self.router = KvRouter(self.runtime, kv_listener)
@dynemo_endpoint()
async def generate(self, request: Tokens):
lora_id = 0
worker_id = ""
if self.routing_strategy == RoutingStrategy.PREFIX:
try:
worker_id = await self.router.schedule(request.tokens, lora_id)
except Exception as e:
if "No worker found" in str(e):
worker_id = ""
else:
print(f"Error during worker selection: {e}")
print(f"Scheduling to worker_id: {worker_id}")
yield worker_id
else:
# TODO: Do we implement round_robin and random here?
# or just skip this router and directly enable in preprocess?
raise NotImplementedError(
f"Routing strategy {self.routing_strategy} not implemented"
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
from typing import Optional
import bentoml
with bentoml.importing():
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.logger import logger as vllm_logger
from vllm.sampling_params import RequestOutputKind
from common.base_engine import BaseVllmEngine
from common.protocol import MyRequestOutput, vLLMGenerateRequest
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from dynemo.llm import KvMetricsPublisher
from dynemo.sdk import (
async_onstart,
dynemo_context,
dynemo_endpoint,
server_context,
service,
)
lease_id = None
## TODO: metrics_publisher.create_endpoint(worker_component),
@service(
dynemo={
"enabled": True,
"namespace": "dynemo",
},
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
workers=1,
)
class VllmEngine(BaseVllmEngine):
"""
vLLM Inference Engine
"""
def __init__(self):
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
self.engine_args = AsyncEngineArgs(
model=model,
gpu_memory_utilization=0.8,
enable_prefix_caching=True,
block_size=64,
max_model_len=16384,
)
VLLM_WORKER_ID = dynemo_context["endpoints"][0].lease_id()
os.environ["VLLM_WORKER_ID"] = str(VLLM_WORKER_ID)
os.environ["VLLM_KV_NAMESPACE"] = "dynemo"
os.environ["VLLM_KV_COMPONENT"] = "vllm"
vllm_logger.info(f"Generate endpoint ID: {VLLM_WORKER_ID}")
os.environ["CUDA_VISIBLE_DEVICES"] = f"{server_context.worker_index - 1}"
self.metrics_publisher = KvMetricsPublisher()
self.engine_client: Optional[MQLLMEngineClient] = None
super().__init__(self.engine_args)
async def create_metrics_publisher_endpoint(self):
component = dynemo_context["component"]
await self.metrics_publisher.create_endpoint(component)
@async_onstart
async def init_engine(self):
if self.engine_client is None:
await super().initialize()
print("vLLM worker initialized")
assert self.engine_client is not None, "engine_client was not initialized"
self.engine_client.set_metrics_publisher(self.metrics_publisher)
self.metrics_publisher.publish(0, 1024, 0, 1024)
task = asyncio.create_task(self.create_metrics_publisher_endpoint())
task.add_done_callback(lambda _: print("metrics publisher endpoint created"))
@dynemo_endpoint()
async def generate(self, request: vLLMGenerateRequest):
sampling_params = request.sampling_params
# rust HTTP requires Delta streaming
sampling_params.output_kind = RequestOutputKind.DELTA
async for response in self.engine_client.generate( # type: ignore
request.engine_prompt, sampling_params, request.request_id
):
# MyRequestOutput takes care of serializing the response as
# vLLM's RequestOutput is not serializable by default
resp = MyRequestOutput(
request_id=response.request_id,
prompt=response.prompt,
prompt_token_ids=response.prompt_token_ids,
prompt_logprobs=response.prompt_logprobs,
outputs=response.outputs,
finished=response.finished,
).model_dump_json()
yield resp
......@@ -95,7 +95,11 @@ xfail_strict = true
log_cli_level = "INFO"
filterwarnings = [
"error",
"ignore:.*pkg_resources.*:DeprecationWarning",
"ignore:.*multipart.*:PendingDeprecationWarning"
]
# NOTE: Can also manually mark tests with @pytest.mark.asyncio
asyncio_mode = "auto"
markers = [
......@@ -134,5 +138,10 @@ check_untyped_defs = true
[[tool.mypy.overrides]]
# Skip mypy analysis on internal dependencies of vllm
module = ["vllm.*"]
module = ["vllm.*", "bentoml.*", "fs.*", "_bentoml_sdk.*"]
follow_imports = "skip"
ignore_missing_imports = true
# declare namespace packages
[tool.setuptools]
namespace-packages = ["dynemo", "dynemo.sdk"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment