Commit b8120504 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(sdk): add initial graph structure for prebuilt components (#130)


Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent 4f7f4b40
......@@ -24,6 +24,7 @@ import typing as t
import click
import rich
import yaml
if t.TYPE_CHECKING:
P = t.ParamSpec("P") # type: ignore
......@@ -126,6 +127,12 @@ def build_serve_command() -> click.Group:
cls=AliasCommand,
)
@click.argument("bento", type=click.STRING, default=".")
@click.option(
"-f",
"--file",
type=click.Path(exists=True),
help="Path to YAML config file for service configuration",
)
@click.option(
"--development",
type=click.BOOL,
......@@ -266,6 +273,7 @@ def build_serve_command() -> click.Group:
development: bool,
port: int,
host: str,
file: str | None,
api_workers: int,
timeout: int | None,
backlog: int,
......@@ -321,8 +329,25 @@ def build_serve_command() -> click.Group:
from bentoml import Service
from bentoml._internal.service.loader import load
from dynamo.sdk.lib.service import LinkedServices
# Process service-specific options
service_configs = _parse_service_args(ctx.args)
service_configs: t.Optional[t.Dict[str, t.Any]] = _parse_service_args(ctx.args)
# Load and merge config file if provided
if file:
with open(file) as f:
yaml_configs = yaml.safe_load(f)
# Initialize service_configs as empty dict if it's None
service_configs = dict(service_configs or {})
# Convert nested YAML structure to flat dict with dot notation
for service, configs in yaml_configs.items():
for key, value in configs.items():
if service not in service_configs:
service_configs[service] = {}
service_configs[service][key] = value
# print("service_configs", service_configs)
# Set environment variable with service configuration
if service_configs:
......@@ -337,6 +362,8 @@ def build_serve_command() -> click.Group:
if sys.path[0] != working_dir:
sys.path.insert(0, working_dir)
svc = load(bento_identifier=bento, working_dir=working_dir)
LinkedServices.remove_unused_edges()
if isinstance(svc, Service):
# bentoml<1.2
from bentoml.serving import serve_http_production
......
......@@ -29,6 +29,7 @@ import click
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.sdk import dynamo_context
from dynamo.sdk.lib.service import LinkedServices
logger = logging.getLogger("dynamo.sdk.serve.dynamo")
logger.setLevel(logging.INFO)
......@@ -99,6 +100,8 @@ def main(
t.cast(t.Dict[str, str], json.loads(runner_map))
)
# TODO: test this with a deep chain of services
LinkedServices.remove_unused_edges()
# Check if Dynamo is enabled for this service
if service.is_dynamo_component():
if worker_id is not None:
......
......@@ -16,17 +16,45 @@ from __future__ import annotations
import json
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
from _bentoml_sdk import Service, ServiceConfig
from _bentoml_sdk.images import Image
from _bentoml_sdk.service.config import validate
from dynamo.sdk.lib.decorators import DynamoEndpoint
T = TypeVar("T", bound=object)
class RuntimeLinkedServices:
"""
A class to track the linked services in the runtime.
"""
def __init__(self) -> None:
self.edges: Dict[DynamoService, Set[DynamoService]] = defaultdict(set)
def add(self, edge: Tuple[DynamoService, DynamoService]):
src, dest = edge
self.edges[src].add(dest.inner)
# track the dest node as well so we can cleanup later
self.edges[dest]
def remove_unused_edges(self):
# this method is idempotent
if not self.edges:
return
# remove edges that are not in the current service
for u, vertices in self.edges.items():
u.remove_unused_edges(used_edges=vertices)
LinkedServices = RuntimeLinkedServices()
@dataclass
class DynamoConfig:
"""Configuration for Dynamo components"""
......@@ -47,6 +75,15 @@ class DynamoService(Service[T]):
envs: Optional[list[dict[str, Any]]] = None,
dynamo_config: Optional[DynamoConfig] = None,
):
service_name = inner.__name__
service_args = self._get_service_args(service_name)
if service_args:
# Validate and merge service args with existing config
validated_args = validate(service_args)
config.update(validated_args)
self._remove_service_args(service_name)
super().__init__(config=config, inner=inner, image=image, envs=envs or [])
# Initialize Dynamo configuration
......@@ -65,6 +102,17 @@ class DynamoService(Service[T]):
if isinstance(value, DynamoEndpoint):
self._dynamo_endpoints[value.name] = value
self._linked_services: List[DynamoService] = [] # Track linked services
def _get_service_args(self, service_name: str) -> Optional[dict]:
"""Get ServiceArgs from environment config if specified"""
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str:
config = json.loads(config_str)
service_config = config.get(service_name, {})
return service_config.get("ServiceArgs")
return None
def is_dynamo_component(self) -> bool:
"""Check if this service is configured as a Dynamo component"""
return self._dynamo_config.enabled
......@@ -111,7 +159,27 @@ class DynamoService(Service[T]):
"""List names of all registered Dynamo endpoints"""
return list(self._dynamo_endpoints.keys())
# todo: add another function to bind an instance of the inner to the self within these methods
def remove_unused_edges(self, used_edges: Set[DynamoService]):
"""Remove a dependancy from the current service based on the key"""
current_deps = dict(self.dependencies)
for dep_key, dep_value in current_deps.items():
if dep_value.on.inner not in used_edges:
del self.dependencies[dep_key]
def link(self, next_service: DynamoService):
"""Link this service to another service, creating a pipeline."""
self._linked_services.append(next_service)
LinkedServices.add((self, next_service))
return next_service
def _remove_service_args(self, service_name: str):
"""Remove ServiceArgs from the environment config after using them"""
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str:
config = json.loads(config_str)
if service_name in config and "ServiceArgs" in config[service_name]:
del config[service_name]["ServiceArgs"]
os.environ["DYNAMO_SERVICE_CONFIG"] = json.dumps(config)
def service(
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# linking syntax example
from dynamo.sdk.tests.pipeline import Backend, Frontend, Middle
# print("INITIAL DEPENDENCIES")
# print("Frontend dependencies", Frontend.dependencies)
# print("Middle dependencies", Middle.dependencies)
# print("Backend dependencies", Backend.dependencies)
# print("\n\n\n")
print()
Frontend.link(Middle).build()
print("Frontend dependencies", Frontend.dependencies)
print("Middle dependencies", Middle.dependencies)
print("Backend dependencies", Backend.dependencies)
......@@ -16,12 +16,10 @@
# This is a simple example of a pipeline that uses Dynamo to deploy a backend, middle, and frontend service. Use this to test
# changes made to CLI, SDK, etc
import os
from pydantic import BaseModel
from dynamo.sdk import api, depends, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
"""
Pipeline Architecture:
......@@ -56,25 +54,14 @@ class ResponseType(BaseModel):
GPU_ENABLED = False
class FrontendConfig(BaseModel):
model: str
temperature: float = 0.7
max_tokens: int = 1024
stream: bool = True
class MiddleConfig(BaseModel):
bias: float
@service(
resources={"cpu": "2"},
resources={"cpu": "1"},
traffic={"timeout": 30},
dynamo={
"enabled": True,
"namespace": "inference",
},
workers=3,
workers=1,
)
class Backend:
def __init__(self) -> None:
......@@ -95,74 +82,52 @@ class Backend:
traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"},
)
class Backend2:
backend = depends(Backend)
def __init__(self) -> None:
print("Starting middle2")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
print(f"Middle2 received: {req_text}")
text = f"{req_text}-mid2"
next_request = RequestType(text=text).model_dump_json()
print(next_request)
@service(
resources={"cpu": "1"},
traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"},
)
class Middle:
backend = depends(Backend)
backend2 = depends(Backend2)
def __init__(self) -> None:
print("Starting middle")
config = ServiceConfig.get_instance()
middle_config = MiddleConfig(**config.get("Middle", {}))
print(f"bias: {middle_config.bias}")
if GPU_ENABLED:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser
try:
os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
# Get VLLM args using new pattern
vllm_args = config.as_args("Middle", prefix="vllm_")
print(f"VLLM args to parse: {vllm_args}")
# Create and use parser
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
self.engine_args = AsyncEngineArgs.from_cli_args(args)
self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)
except ImportError:
print("VLLM imports not available, skipping engine arg parsing")
except Exception as e:
print(f"Error parsing VLLM args: {e}")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
print(f"Middle received: {req_text}")
text = f"{req_text}-mid"
next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request):
print(f"Middle received response: {response}")
yield f"Middle: {response}"
for token in text.split():
yield f"Mid: {token}"
@service(resources={"cpu": "1"}, traffic={"timeout": 60}) # Regular HTTP API
@service(resources={"cpu": "1"}, traffic={"timeout": 60})
class Frontend:
middle = depends(Middle)
backend = depends(Backend)
def __init__(self) -> None:
print("Starting frontend")
self.config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**self.config.get("Frontend", {}))
print(
f"Frontend initialized with model={frontend_config.model}, "
f"temp={frontend_config.temperature}, max_tokens={frontend_config.max_tokens}"
)
# Get all configs for a service (new dict pattern)
all_frontend_configs = self.config.get("Frontend", {})
print(f"All Frontend configs: {all_frontend_configs}")
# Check other service configs (new dict pattern)
if self.config.get("Middle", {}).get("special_mode") == "fast":
print("Using Middle service in fast mode")
@api
async def generate(self, text):
......@@ -171,5 +136,11 @@ class Frontend:
print(f"Frontend received type: {type(text)}")
txt = RequestType(text=text)
print(f"Frontend sending: {type(txt)}")
async for response in self.middle.generate(txt.model_dump_json()):
yield f"Frontend: {response}"
if self.backend:
async for back_resp in self.backend.generate(txt.model_dump_json()):
print(f"Frontend received back_resp: {back_resp}")
yield f"Frontend: {back_resp}"
else:
async for mid_resp in self.middle.generate(txt.model_dump_json()):
print(f"Frontend received mid_resp: {mid_resp}")
yield f"Frontend: {mid_resp}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo-init.VllmWorker.generate
port: 8000
VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
conditional-disagg: true
remote-prefill: true
PrefillWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
cuda-visible-device-offset: 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from disaggregated.frontend import Frontend
from disaggregated.kv_router import Router
from disaggregated.processor import Processor
from disaggregated.worker import VllmWorker
# example 2 and 3: kv aware routing + worker
# kv.yaml
Frontend.link(Processor).link(Router).link(VllmWorker)
# example 4 and 5: only disag - issue with endpoint (probably because of routerless)
# disag.yaml
# Frontend.link(VllmWorker).link(PrefillWorker)
# example 6: disag with kv
# kv_with_disag.yaml
# Frontend.link(Processor).link(Router).link(VllmWorker).link(PrefillWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from disaggregated.processor import Processor
from disaggregated.worker import VllmWorker
from pydantic import BaseModel
from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
class FrontendConfig(BaseModel):
model: str
endpoint: str
port: int = 8080
@service(
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
# todo this should be called ApiServer
class Frontend:
worker = depends(VllmWorker)
processor = depends(Processor)
def __init__(self):
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
os.environ["TRT_LOG"] = "DEBUG"
subprocess.run(
["llmctl", "http", "remove", "chat-models", frontend_config.model]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"chat-models",
frontend_config.model,
frontend_config.endpoint,
]
)
subprocess.run(
["http", "-p", str(frontend_config.port)], stdout=None, stderr=None
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo-init.Processor.chat/completions
port: 8000
Processor:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
router: kv
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
enable-prefix-caching: true
router: kv
remote-prefill: true
tensor-parallel-size: 1
ServiceArgs:
workers: 2
envs:
- CUDA_VISIBLE_DEVICES: '0,1'
......@@ -20,11 +20,12 @@ import random
from argparse import Namespace
from typing import AsyncIterator
from disaggregated.worker import VllmWorker
from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_onstart, dynamo_context, dynamo_endpoint, service
from dynamo.sdk import async_onstart, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
WorkerId = str
......@@ -76,10 +77,11 @@ class Router:
Request handler for the generate endpoint
"""
worker = depends(VllmWorker)
def __init__(self):
vllm_logger.info("Initializing Custom Router")
self.args = parse_args(self.__class__.__name__, "")
print("[ROUTER] args = ", self.args)
@async_onstart
async def async_init(self):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo-init.Processor.chat/completions
port: 8000
Processor:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
router: kv
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
conditional-disagg: true
tensor-parallel-size: 1
router: kv
enable-prefix-caching: true
# TODO - set all of these but model as default
PrefillWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
cuda-visible-device-offset: 1
......@@ -60,7 +60,6 @@ class Processor(ProcessMixIn):
class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config()
print(f"[Processor] self.engine_args: {self.engine_args}")
self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor(
......@@ -191,7 +190,7 @@ class Processor(ProcessMixIn):
f"Request type {request_type} not implemented"
)
@dynamo_endpoint()
@dynamo_endpoint(name="chat/completions")
async def chat_completions(self, raw_request: ChatCompletionRequest):
async for response in self._generate(raw_request, RequestType.CHAT):
yield response
......
......@@ -22,9 +22,7 @@ from dynamo.sdk.lib.config import ServiceConfig
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance()
print(f"[DEBUG] config: {config}")
vllm_args = config.as_args(service_name, prefix=prefix)
print(f"[DEBUG] service_name: {service_name}, vllm_args: {vllm_args}")
parser = FlexibleArgumentParser()
parser.add_argument(
"--router",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment