"dynamo.code-workspace" did not exist on "437d8e37771e6854595486f73d2441d129d95cde"
Unverified Commit 16310b26 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

refactor: refactor dynamo serve part-1/N (#788)


Co-authored-by: default avatarishandhanani <ishandhanani@gmail.com>
parent dbdbd5e5
...@@ -35,7 +35,7 @@ The code for the pipeline looks like this: ...@@ -35,7 +35,7 @@ The code for the pipeline looks like this:
```python ```python
# filename: pipeline.py # filename: pipeline.py
from dynamo.sdk import service, dynamo_endpoint, depends, api from dynamo.sdk import service, dynamo_endpoint, depends
from pydantic import BaseModel from pydantic import BaseModel
class RequestType(BaseModel): class RequestType(BaseModel):
...@@ -93,7 +93,7 @@ dynamo serve pipeline:Frontend ...@@ -93,7 +93,7 @@ dynamo serve pipeline:Frontend
Once it's up and running, you can make a request to the pipeline using Once it's up and running, you can make a request to the pipeline using
```bash ```bash
curl -X POST http://localhost:3000/generate \ curl -X POST http://localhost:8000/generate \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{"text": "federer"}' -d '{"text": "federer"}'
``` ```
......
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
from typing import Any from typing import Any
from bentoml import api # type: ignore
from bentoml import on_shutdown as async_on_shutdown from bentoml import on_shutdown as async_on_shutdown
from bentoml._internal.context import server_context # type: ignore from bentoml._internal.context import server_context # type: ignore
from dynamo.sdk.lib.decorators import async_on_start, dynamo_api, dynamo_endpoint from dynamo.sdk.lib.decorators import async_on_start, dynamo_endpoint
from dynamo.sdk.lib.dependency import depends from dynamo.sdk.lib.dependency import depends
from dynamo.sdk.lib.image import DYNAMO_IMAGE from dynamo.sdk.lib.image import DYNAMO_IMAGE
from dynamo.sdk.lib.service import service from dynamo.sdk.lib.service import service
...@@ -28,11 +27,9 @@ dynamo_context: dict[str, Any] = {} ...@@ -28,11 +27,9 @@ dynamo_context: dict[str, Any] = {}
__all__ = [ __all__ = [
"DYNAMO_IMAGE", "DYNAMO_IMAGE",
"api",
"async_on_shutdown", "async_on_shutdown",
"async_on_start", "async_on_start",
"depends", "depends",
"dynamo_api",
"dynamo_context", "dynamo_context",
"dynamo_endpoint", "dynamo_endpoint",
"server_context", "server_context",
......
...@@ -17,120 +17,160 @@ ...@@ -17,120 +17,160 @@
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import warnings
from typing import Any from typing import Any
from _bentoml_sdk import Service from _bentoml_sdk import Service
from bentoml._internal.configuration.containers import BentoMLContainer from simple_di import inject
from bentoml._internal.resource import system_resources
from bentoml.exceptions import BentoMLConfigException
from simple_di import Provide, inject
NVIDIA_GPU = "nvidia.com/gpu" # Import our own resource module
from dynamo.sdk.lib.resource import NVIDIA_GPU, GPUManager, system_resources
logger = logging.getLogger(__name__)
# Constants
DYN_DISABLE_AUTO_GPU_ALLOCATION = "DYN_DISABLE_AUTO_GPU_ALLOCATION" DYN_DISABLE_AUTO_GPU_ALLOCATION = "DYN_DISABLE_AUTO_GPU_ALLOCATION"
DYN_DEPLOYMENT_ENV = "DYN_DEPLOYMENT_ENV" DYN_DEPLOYMENT_ENV = "DYN_DEPLOYMENT_ENV"
def format_memory_gb(memory_bytes: float) -> str:
"""Convert memory from bytes to formatted GB string.
Args:
memory_bytes: Memory size in bytes
Returns:
Formatted string with memory size in GB with 1 decimal place
"""
return f"{memory_bytes/1024/1024/1024:.1f}GB"
class ResourceAllocator: class ResourceAllocator:
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize the resource allocator."""
self.system_resources = system_resources() self.system_resources = system_resources()
self.gpu_manager = GPUManager()
self.remaining_gpus = len(self.system_resources[NVIDIA_GPU]) self.remaining_gpus = len(self.system_resources[NVIDIA_GPU])
# For compatibility with the old implementation
self._available_gpus: list[tuple[float, float]] = [ self._available_gpus: list[tuple[float, float]] = [
(1.0, 1.0) # each item is (remaining, unit) (1.0, 1.0) # each item is (remaining, unit)
for _ in range(self.remaining_gpus) for _ in range(self.remaining_gpus)
] ]
def assign_gpus(self, count: float) -> list[int]: def assign_gpus(self, count: float) -> list[int]:
if count > self.remaining_gpus: """
warnings.warn( Assign GPUs for use.
f"Requested {count} GPUs, but only {self.remaining_gpus} are remaining. "
f"Serving may fail due to inadequate GPUs. Set {DYN_DISABLE_AUTO_GPU_ALLOCATION}=1 " Args:
"to disable automatic allocation and allocate GPUs manually.", count: Number of GPUs to assign (can be fractional)
ResourceWarning,
stacklevel=3, Returns:
) List of GPU indices that were assigned
self.remaining_gpus = int(max(0, self.remaining_gpus - count)) """
if count < 1: # a fractional GPU # Use our GPU manager's assign_gpus method
try: return self.gpu_manager.assign_gpus(count)
# try to find the GPU used with the same fragment
gpu = next( def get_gpu_stats(self) -> list[dict[str, Any]]:
i """Get detailed statistics for all GPUs."""
for i, v in enumerate(self._available_gpus) return self.gpu_manager.get_gpu_stats()
if v[0] > 0 and v[1] == count
)
except StopIteration:
try:
gpu = next(
i for i, v in enumerate(self._available_gpus) if v[0] == 1.0
)
except StopIteration:
gpu = len(self._available_gpus)
self._available_gpus.append((1.0, count))
remaining, _ = self._available_gpus[gpu]
if (remaining := remaining - count) < count:
# can't assign to the next one, mark it as zero.
self._available_gpus[gpu] = (0.0, count)
else:
self._available_gpus[gpu] = (remaining, count)
return [gpu]
else: # allocate n GPUs, n is a positive integer
if int(count) != count:
raise BentoMLConfigException(
"Float GPUs larger than 1 is not supported"
)
count = int(count)
unassigned = [
gpu
for gpu, value in enumerate(self._available_gpus)
if value[0] > 0 and value[1] == 1.0
]
if len(unassigned) < count:
warnings.warn(
f"Not enough GPUs to be assigned, {count} is requested",
ResourceWarning,
)
for _ in range(count - len(unassigned)):
unassigned.append(len(self._available_gpus))
self._available_gpus.append((1.0, 1.0))
for gpu in unassigned[:count]:
self._available_gpus[gpu] = (0.0, 1.0)
return unassigned[:count]
@inject @inject
def get_resource_envs( def get_resource_envs(
self, self,
service: Service[Any], service: Service[Any],
services: dict[str, Any] = Provide[BentoMLContainer.config.services],
) -> tuple[int, list[dict[str, str]]]: ) -> tuple[int, list[dict[str, str]]]:
"""
Get resource environment variables for a service.
Args:
service: The service to get resource environment variables for
Returns:
Tuple of (number of workers, list of environment variables dictionaries)
"""
logger.info(f"Getting resource envs for service {service.name}")
services = service.get_service_configs()
if service.name not in services:
logger.warning(f"No service configs found for {service.name}")
return 1, [] # Default to 1 worker, no special resources
config = services[service.name] config = services[service.name]
logger.debug(f"Using config for {service.name}: {config}")
num_gpus = 0 num_gpus = 0
num_workers = 1 num_workers = 1
resource_envs: list[dict[str, str]] = [] resource_envs: list[dict[str, str]] = []
# Check if service requires GPUs
if "gpu" in (config.get("resources") or {}): if "gpu" in (config.get("resources") or {}):
num_gpus = config["resources"]["gpu"] # type: ignore num_gpus = config["resources"]["gpu"] # type: ignore
logger.info(f"GPU requirement found: {num_gpus}")
# Check if we have enough GPUs
available_gpus = self.gpu_manager.get_available_gpus()
if num_gpus > len(available_gpus):
logger.warning(
f"Requested {num_gpus} GPUs, but only {len(available_gpus)} are available. "
f"Service may fail due to inadequate GPU resources."
)
# Determine number of workers
if config.get("workers"): if config.get("workers"):
if (workers := config["workers"]) == "cpu_count": num_workers = config["workers"]
num_workers = int(self.system_resources["cpu"]) logger.info(f"Using configured worker count: {num_workers}")
# don't assign gpus to workers
return num_workers, resource_envs # Handle GPU allocation
else: # workers is a number
num_workers = workers
if num_gpus and DYN_DISABLE_AUTO_GPU_ALLOCATION not in os.environ: if num_gpus and DYN_DISABLE_AUTO_GPU_ALLOCATION not in os.environ:
logger.info("GPU allocation enabled")
if os.environ.get(DYN_DEPLOYMENT_ENV): if os.environ.get(DYN_DEPLOYMENT_ENV):
logger.info("K8s deployment detected")
# K8s replicas: Assumes DYNAMO_DEPLOYMENT_ENV is set # K8s replicas: Assumes DYNAMO_DEPLOYMENT_ENV is set
# each pod in replicaset will have separate GPU with same CUDA_VISIBLE_DEVICES # each pod in replicaset will have separate GPU with same CUDA_VISIBLE_DEVICES
assigned = self.assign_gpus(num_gpus) assigned = self.assign_gpus(num_gpus)
resource_envs = [ logger.info(f"Assigned GPUs for K8s: {assigned}")
{"CUDA_VISIBLE_DEVICES": ",".join(map(str, assigned))}
for _ in range(num_workers) # Generate environment variables for each worker
]
else:
# local deployment where we split all available GPUs across workers
for _ in range(num_workers): for _ in range(num_workers):
env_vars = {"CUDA_VISIBLE_DEVICES": ",".join(map(str, assigned))}
resource_envs.append(env_vars)
else:
logger.info("Local deployment detected")
# Local deployment where we split all available GPUs across workers
for worker_id in range(num_workers):
assigned = self.assign_gpus(num_gpus) assigned = self.assign_gpus(num_gpus)
resource_envs.append( logger.info(f"Assigned GPUs for worker {worker_id}: {assigned}")
{"CUDA_VISIBLE_DEVICES": ",".join(map(str, assigned))}
# Generate environment variables for this worker
env_vars = {"CUDA_VISIBLE_DEVICES": ",".join(map(str, assigned))}
# If we have comprehensive GPU stats, log them
try:
gpu_stats = [
stat
for stat in self.get_gpu_stats()
if stat["index"] in assigned
]
for stat in gpu_stats:
logger.info(
f"GPU {stat['index']} ({stat['name']}): "
f"Memory: {format_memory_gb(stat['free_memory'])} free / "
f"{format_memory_gb(stat['total_memory'])} total, "
f"Utilization: {stat['gpu_utilization']}%, "
f"Temperature: {stat['temperature']}°C"
)
except Exception as e:
logger.debug(f"Failed to get GPU stats: {e}")
resource_envs.append(env_vars)
logger.info(
f"Final resource allocation - workers: {num_workers}, envs: {resource_envs}"
) )
return num_workers, resource_envs return num_workers, resource_envs
def reset_allocations(self):
"""Reset all GPU allocations."""
self.gpu_manager.reset_allocations()
# Reset legacy tracking
self._available_gpus = [(1.0, 1.0) for _ in range(self.remaining_gpus)]
...@@ -117,7 +117,7 @@ class Bento(BaseBento): ...@@ -117,7 +117,7 @@ class Bento(BaseBento):
build_config.envs.extend(svc.envs) build_config.envs.extend(svc.envs)
build_config.labels.update(svc.labels) build_config.labels.update(svc.labels)
if svc.image is not None: if svc.image is not None:
image = svc.image image = Image(base_image=svc.image)
if not disable_image: if not disable_image:
image = populate_image_from_build_config(image, build_config, build_ctx) image = populate_image_from_build_config(image, build_config, build_ctx)
build_config = build_config.with_defaults() build_config = build_config.with_defaults()
......
This diff is collapsed.
# SPDX-FileCopyrightText: Copyright (c) 2020 Atalaya Tech. Inc
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
# Once planner v1 goes live - this will be be full of more granular APIs
from __future__ import annotations
import contextlib
import os
import pathlib
import shlex
import sys
from dataclasses import dataclass
from typing import Any, Callable
import psutil
from circus.arbiter import Arbiter as _Arbiter
from circus.sockets import CircusSocket
from circus.watcher import Watcher
from .utils import ServiceProtocol
class Arbiter(_Arbiter):
"""Arbiter with cleanup support via exit_stack."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.exit_stack = contextlib.ExitStack()
def start(self, cb: Callable[[Any], Any] | None = None) -> None:
"""Start arbiter and enter context."""
self.exit_stack.__enter__()
fut = super().start(cb)
if exc := fut.exception():
raise exc
def stop(self) -> None:
"""Stop arbiter and cleanup resources."""
self.exit_stack.__exit__(None, None, None)
return super().stop()
@dataclass
class CircusRunner:
"""Simple server wrapper for arbiter lifecycle management."""
arbiter: Arbiter
def stop(self) -> None:
self.arbiter.stop()
@property
def running(self) -> bool:
return self.arbiter.running
def __enter__(self) -> CircusRunner:
return self
def __exit__(self, *_: Any) -> None:
self.stop()
MAX_AF_UNIX_PATH_LENGTH = 103
def create_circus_watcher(
name: str,
args: list[str],
*,
cmd: str = sys.executable,
use_sockets: bool = True,
**kwargs: Any,
) -> Watcher:
return Watcher(
name=name,
cmd=shlex.quote(cmd) if psutil.POSIX else cmd,
args=args,
copy_env=True,
stop_children=True,
use_sockets=use_sockets,
graceful_timeout=86400,
respawn=False, # TODO
**kwargs,
)
def create_arbiter(
watchers: list[Watcher], *, threaded: bool = False, **kwargs: Any
) -> Arbiter:
endpoint_port = int(os.environ.get("DYN_CIRCUS_ENDPOINT_PORT", "41234"))
pubsub_port = int(os.environ.get("DYN_CIRCUS_PUBSUB_PORT", "52345"))
return Arbiter(
watchers,
endpoint=f"tcp://127.0.0.1:{endpoint_port}",
pubsub_endpoint=f"tcp://127.0.0.1:{pubsub_port}",
check_delay=kwargs.pop("check_delay", 10),
**kwargs,
)
def path_to_uri(path: str) -> str:
"""
Convert a path to a URI.
Args:
path: Path to convert to URI.
Returns:
URI string. (quoted, absolute)
"""
return pathlib.PurePosixPath(path).as_uri()
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
) -> tuple[str, CircusSocket]:
"""Create a Unix Domain Socket for a service.
Args:
service: The service to create a socket for
uds_path: Base directory for Unix Domain Sockets
port_stack: Not used in POSIX implementation, kept for interface compatibility
Returns:
Tuple of (socket URI, CircusSocket object)
Raises:
AssertionError: If socket path exceeds maximum length
"""
socket_path = os.path.join(uds_path, f"{id(service)}.sock")
assert (
len(socket_path) < MAX_AF_UNIX_PATH_LENGTH
), f"Socket path '{socket_path}' exceeds maximum length of {MAX_AF_UNIX_PATH_LENGTH}"
return path_to_uri(socket_path), CircusSocket(name=service.name, path=socket_path)
...@@ -36,7 +36,7 @@ from rich.console import Console ...@@ -36,7 +36,7 @@ from rich.console import Console
from rich.syntax import Syntax from rich.syntax import Syntax
from simple_di import Provide, inject from simple_di import Provide, inject
from dynamo.sdk.lib.bento import Bento from dynamo.sdk.cli.bento_util import Bento
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from bentoml._internal.bento import BentoStore from bentoml._internal.bento import BentoStore
......
...@@ -47,7 +47,7 @@ def serve( ...@@ -47,7 +47,7 @@ def serve(
service_name: str = typer.Option( service_name: str = typer.Option(
"", "",
help="Only serve the specified service. Don't serve any dependencies of this service.", help="Only serve the specified service. Don't serve any dependencies of this service.",
envvar="DYNAMO_SERVE_SERVICE_NAME", envvar="DYNAMO_SERVICE_NAME",
), ),
depends: List[str] = typer.Option( depends: List[str] = typer.Option(
[], [],
...@@ -92,8 +92,7 @@ def serve( ...@@ -92,8 +92,7 @@ def serve(
""" """
# Warning: internal # Warning: internal
from bentoml._internal.service.loader import load from dynamo.sdk.lib.loader import find_and_load_service
from dynamo.sdk.lib.logging import configure_server_logging from dynamo.sdk.lib.logging import configure_server_logging
from dynamo.sdk.lib.service import LinkedServices from dynamo.sdk.lib.service import LinkedServices
...@@ -138,11 +137,12 @@ def serve( ...@@ -138,11 +137,12 @@ def serve(
if sys.path[0] != working_dir_str: if sys.path[0] != working_dir_str:
sys.path.insert(0, working_dir_str) sys.path.insert(0, working_dir_str)
svc = load(bento_identifier=dynamo_pipeline, working_dir=working_dir_str) svc = find_and_load_service(dynamo_pipeline, working_dir=working_dir)
logger.info(f"Loaded service: {svc.name}")
logger.info("Dependencies: %s", [dep.on.name for dep in svc.dependencies.values()])
LinkedServices.remove_unused_edges() LinkedServices.remove_unused_edges()
from dynamo.sdk.cli.serving import serve_http # type: ignore from dynamo.sdk.cli.serving import serve_dynamo_graph # type: ignore
svc.inject_config() svc.inject_config()
...@@ -155,11 +155,11 @@ def serve( ...@@ -155,11 +155,11 @@ def serve(
) )
) )
serve_http( serve_dynamo_graph(
dynamo_pipeline, dynamo_pipeline,
working_dir=working_dir_str, working_dir=working_dir_str,
host=host, # host=host,
port=port, # port=port,
dependency_map=runner_map_dict, dependency_map=runner_map_dict,
service_name=service_name, service_name=service_name,
enable_planner=enable_planner, enable_planner=enable_planner,
......
...@@ -22,11 +22,16 @@ import inspect ...@@ -22,11 +22,16 @@ import inspect
import json import json
import logging import logging
import os import os
import signal
import sys
import time
import typing as t import typing as t
from typing import Any from typing import Any
import click import click
import uvicorn
import uvloop import uvloop
from fastapi.responses import StreamingResponse
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.sdk import dynamo_context from dynamo.sdk import dynamo_context
...@@ -35,6 +40,69 @@ from dynamo.sdk.lib.service import LinkedServices ...@@ -35,6 +40,69 @@ from dynamo.sdk.lib.service import LinkedServices
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def add_fastapi_routes(app, service, class_instance):
"""
Add FastAPI routes for Dynamo endpoints marked with is_api=True.
Args:
app: FastAPI app instance
service: Dynamo service instance
class_instance: Instance of the service class
"""
added_routes = []
for name, endpoint in service.get_dynamo_endpoints().items():
if endpoint.is_api:
path = name if name.startswith("/") else f"/{name}"
# Bind the method to the class instance
bound_method = endpoint.func.__get__(class_instance)
# Check if the method is a generator or async generator
is_streaming = inspect.isasyncgenfunction(
bound_method
) or inspect.isgeneratorfunction(bound_method)
# Set up appropriate response model and response class
if is_streaming:
logger.info(f"Registering streaming endpoint {path}")
app.add_api_route(
path,
bound_method,
methods=["POST"],
response_class=StreamingResponse,
)
else:
logger.info(f"Registering regular endpoint {path}")
app.add_api_route(
path,
bound_method,
methods=["POST"],
)
added_routes.append(path)
logger.info(f"Added API route {path} to FastAPI app")
return added_routes
class GracefulExit(SystemExit):
"""Exception to signal a graceful exit."""
pass
def setup_signal_handlers():
"""Setup signal handlers for graceful shutdown."""
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, initiating graceful shutdown")
raise GracefulExit(0)
# Register SIGINT and SIGTERM handlers
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGQUIT, signal_handler)
@click.command() @click.command()
@click.argument("bento_identifier", type=click.STRING, required=False, default=".") @click.argument("bento_identifier", type=click.STRING, required=False, default=".")
@click.option("--service-name", type=click.STRING, required=False, default="") @click.option("--service-name", type=click.STRING, required=False, default="")
...@@ -68,6 +136,10 @@ def main( ...@@ -68,6 +136,10 @@ def main(
from dynamo.sdk.lib.logging import configure_server_logging from dynamo.sdk.lib.logging import configure_server_logging
# Setup signal handlers for graceful shutdown
setup_signal_handlers()
run_id = service_name
dynamo_context["service_name"] = service_name dynamo_context["service_name"] = service_name
dynamo_context["runner_map"] = runner_map dynamo_context["runner_map"] = runner_map
dynamo_context["worker_id"] = worker_id dynamo_context["worker_id"] = worker_id
...@@ -168,7 +240,7 @@ def main( ...@@ -168,7 +240,7 @@ def main(
# Run startup hooks before setting up endpoints # Run startup hooks before setting up endpoints
for name, member in vars(class_instance.__class__).items(): for name, member in vars(class_instance.__class__).items():
if callable(member) and getattr( if callable(member) and getattr(
member, "__bentoml_startup_hook__", False member, "__dynamo_startup_hook__", False
): ):
logger.debug(f"Running startup hook: {name}") logger.debug(f"Running startup hook: {name}")
result = getattr(class_instance, name)() result = getattr(class_instance, name)()
...@@ -188,13 +260,75 @@ def main( ...@@ -188,13 +260,75 @@ def main(
logger.info(f"Serving {service.name} with lease: {lease.id()}") logger.info(f"Serving {service.name} with lease: {lease.id()}")
result = await endpoints[0].serve_endpoint(twm[0], lease) result = await endpoints[0].serve_endpoint(twm[0], lease)
except GracefulExit:
logger.info(f"[{run_id}] Gracefully shutting down {service.name}")
# Add any specific cleanup needed
return None
except Exception as e: except Exception as e:
logger.error(f"Error in Dynamo component setup: {str(e)}") logger.error(f"Error in Dynamo component setup: {str(e)}")
raise raise
# if the service has a FastAPI app, add the worker as an event handler
def web_worker():
try:
if not service.app:
return
# Create the class instance
class_instance = service.inner()
# TODO: init hooks
# Add API routes to the FastAPI app
added_routes = add_fastapi_routes(service.app, service, class_instance)
if added_routes:
# Configure uvicorn with graceful shutdown
config = uvicorn.Config(
service.app, host="0.0.0.0", port=8000, log_level="info"
)
server = uvicorn.Server(config)
# Start the server with graceful shutdown handling
logger.info(
f"Starting FastAPI server on 0.0.0.0:8000 with routes: {added_routes}"
)
server.run()
else:
logger.warning("No API routes found, not starting FastAPI server")
# Keep the process running until interrupted
logger.info("Service is running, press Ctrl+C to stop")
while True:
try:
# Sleep in small increments to respond to signals quickly
time.sleep(0.1)
except (KeyboardInterrupt, GracefulExit):
logger.info("Gracefully shutting down FastAPI process")
break
except GracefulExit:
logger.info("Gracefully shutting down FastAPI service")
except Exception as e:
logger.error(f"Error in web worker: {str(e)}")
raise
try:
uvloop.install() uvloop.install()
if service.app:
web_worker()
else:
asyncio.run(worker()) asyncio.run(worker())
except GracefulExit:
logger.info("Exiting gracefully")
sys.exit(0)
except KeyboardInterrupt:
logger.info("Interrupted, shutting down gracefully")
sys.exit(0)
if __name__ == "__main__": if __name__ == "__main__":
try:
main() main()
except (GracefulExit, KeyboardInterrupt):
logger.info("Exiting gracefully")
sys.exit(0)
except Exception as e:
logger.error(f"Error in main: {str(e)}")
sys.exit(1)
...@@ -18,112 +18,40 @@ ...@@ -18,112 +18,40 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
import ipaddress
import json import json
import logging import logging
import os import os
import pathlib import pathlib
import platform
import shutil import shutil
import socket
import tempfile import tempfile
import typing as t from typing import Any, Dict, Optional, TypeVar
from typing import Any, Dict, Optional, Protocol, TypeVar
# WARNING: internal # TODO: WARNING: internal but only for type checking in the deploy path i believe
from _bentoml_sdk import Service from _bentoml_sdk import Service
# WARNING: internal
from bentoml._internal.container import BentoMLContainer
# WARNING: internal
from bentoml._internal.utils.circus import Server
from bentoml.exceptions import BentoMLConfigException
from circus.sockets import CircusSocket from circus.sockets import CircusSocket
from circus.watcher import Watcher from circus.watcher import Watcher
from simple_di import Provide, inject from simple_di import inject
from dynamo.sdk.cli.circus import CircusRunner
from .allocator import ResourceAllocator from .allocator import ResourceAllocator
from .circus import _get_server_socket
from .utils import ( from .utils import (
DYN_LOCAL_STATE_DIR, DYN_LOCAL_STATE_DIR,
path_to_uri, ServiceProtocol,
reserve_free_port, reserve_free_port,
save_dynamo_state, save_dynamo_state,
) )
# WARNING: internal
# Define a Protocol for services to ensure type safety
class ServiceProtocol(Protocol):
name: str
inner: Any
models: list[Any]
bento: Any
def is_dynamo_component(self) -> bool:
...
def dynamo_address(self) -> tuple[str, str]:
...
# Use Protocol as the base for type alias # Use Protocol as the base for type alias
AnyService = TypeVar("AnyService", bound=ServiceProtocol) AnyService = TypeVar("AnyService", bound=ServiceProtocol)
POSIX = os.name == "posix"
WINDOWS = os.name == "nt"
IS_WSL = "microsoft-standard" in platform.release()
API_SERVER_NAME = "_bento_api_server"
MAX_AF_UNIX_PATH_LENGTH = 103
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if POSIX and not IS_WSL:
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
) -> tuple[str, CircusSocket]:
from circus.sockets import CircusSocket
socket_path = os.path.join(uds_path, f"{id(service)}.sock")
assert len(socket_path) < MAX_AF_UNIX_PATH_LENGTH
return path_to_uri(socket_path), CircusSocket(
name=service.name, path=socket_path
)
elif WINDOWS or IS_WSL:
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
) -> tuple[str, CircusSocket]:
from circus.sockets import CircusSocket
runner_port = port_stack.enter_context(reserve_free_port())
runner_host = "127.0.0.1"
return f"tcp://{runner_host}:{runner_port}", CircusSocket(
name=service.name,
host=runner_host,
port=runner_port,
)
else:
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
) -> tuple[str, CircusSocket]:
from bentoml.exceptions import BentoMLException
raise BentoMLException("Unsupported platform")
# WARNING: internal
_BENTO_WORKER_SCRIPT = "_bentoml_impl.worker.service"
_DYNAMO_WORKER_SCRIPT = "dynamo.sdk.cli.serve_dynamo" _DYNAMO_WORKER_SCRIPT = "dynamo.sdk.cli.serve_dynamo"
...@@ -140,63 +68,19 @@ def _get_dynamo_worker_script(bento_identifier: str, svc_name: str) -> list[str] ...@@ -140,63 +68,19 @@ def _get_dynamo_worker_script(bento_identifier: str, svc_name: str) -> list[str]
return args return args
def _get_bento_worker_script(bento_identifier: str, svc_name: str) -> list[str]:
args = [
"-m",
_BENTO_WORKER_SCRIPT,
bento_identifier,
"--service-name",
svc_name,
"--fd",
f"$(circus.sockets.{svc_name})",
"--worker-id",
"$(CIRCUS.WID)",
]
return args
def create_dependency_watcher(
bento_identifier: str,
svc: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
scheduler: ResourceAllocator,
working_dir: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
) -> tuple[Watcher, CircusSocket, str]:
from bentoml.serving import create_watcher
num_workers, resource_envs = scheduler.get_resource_envs(svc)
uri, socket = _get_server_socket(svc, uds_path, port_stack)
args = _get_bento_worker_script(bento_identifier, svc.name)
if resource_envs:
args.extend(["--worker-env", json.dumps(resource_envs)])
watcher = create_watcher(
name=f"service_{svc.name}",
args=args,
numprocesses=num_workers,
working_dir=working_dir,
env=env,
)
return watcher, socket, uri
def create_dynamo_watcher( def create_dynamo_watcher(
bento_identifier: str, bento_identifier: str,
svc: ServiceProtocol, svc: ServiceProtocol,
uds_path: str, uds_path: str,
port_stack: contextlib.ExitStack,
scheduler: ResourceAllocator, scheduler: ResourceAllocator,
working_dir: Optional[str] = None, working_dir: Optional[str] = None,
env: Optional[Dict[str, str]] = None, env: Optional[Dict[str, str]] = None,
) -> tuple[Watcher, CircusSocket, str]: ) -> tuple[Watcher, CircusSocket, str]:
"""Create a watcher for a Dynamo service in the dependency graph""" """Create a watcher for a Dynamo service in the dependency graph"""
from bentoml.serving import create_watcher from dynamo.sdk.cli.circus import create_circus_watcher
num_workers, resource_envs = scheduler.get_resource_envs(svc) num_workers, resource_envs = scheduler.get_resource_envs(svc)
uri, socket = _get_server_socket(svc, uds_path, port_stack) uri, socket = _get_server_socket(svc, uds_path)
args = _get_dynamo_worker_script(bento_identifier, svc.name) args = _get_dynamo_worker_script(bento_identifier, svc.name)
if resource_envs: if resource_envs:
args.extend(["--worker-env", json.dumps(resource_envs)]) args.extend(["--worker-env", json.dumps(resource_envs)])
...@@ -226,7 +110,7 @@ def create_dynamo_watcher( ...@@ -226,7 +110,7 @@ def create_dynamo_watcher(
namespace, _ = svc.dynamo_address() namespace, _ = svc.dynamo_address()
# Create the watcher with updated environment # Create the watcher with updated environment
watcher = create_watcher( watcher = create_circus_watcher(
name=f"{namespace}_{svc.name}", name=f"{namespace}_{svc.name}",
args=args, args=args,
numprocesses=num_workers, numprocesses=num_workers,
...@@ -240,23 +124,15 @@ def create_dynamo_watcher( ...@@ -240,23 +124,15 @@ def create_dynamo_watcher(
@inject(squeeze_none=True) @inject(squeeze_none=True)
def serve_http( def serve_dynamo_graph(
bento_identifier: str | AnyService, bento_identifier: str | AnyService,
working_dir: str | None = None, working_dir: str | None = None,
host: str = Provide[BentoMLContainer.http.host],
port: int = Provide[BentoMLContainer.http.port],
dependency_map: dict[str, str] | None = None, dependency_map: dict[str, str] | None = None,
service_name: str = "", service_name: str = "",
enable_planner: bool = False, enable_planner: bool = False,
) -> Server: ) -> CircusRunner:
# WARNING: internal from dynamo.sdk.cli.circus import create_arbiter, create_circus_watcher
from _bentoml_impl.loader import load from dynamo.sdk.lib.loader import find_and_load_service
# WARNING: internal
from bentoml._internal.utils.circus import create_standalone_arbiter
from bentoml.serving import create_watcher
from circus.sockets import CircusSocket
from dynamo.sdk.lib.logging import configure_server_logging from dynamo.sdk.lib.logging import configure_server_logging
from .allocator import ResourceAllocator from .allocator import ResourceAllocator
...@@ -275,7 +151,7 @@ def serve_http( ...@@ -275,7 +151,7 @@ def serve_http(
# use cwd # use cwd
bento_path = pathlib.Path(".") bento_path = pathlib.Path(".")
else: else:
svc = load(bento_identifier, working_dir) svc = find_and_load_service(bento_identifier, working_dir)
bento_id = str(bento_identifier) bento_id = str(bento_identifier)
bento_path = pathlib.Path(working_dir or ".") bento_path = pathlib.Path(working_dir or ".")
...@@ -294,7 +170,7 @@ def serve_http( ...@@ -294,7 +170,7 @@ def serve_http(
if service_name and service_name != svc.name: if service_name and service_name != svc.name:
svc = svc.find_dependent_by_name(service_name) svc = svc.find_dependent_by_name(service_name)
num_workers, resource_envs = allocator.get_resource_envs(svc) num_workers, resource_envs = allocator.get_resource_envs(svc)
uds_path = tempfile.mkdtemp(prefix="bentoml-uds-") uds_path = tempfile.mkdtemp(prefix="dynamo-uds-")
try: try:
if not service_name and not standalone: if not service_name and not standalone:
with contextlib.ExitStack() as port_stack: with contextlib.ExitStack() as port_stack:
...@@ -303,74 +179,28 @@ def serve_http( ...@@ -303,74 +179,28 @@ def serve_http(
continue continue
if name in dependency_map: if name in dependency_map:
continue continue
if not (
# Check if this is a Dynamo service
if (
hasattr(dep_svc, "is_dynamo_component") hasattr(dep_svc, "is_dynamo_component")
and dep_svc.is_dynamo_component() and dep_svc.is_dynamo_component()
): ):
raise RuntimeError(
f"Service {dep_svc.name} is not a Dynamo component"
)
new_watcher, new_socket, uri = create_dynamo_watcher( new_watcher, new_socket, uri = create_dynamo_watcher(
bento_id, bento_id,
dep_svc, dep_svc,
uds_path, uds_path,
port_stack,
allocator, allocator,
str(bento_path.absolute()), str(bento_path.absolute()),
env=env, env=env,
) )
namespace, _ = dep_svc.dynamo_address() namespace, _ = dep_svc.dynamo_address()
else:
# Regular BentoML service
new_watcher, new_socket, uri = create_dependency_watcher(
bento_id,
dep_svc,
uds_path,
port_stack,
allocator,
str(bento_path.absolute()),
env=env,
)
watchers.append(new_watcher) watchers.append(new_watcher)
sockets.append(new_socket) sockets.append(new_socket)
dependency_map[name] = uri dependency_map[name] = uri
# reserve one more to avoid conflicts # reserve one more to avoid conflicts
port_stack.enter_context(reserve_free_port()) port_stack.enter_context(reserve_free_port())
try:
ipaddr = ipaddress.ip_address(host)
if ipaddr.version == 4:
family = socket.AF_INET
elif ipaddr.version == 6:
family = socket.AF_INET6
else:
raise BentoMLConfigException(
f"Unsupported host IP address version: {ipaddr.version}"
)
except ValueError as e:
raise BentoMLConfigException(f"Invalid host IP address: {host}") from e
if not svc.is_dynamo_component():
sockets.append(
CircusSocket(
name=API_SERVER_NAME,
host=host,
port=port,
family=family,
)
)
server_args = [
"-m",
_BENTO_WORKER_SCRIPT,
bento_identifier,
"--fd",
f"$(circus.sockets.{API_SERVER_NAME})",
"--service-name",
svc.name,
"--worker-id",
"$(CIRCUS.WID)",
]
dynamo_args = [ dynamo_args = [
"-m", "-m",
_DYNAMO_WORKER_SCRIPT, _DYNAMO_WORKER_SCRIPT,
...@@ -380,12 +210,7 @@ def serve_http( ...@@ -380,12 +210,7 @@ def serve_http(
"--worker-id", "--worker-id",
"$(CIRCUS.WID)", "$(CIRCUS.WID)",
] ]
if resource_envs:
server_args.extend(["--worker-env", json.dumps(resource_envs)])
scheme = "http"
# Check if this is a Dynamo service
if hasattr(svc, "is_dynamo_component") and svc.is_dynamo_component(): if hasattr(svc, "is_dynamo_component") and svc.is_dynamo_component():
# resource_envs is the resource allocation (ie CUDA_VISIBLE_DEVICES) for each worker created by the allocator # resource_envs is the resource allocation (ie CUDA_VISIBLE_DEVICES) for each worker created by the allocator
# these resource_envs are passed to each individual worker's environment which is set in serve_dynamo # these resource_envs are passed to each individual worker's environment which is set in serve_dynamo
...@@ -411,7 +236,7 @@ def serve_http( ...@@ -411,7 +236,7 @@ def serve_http(
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}") logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}")
watcher = create_watcher( watcher = create_circus_watcher(
name=f"{namespace}_{svc.name}", name=f"{namespace}_{svc.name}",
args=dynamo_args, args=dynamo_args,
numprocesses=num_workers, numprocesses=num_workers,
...@@ -422,20 +247,6 @@ def serve_http( ...@@ -422,20 +247,6 @@ def serve_http(
logger.info( logger.info(
f"Created watcher for {svc.name} with {num_workers} workers in the {namespace} namespace" f"Created watcher for {svc.name} with {num_workers} workers in the {namespace} namespace"
) )
else:
watchers.append(
create_watcher(
name="service",
args=server_args,
working_dir=str(bento_path.absolute()),
numprocesses=num_workers,
env=env,
)
)
logger.info(f"Created watcher for service with {num_workers} workers")
log_host = "localhost" if host in ["0.0.0.0", "::"] else host
dependency_map[svc.name] = f"{scheme}://{log_host}:{port}"
# inject runner map now # inject runner map now
inject_env = {"BENTOML_RUNNER_MAP": json.dumps(dependency_map)} inject_env = {"BENTOML_RUNNER_MAP": json.dumps(dependency_map)}
...@@ -446,12 +257,12 @@ def serve_http( ...@@ -446,12 +257,12 @@ def serve_http(
else: else:
watcher.env.update(inject_env) watcher.env.update(inject_env)
arbiter_kwargs: dict[str, t.Any] = { arbiter_kwargs: dict[str, Any] = {
"watchers": watchers, "watchers": watchers,
"sockets": sockets, "sockets": sockets,
} }
arbiter = create_standalone_arbiter(**arbiter_kwargs) arbiter = create_arbiter(**arbiter_kwargs)
arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True) arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True)
if enable_planner: if enable_planner:
arbiter.exit_stack.callback( arbiter.exit_stack.callback(
...@@ -499,7 +310,7 @@ def serve_http( ...@@ -499,7 +310,7 @@ def serve_http(
), ),
), ),
) )
return Server(url=f"{scheme}://{log_host}:{port}", arbiter=arbiter) return CircusRunner(arbiter=arbiter)
except Exception: except Exception:
shutil.rmtree(uds_path, ignore_errors=True) shutil.rmtree(uds_path, ignore_errors=True)
raise raise
...@@ -25,10 +25,9 @@ import os ...@@ -25,10 +25,9 @@ import os
import pathlib import pathlib
import random import random
import socket import socket
import typing as t from typing import Any, DefaultDict, Dict, Iterator, Optional, Protocol, TextIO, Union
import click import click
import psutil
import yaml import yaml
from click import Command, Context from click import Command, Context
...@@ -41,10 +40,24 @@ logger = logging.getLogger(__name__) ...@@ -41,10 +40,24 @@ logger = logging.getLogger(__name__)
DYN_LOCAL_STATE_DIR = "DYN_LOCAL_STATE_DIR" DYN_LOCAL_STATE_DIR = "DYN_LOCAL_STATE_DIR"
# Define a Protocol for services to ensure type safety
class ServiceProtocol(Protocol):
name: str
inner: Any
models: list[Any]
bento: Any
def is_dynamo_component(self) -> bool:
...
def dynamo_address(self) -> tuple[str, str]:
...
class DynamoCommandGroup(click.Group): class DynamoCommandGroup(click.Group):
"""Simplified version of BentoMLCommandGroup for Dynamo CLI""" """Simplified version of BentoMLCommandGroup for Dynamo CLI"""
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
self.aliases = kwargs.pop("aliases", []) self.aliases = kwargs.pop("aliases", [])
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._commands: dict[str, list[str]] = {} self._commands: dict[str, list[str]] = {}
...@@ -101,26 +114,19 @@ class DynamoCommandGroup(click.Group): ...@@ -101,26 +114,19 @@ class DynamoCommandGroup(click.Group):
def reserve_free_port( def reserve_free_port(
host: str = "localhost", host: str = "localhost",
port: int | None = None, port: int | None = None,
prefix: t.Optional[str] = None, prefix: Optional[str] = None,
max_retry: int = 50, max_retry: int = 50,
enable_so_reuseport: bool = False, enable_so_reuseport: bool = False,
) -> t.Iterator[int]: ) -> Iterator[int]:
""" """
detect free port and reserve until exit the context detect free port and reserve until exit the context
""" """
import psutil
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if enable_so_reuseport: if enable_so_reuseport:
if psutil.WINDOWS:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
elif psutil.MACOS or psutil.FREEBSD:
sock.setsockopt(socket.SOL_SOCKET, 0x10000, 1) # SO_REUSEPORT_LB
else:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0: if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.") from None raise RuntimeError("Failed to set SO_REUSEPORT.") from None
if prefix is not None: if prefix is not None:
prefix_num = int(prefix) * 10 ** (5 - len(prefix)) prefix_num = int(prefix) * 10 ** (5 - len(prefix))
suffix_range = min(65535 - prefix_num, 10 ** (5 - len(prefix))) suffix_range = min(65535 - prefix_num, 10 ** (5 - len(prefix)))
...@@ -147,29 +153,11 @@ def reserve_free_port( ...@@ -147,29 +153,11 @@ def reserve_free_port(
sock.close() sock.close()
def path_to_uri(path: str) -> str:
"""
Convert a path to a URI.
Args:
path: Path to convert to URI.
Returns:
URI string. (quoted, absolute)
"""
path = os.path.abspath(path)
if psutil.WINDOWS:
return pathlib.PureWindowsPath(path).as_uri()
if psutil.POSIX:
return pathlib.PurePosixPath(path).as_uri()
raise ValueError("Unsupported OS")
def save_dynamo_state( def save_dynamo_state(
namespace: str, namespace: str,
circus_endpoint: str, circus_endpoint: str,
components: dict[str, t.Any], components: dict[str, Any],
environment: dict[str, t.Any], environment: dict[str, Any],
): ):
state_dir = os.environ.get( state_dir = os.environ.get(
DYN_LOCAL_STATE_DIR, os.path.expanduser("~/.dynamo/state") DYN_LOCAL_STATE_DIR, os.path.expanduser("~/.dynamo/state")
...@@ -192,7 +180,7 @@ def save_dynamo_state( ...@@ -192,7 +180,7 @@ def save_dynamo_state(
logger.warning(f"Saved state to {state_file}") logger.warning(f"Saved state to {state_file}")
def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]: def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, Any]:
"""Parse a single CLI argument into service name, key, and value.""" """Parse a single CLI argument into service name, key, and value."""
parts = arg_name.split(".") parts = arg_name.split(".")
...@@ -205,7 +193,7 @@ def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]: ...@@ -205,7 +193,7 @@ def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]:
and nested_keys[0] == "ServiceArgs" and nested_keys[0] == "ServiceArgs"
and nested_keys[1] == "envs" and nested_keys[1] == "envs"
): ):
value: t.Union[str, int, float, bool, dict, list] = arg_value value: Union[str, int, float, bool, dict, list] = arg_value
else: else:
# Parse value based on type for non-env vars # Parse value based on type for non-env vars
try: try:
...@@ -228,12 +216,10 @@ def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]: ...@@ -228,12 +216,10 @@ def _parse_service_arg(arg_name: str, arg_value: str) -> tuple[str, str, t.Any]:
return service, nested_keys[0], result return service, nested_keys[0], result
def _parse_service_args(args: list[str]) -> t.Dict[str, t.Any]: def _parse_service_args(args: list[str]) -> Dict[str, Any]:
service_configs: t.DefaultDict[str, t.Dict[str, t.Any]] = collections.defaultdict( service_configs: DefaultDict[str, Dict[str, Any]] = collections.defaultdict(dict)
dict
)
def deep_update(d: dict, key: str, value: t.Any): def deep_update(d: dict, key: str, value: Any):
""" """
Recursively updates nested dictionaries. We use this to process arguments like Recursively updates nested dictionaries. We use this to process arguments like
...@@ -283,9 +269,9 @@ def _parse_service_args(args: list[str]) -> t.Dict[str, t.Any]: ...@@ -283,9 +269,9 @@ def _parse_service_args(args: list[str]) -> t.Dict[str, t.Any]:
def resolve_service_config( def resolve_service_config(
config_file: pathlib.Path | t.TextIO | None = None, config_file: pathlib.Path | TextIO | None = None,
args: list[str] | None = None, args: list[str] | None = None,
) -> dict[str, dict[str, t.Any]]: ) -> dict[str, dict[str, Any]]:
"""Resolve service configuration from file and command line arguments. """Resolve service configuration from file and command line arguments.
Args: Args:
...@@ -295,7 +281,7 @@ def resolve_service_config( ...@@ -295,7 +281,7 @@ def resolve_service_config(
Returns: Returns:
Dictionary mapping service names to their configurations Dictionary mapping service names to their configurations
""" """
service_configs: dict[str, dict[str, t.Any]] = {} service_configs: dict[str, dict[str, Any]] = {}
# Check for deployment config first # Check for deployment config first
if "DYN_DEPLOYMENT_CONFIG" in os.environ: if "DYN_DEPLOYMENT_CONFIG" in os.environ:
......
...@@ -19,18 +19,17 @@ import typing as t ...@@ -19,18 +19,17 @@ import typing as t
from functools import wraps from functools import wraps
from typing import Any, get_type_hints from typing import Any, get_type_hints
import bentoml
from pydantic import BaseModel from pydantic import BaseModel
class DynamoEndpoint: class DynamoEndpoint:
"""Decorator class for Dynamo endpoints""" """Decorator class for Dynamo endpoints"""
def __init__(self, func: t.Callable, name: str | None = None): def __init__(self, func: t.Callable, name: str | None = None, is_api: bool = False):
self.func = func self.func = func
self.name = name or func.__name__ self.name = name or func.__name__
self.is_dynamo_endpoint = True self.is_dynamo_endpoint = True
self.is_api = is_api
# Extract request type from hints # Extract request type from hints
hints = get_type_hints(func) hints = get_type_hints(func)
args = list(hints.items()) args = list(hints.items())
...@@ -60,11 +59,13 @@ class DynamoEndpoint: ...@@ -60,11 +59,13 @@ class DynamoEndpoint:
def dynamo_endpoint( def dynamo_endpoint(
name: str | None = None, name: str | None = None,
is_api: bool = False,
) -> t.Callable[[t.Callable], DynamoEndpoint]: ) -> t.Callable[[t.Callable], DynamoEndpoint]:
"""Decorator for Dynamo endpoints. """Decorator for Dynamo endpoints.
Args: Args:
name: Optional name for the endpoint. Defaults to function name. name: Optional name for the endpoint. Defaults to function name.
is_api: Whether to expose the endpoint as an API. Defaults to False.
Example: Example:
@dynamo_endpoint() @dynamo_endpoint()
...@@ -77,25 +78,13 @@ def dynamo_endpoint( ...@@ -77,25 +78,13 @@ def dynamo_endpoint(
""" """
def decorator(func: t.Callable) -> DynamoEndpoint: def decorator(func: t.Callable) -> DynamoEndpoint:
return DynamoEndpoint(func, name) return DynamoEndpoint(func, name, is_api)
return decorator return decorator
def dynamo_api(func: t.Callable) -> t.Callable:
"""Decorator for BentoML API endpoints.
Args:
func: The function to be decorated.
Returns:
The decorated function.
"""
return bentoml.api(func)
def async_on_start(func: t.Callable) -> t.Callable: def async_on_start(func: t.Callable) -> t.Callable:
"""Decorator for async onstart functions.""" """Decorator for async onstart functions."""
# Mark the function as a startup hook # Mark the function as a startup hook
setattr(func, "__bentoml_startup_hook__", True) setattr(func, "__dynamo_startup_hook__", True)
return bentoml.on_startup(func) return func
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from http import HTTPStatus
class DynamoException(Exception):
"""Base class for all Dynamo SDK Exception."""
error_code = HTTPStatus.INTERNAL_SERVER_ERROR
error_mapping: dict[HTTPStatus, type[DynamoException]] = {}
def __init_subclass__(cls) -> None:
if "error_code" in cls.__dict__:
cls.error_mapping[cls.error_code] = cls
def __init__(self, message: str, error_code: HTTPStatus | None = None):
super().__init__(message)
self.message = message
self.error_code = error_code or self.error_code
...@@ -13,13 +13,6 @@ ...@@ -13,13 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# wrapper over bento images to handle Dynamo base image
import os import os
import bentoml DYNAMO_IMAGE = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm")
# TODO: "dynamo:latest-vllm-dev" image will not be available to image builder in k8s
# so We'd consider publishing the base image for releases to public nvcr.io registry.
image_name = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm-dev")
DYNAMO_IMAGE = bentoml.images.Image(base_image=image_name)
# SPDX-FileCopyrightText: Copyright (c) 2020 Atalaya Tech. Inc
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
from __future__ import annotations
import importlib
import logging
import os
import sys
from typing import Optional, TypeVar
from dynamo.sdk.lib.service import DynamoService
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=object)
def find_and_load_service(
import_str: str,
working_dir: Optional[str] = None,
) -> DynamoService:
"""Load a DynamoService instance from source code by providing an import string.
Args:
import_str: String in format "module[:attribute]" or "path/to/file.py[:attribute]"
Examples:
"graphs:disagg:Frontend"
"fraud_detector:svc"
"./path/to/service.py:MyService"
"fraud_detector" # Will find the root service if only one exists
working_dir: Optional directory to use as base for imports. Defaults to cwd.
Returns:
The loaded DynamoService instance
Raises:
ImportError: If module cannot be imported
ValueError: If service cannot be found or multiple root services exist
"""
logger.info(f"Loading service from import string: {import_str}")
logger.info(f"Working directory: {working_dir or os.getcwd()}")
sys_path_modified = False
prev_cwd = None
if working_dir is not None:
prev_cwd = os.getcwd()
working_dir = os.path.realpath(os.path.expanduser(working_dir))
logger.info(f"Changing working directory to: {working_dir}")
os.chdir(working_dir)
else:
working_dir = os.getcwd()
if working_dir not in sys.path:
logger.info(f"Adding {working_dir} to sys.path")
sys.path.insert(0, working_dir)
sys_path_modified = True
try:
return _do_import(import_str, working_dir)
finally:
if sys_path_modified and working_dir:
logger.info(f"Removing {working_dir} from sys.path")
sys.path.remove(working_dir)
if prev_cwd is not None:
logger.info(f"Restoring working directory to: {prev_cwd}")
os.chdir(prev_cwd)
def _do_import(import_str: str, working_dir: str) -> DynamoService:
"""Internal function to handle the actual import logic"""
import_path, _, attrs_str = import_str.partition(":")
logger.info(f"Parsed import string - path: {import_path}, attributes: {attrs_str}")
if not import_path:
raise ValueError(
f'Invalid import string "{import_str}", must be in format '
'"<module>:<attribute>" or "<module>"'
)
# Handle file path vs module name imports
if os.path.isfile(import_path):
logger.info(f"Importing from file path: {import_path}")
import_path = os.path.realpath(import_path)
if not import_path.startswith(working_dir):
raise ImportError(
f'Module "{import_path}" not found in working directory "{working_dir}"'
)
file_name, ext = os.path.splitext(import_path)
if ext != ".py":
raise ImportError(
f'Invalid module extension "{ext}", only ".py" files are supported'
)
# Build module name from path components
module_parts = []
path = file_name
while True:
path, name = os.path.split(path)
module_parts.append(name)
if (
not os.path.exists(os.path.join(path, "__init__.py"))
or path == working_dir
):
break
module_name = ".".join(module_parts[::-1])
logger.info(f"Constructed module name from path: {module_name}")
else:
logger.info(f"Importing from module name: {import_path}")
module_name = import_path
try:
logger.info(f"Attempting to import module: {module_name}")
module = importlib.import_module(module_name)
except ImportError as e:
raise ImportError(f'Failed to import module "{module_name}": {e}')
# If no specific attribute given, find the root service
if not attrs_str:
logger.info("No attributes specified, searching for root service")
services = [
(name, obj)
for name, obj in module.__dict__.items()
if isinstance(obj, DynamoService)
]
logger.info(f"Found {len(services)} DynamoService instances")
if not services:
raise ValueError(
f"No DynamoService instances found in module '{module_name}'"
)
# Find root services (those that aren't dependencies of other services)
dependents = set()
for _, svc in services:
for dep in svc.dependencies.values():
if dep.on is not None:
dependents.add(dep.on)
root_services = [(n, s) for n, s in services if s not in dependents]
logger.info(f"Found {len(root_services)} root services")
if not root_services:
raise ValueError(
f"No root DynamoService found in module '{module_name}'. "
"All services are dependencies of other services."
)
if len(root_services) > 1:
names = [n for n, _ in root_services]
raise ValueError(
f"Multiple root services found in module '{module_name}': {names}. "
"Please specify which service to use with '<module>:<service_name>'"
)
_, instance = root_services[0]
logger.info(f"Selected root service: {instance}")
else:
# Navigate through dot-separated attributes
logger.info(f"Navigating attributes: {attrs_str}")
instance = module
for attr in attrs_str.split("."):
try:
if isinstance(instance, DynamoService):
logger.info(f"Following dependency link: {attr}")
instance = instance.dependencies[attr].on
else:
logger.info(f"Getting attribute: {attr}")
instance = getattr(instance, attr)
except (AttributeError, KeyError):
raise ValueError(f'Attribute "{attr}" not found in "{module_name}"')
if not isinstance(instance, DynamoService):
raise ValueError(
f'Object "{attrs_str}" in module "{module_name}" is not a DynamoService'
)
# Set import string for debugging/logging
if not hasattr(instance, "_import_str"):
import_str_val = f"{module_name}:{attrs_str}" if attrs_str else module_name
logger.info(f"Setting _import_str to: {import_str_val}")
object.__setattr__(instance, "_import_str", import_str_val)
return instance
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: this should be used for planner as well and should leverage proper nvml bindings
from __future__ import annotations
import logging
import typing as t
from dataclasses import dataclass
import psutil
try:
import pynvml
PYNVML_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
PYNVML_AVAILABLE = False
logger = logging.getLogger(__name__)
# Constants
NVIDIA_GPU = "nvidia.com/gpu"
class ResourceError(Exception):
"""Base exception for resource-related errors."""
pass
@dataclass
class GPUProcess:
"""Information about a process running on a GPU."""
pid: int
used_memory: int # in bytes
name: str = ""
def __post_init__(self):
"""Get process name if available."""
try:
self.name = psutil.Process(self.pid).name()
except (psutil.NoSuchProcess, psutil.AccessDenied):
pass
class GPUInfo:
"""Information about a specific GPU device."""
def __init__(self, index: int, total_memory: int, name: str, uuid: str):
self.index = index
self.total_memory = total_memory # in bytes
self.name = name
self.uuid = uuid
self.available = True # Can be set to False if GPU is reserved/in use
self.temperature = 0 # in Celsius
self.utilization = 0 # in percent (0-100)
self.processes: list[GPUProcess] = []
def __repr__(self) -> str:
return f"GPUInfo(index={self.index}, name='{self.name}', total_memory={self.total_memory/1024/1024:.0f}MB, available={self.available})"
class GPUManager:
"""
Manages GPU resources using NVML.
This class provides methods to:
- Discover available GPUs
- Query GPU properties and status
- Track GPU processes
- Allocate and release GPUs
- Generate CUDA_VISIBLE_DEVICES environment variables
"""
def __init__(self):
"""Initialize the GPU manager."""
self.gpus: list[GPUInfo] = []
self._initialized = False
# List to track fractional GPU allocations
# Each item is (gpu_index, fraction_used, fraction_size)
# E.g. (0, 0.5, 0.5) means GPU 0 has 0.5 used with fraction size of 0.5
self._gpu_fractions: list[tuple[int, float, float]] = []
self._init_nvml()
def _init_nvml(self):
"""Initialize NVML and discover GPUs."""
if not PYNVML_AVAILABLE:
logger.warning("PyNVML not available. GPU functionality will be limited.")
return
try:
pynvml.nvmlInit()
self._initialized = True
self._discover_gpus()
except (
pynvml.NVMLError_LibraryNotFound,
pynvml.NVMLError_DriverNotLoaded,
OSError,
) as e:
logger.warning(f"Failed to initialize NVML: {e}")
self._initialized = False
def __del__(self):
"""Clean up NVML."""
if self._initialized:
try:
pynvml.nvmlShutdown()
except Exception: # pylint: disable=broad-except
pass
def _discover_gpus(self):
"""Discover available GPUs and their properties."""
if not self._initialized:
return
try:
device_count = pynvml.nvmlDeviceGetCount()
self.gpus = []
for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
name = pynvml.nvmlDeviceGetName(handle)
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
uuid = pynvml.nvmlDeviceGetUUID(handle)
gpu_info = GPUInfo(
index=i, total_memory=memory_info.total, name=name, uuid=uuid
)
# Get additional GPU information if available
try:
gpu_info.temperature = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU
)
except pynvml.NVMLError:
logger.debug(f"Could not get temperature for GPU {i}")
try:
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
gpu_info.utilization = utilization.gpu
except pynvml.NVMLError:
logger.debug(f"Could not get utilization for GPU {i}")
# Get processes running on GPU
try:
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
gpu_info.processes = [
GPUProcess(pid=p.pid, used_memory=p.usedGpuMemory)
for p in processes
]
except pynvml.NVMLError:
logger.debug(f"Could not get processes for GPU {i}")
self.gpus.append(gpu_info)
logger.info(f"Discovered {len(self.gpus)} GPUs")
except pynvml.NVMLError as e:
logger.warning(f"Error discovering GPUs: {e}")
def update_gpu_stats(self):
"""Update GPU statistics (utilization, memory, temperature, etc.)."""
if not self._initialized:
return
for gpu in self.gpus:
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu.index)
# Update memory info
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpu.total_memory = memory_info.total
# Update temperature
try:
gpu.temperature = pynvml.nvmlDeviceGetTemperature(
handle, pynvml.NVML_TEMPERATURE_GPU
)
except pynvml.NVMLError:
pass
# Update utilization
try:
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
gpu.utilization = utilization.gpu
except pynvml.NVMLError:
pass
# Update processes
try:
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
gpu.processes = [
GPUProcess(pid=p.pid, used_memory=p.usedGpuMemory)
for p in processes
]
except pynvml.NVMLError:
pass
except pynvml.NVMLError as e:
logger.warning(f"Error updating GPU {gpu.index} stats: {e}")
def get_gpu_count(self) -> int:
"""Return the number of available GPUs."""
return len(self.gpus)
def get_available_gpus(self) -> list[int]:
"""Return a list of available GPU indices."""
return [gpu.index for gpu in self.gpus if gpu.available]
def get_gpu_memory(self, index: int) -> tuple[int, int]:
"""
Return (total memory, free memory) in bytes for a specific GPU.
Args:
index: GPU index
Returns:
Tuple of (total memory, free memory) in bytes
"""
if not self._initialized or index >= len(self.gpus):
return (0, 0)
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
memory_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return (memory_info.total, memory_info.free)
except pynvml.NVMLError as e:
logger.warning(f"Error getting GPU memory for GPU {index}: {e}")
return (0, 0)
def get_gpu_utilization(self, index: int) -> int:
"""
Return GPU utilization percentage for a specific GPU.
Args:
index: GPU index
Returns:
GPU utilization percentage (0-100)
"""
if not self._initialized or index >= len(self.gpus):
return 0
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
utilization = pynvml.nvmlDeviceGetUtilizationRates(handle)
return utilization.gpu # Returns GPU utilization percentage (0-100)
except pynvml.NVMLError as e:
logger.warning(f"Error getting GPU utilization for GPU {index}: {e}")
return 0
def get_gpu_temperature(self, index: int) -> int:
"""
Return GPU temperature for a specific GPU.
Args:
index: GPU index
Returns:
GPU temperature in Celsius
"""
if not self._initialized or index >= len(self.gpus):
return 0
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
return pynvml.nvmlDeviceGetTemperature(handle, pynvml.NVML_TEMPERATURE_GPU)
except pynvml.NVMLError as e:
logger.warning(f"Error getting GPU temperature for GPU {index}: {e}")
return 0
def get_gpu_processes(self, index: int) -> list[GPUProcess]:
"""
Return processes running on a specific GPU.
Args:
index: GPU index
Returns:
List of processes running on the GPU
"""
if not self._initialized or index >= len(self.gpus):
return []
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
return [
GPUProcess(pid=p.pid, used_memory=p.usedGpuMemory) for p in processes
]
except pynvml.NVMLError as e:
logger.warning(f"Error getting GPU processes for GPU {index}: {e}")
return []
def assign_gpus(self, count: float) -> list[int]:
"""
Assign GPUs for use. It can handle fractional GPU requests.
Args:
count: Number of GPUs to assign (can be fractional)
Returns:
List of GPU indices that were assigned
"""
available_gpus = self.get_available_gpus()
if count > len(available_gpus):
logger.warning(
f"Requested {count} GPUs, but only {len(available_gpus)} are available. "
"Service may fail due to inadequate GPU resources."
)
# Handle fractional GPU allocation
if count < 1:
# Try to find a GPU with the same fraction size
try:
# Find a GPU where we've already used the same fraction size
gpu_idx, used_fraction = next(
(idx, used)
for idx, used, frac_size in self._gpu_fractions
if frac_size == count and used < 1.0
)
# Update the usage for this GPU
for i, (idx, used, frac_size) in enumerate(self._gpu_fractions):
if idx == gpu_idx and frac_size == count:
new_used = used + count
if new_used > 1.0:
new_used = 1.0 # Cap at 1.0
self._gpu_fractions[i] = (idx, new_used, frac_size)
break
return [gpu_idx]
except StopIteration:
# No existing fraction of this size, find a free GPU
if available_gpus:
gpu_idx = available_gpus[0]
self._gpu_fractions.append((gpu_idx, count, count))
return [gpu_idx]
else:
# No available GPUs, return the first GPU (or log warning)
if self.gpus:
logger.warning("No available GPUs, using GPU 0 by default")
self._gpu_fractions.append((0, count, count))
return [0]
else:
logger.error("No GPUs available for allocation")
return []
# Integer GPU allocation
if count >= 1:
if int(count) != count:
raise ResourceError(
"Fractional GPU count greater than 1 is not supported"
)
count_int = int(count)
assigned_gpus = available_gpus[:count_int]
# Mark these GPUs as fully used
for gpu_idx in assigned_gpus:
# Check if this GPU is already in _gpu_fractions
if not any(idx == gpu_idx for idx, _, _ in self._gpu_fractions):
self._gpu_fractions.append((gpu_idx, 1.0, 1.0))
else:
# Update the existing entry
for i, (idx, _, frac_size) in enumerate(self._gpu_fractions):
if idx == gpu_idx:
self._gpu_fractions[i] = (idx, 1.0, frac_size)
# Mark this GPU as unavailable for future requests
for gpu in self.gpus:
if gpu.index == gpu_idx:
gpu.available = False
return assigned_gpus
return []
def get_best_gpu_for_memory(self, required_memory: int) -> int:
"""
Return the index of the GPU with the most available memory that meets the requirement.
Args:
required_memory: Required memory in bytes
Returns:
GPU index, or -1 if no suitable GPU was found
"""
if not self._initialized:
return -1
best_gpu = -1
max_free = 0
for gpu in self.gpus:
if not gpu.available:
continue
_, free = self.get_gpu_memory(gpu.index)
if free > required_memory and free > max_free:
max_free = free
best_gpu = gpu.index
return best_gpu
def reset_allocations(self):
"""Reset all GPU allocations."""
self._gpu_fractions = []
for gpu in self.gpus:
gpu.available = True
def get_gpu_stats(self) -> list[dict[str, t.Any]]:
"""
Get detailed statistics for all GPUs.
Returns:
List of dictionaries with GPU statistics
"""
self.update_gpu_stats()
stats = []
for gpu in self.gpus:
total_memory, free_memory = self.get_gpu_memory(gpu.index)
stats.append(
{
"index": gpu.index,
"name": gpu.name,
"uuid": gpu.uuid,
"total_memory": total_memory,
"free_memory": free_memory,
"used_memory": total_memory - free_memory,
"memory_utilization": (total_memory - free_memory)
/ total_memory
* 100
if total_memory > 0
else 0,
"gpu_utilization": gpu.utilization,
"temperature": gpu.temperature,
"process_count": len(gpu.processes),
"processes": [
{
"pid": process.pid,
"name": process.name,
"used_memory": process.used_memory,
}
for process in gpu.processes
],
"available": gpu.available,
}
)
return stats
def system_resources() -> dict[str, t.Any]:
"""
Get available system resources (CPU and GPU).
Returns:
Dictionary of resources with keys 'cpu' and 'nvidia.com/gpu'
"""
resources = {}
# Get GPU resources
gpu_manager = GPUManager()
resources[NVIDIA_GPU] = gpu_manager.get_available_gpus()
return resources
...@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union ...@@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
from _bentoml_sdk import Service, ServiceConfig from _bentoml_sdk import Service, ServiceConfig
from _bentoml_sdk.images import Image from _bentoml_sdk.images import Image
from _bentoml_sdk.service.config import validate from _bentoml_sdk.service.config import validate
from fastapi import FastAPI
from dynamo.sdk.lib.decorators import DynamoEndpoint from dynamo.sdk.lib.decorators import DynamoEndpoint
...@@ -86,9 +87,11 @@ class DynamoService(Service[T]): ...@@ -86,9 +87,11 @@ class DynamoService(Service[T]):
image: Optional[Image] = None, image: Optional[Image] = None,
envs: Optional[list[dict[str, Any]]] = None, envs: Optional[list[dict[str, Any]]] = None,
dynamo_config: Optional[DynamoConfig] = None, dynamo_config: Optional[DynamoConfig] = None,
app: Optional[FastAPI] = None,
): ):
service_name = inner.__name__ service_name = inner.__name__
service_args = self._get_service_args(service_name) service_args = self._get_service_args(service_name)
self.app = app
if service_args: if service_args:
# Validate and merge service args with existing config # Validate and merge service args with existing config
...@@ -224,14 +227,91 @@ class DynamoService(Service[T]): ...@@ -224,14 +227,91 @@ class DynamoService(Service[T]):
} }
os.environ["DYNAMO_SERVICE_ENVS"] = json.dumps(envs_config) os.environ["DYNAMO_SERVICE_ENVS"] = json.dumps(envs_config)
def inject_config(self) -> None:
"""Inject configuration from environment into service configs.
This reads from DYNAMO_SERVICE_CONFIG environment variable and merges
the configuration with any existing service config.
"""
# Get service configs from environment
service_config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if not service_config_str:
logger.debug("No DYNAMO_SERVICE_CONFIG found in environment")
return
try:
service_configs = json.loads(service_config_str)
logger.debug(f"Loaded service configs: {service_configs}")
except json.JSONDecodeError as e:
logger.error(f"Failed to parse DYNAMO_SERVICE_CONFIG: {e}")
return
# Store the entire config at class level
if not hasattr(DynamoService, "_global_service_configs"):
setattr(DynamoService, "_global_service_configs", {})
DynamoService._global_service_configs = service_configs
# Process ServiceArgs for all services
all_services = self.all_services()
logger.debug(f"Processing configs for services: {list(all_services.keys())}")
for name, svc in all_services.items():
if name in service_configs:
svc_config = service_configs[name]
# Extract ServiceArgs if present
if "ServiceArgs" in svc_config:
logger.debug(
f"Found ServiceArgs for {name}: {svc_config['ServiceArgs']}"
)
if not hasattr(svc, "_service_args"):
object.__setattr__(svc, "_service_args", {})
svc._service_args = svc_config["ServiceArgs"]
else:
logger.debug(f"No ServiceArgs found for {name}")
# Set default config
if not hasattr(svc, "_service_args"):
object.__setattr__(svc, "_service_args", {"workers": 1})
def get_service_configs(self) -> Dict[str, Dict[str, Any]]:
"""Get the service configurations for resource allocation.
Returns:
Dict mapping service names to their configs
"""
# Get all services in the dependency chain
all_services = self.all_services()
result = {}
# If we have global configs, use them to build service configs
if hasattr(DynamoService, "_global_service_configs"):
for name, svc in all_services.items():
# Start with default config
config = {"workers": 1}
# If service has specific args, use them
if hasattr(svc, "_service_args"):
config.update(svc._service_args)
# If there are global configs for this service, get ServiceArgs
if name in DynamoService._global_service_configs:
svc_config = DynamoService._global_service_configs[name]
if "ServiceArgs" in svc_config:
config.update(svc_config["ServiceArgs"])
result[name] = config
logger.debug(f"Built config for {name}: {config}")
return result
def service( def service(
inner: Optional[type[T]] = None, inner: Optional[type[T]] = None,
/, /,
*, *,
image: Optional[Image] = None, image: Optional[str] = None,
envs: Optional[list[dict[str, Any]]] = None, envs: Optional[list[dict[str, Any]]] = None,
dynamo: Optional[Union[Dict[str, Any], DynamoConfig]] = None, dynamo: Optional[Union[Dict[str, Any], DynamoConfig]] = None,
app: Optional[FastAPI] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Enhanced service decorator that supports Dynamo configuration """Enhanced service decorator that supports Dynamo configuration
...@@ -262,6 +342,7 @@ def service( ...@@ -262,6 +342,7 @@ def service(
image=image, image=image,
envs=envs or [], envs=envs or [],
dynamo_config=dynamo_config, dynamo_config=dynamo_config,
app=app,
) )
return decorator(inner) if inner is not None else decorator return decorator(inner) if inner is not None else decorator
...@@ -13,13 +13,15 @@ ...@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# This is a simple example of a pipeline that uses Dynamo to deploy a backend, middle, and frontend service. Use this to test # This is a simple example of a pipeline that uses Dynamo to deploy a backend, middle, and frontend service.
# changes made to CLI, SDK, etc # Use this to test changes made to CLI, SDK, etc
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.sdk import api, depends, dynamo_endpoint, service from dynamo.sdk import depends, dynamo_endpoint, service
""" """
Pipeline Architecture: Pipeline Architecture:
...@@ -54,6 +56,9 @@ class ResponseType(BaseModel): ...@@ -54,6 +56,9 @@ class ResponseType(BaseModel):
GPU_ENABLED = False GPU_ENABLED = False
app = FastAPI(title="Hello World!")
@service( @service(
resources={"cpu": "1"}, resources={"cpu": "1"},
traffic={"timeout": 30}, traffic={"timeout": 30},
...@@ -130,7 +135,12 @@ class Middle: ...@@ -130,7 +135,12 @@ class Middle:
yield f"Frontend: {back_resp}" yield f"Frontend: {back_resp}"
@service(resources={"cpu": "1"}, traffic={"timeout": 60}) @service(
resources={"cpu": "1"},
traffic={"timeout": 60},
dynamo={"enabled": True, "namespace": "inference"},
app=app,
)
class Frontend: class Frontend:
middle = depends(Middle) middle = depends(Middle)
backend = depends(Backend) backend = depends(Backend)
...@@ -138,13 +148,13 @@ class Frontend: ...@@ -138,13 +148,13 @@ class Frontend:
def __init__(self) -> None: def __init__(self) -> None:
print("Starting frontend") print("Starting frontend")
@api @dynamo_endpoint(is_api=True)
async def generate(self, text): async def generate(self, request: RequestType):
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
print(f"Frontend received: {text}") print(f"Frontend received: {request.text}")
print(f"Frontend received type: {type(text)}")
txt = RequestType(text=text) async def content_generator():
print(f"Frontend sending: {type(txt)}") async for response in self.middle.generate(request.model_dump_json()):
async for mid_resp in self.middle.generate(txt.model_dump_json()): yield f"Frontend: {response}"
print(f"Frontend received mid_resp: {mid_resp}")
yield f"Frontend: {mid_resp}" return StreamingResponse(content_generator())
...@@ -94,7 +94,7 @@ async def test_pipeline(setup_and_teardown): ...@@ -94,7 +94,7 @@ async def test_pipeline(setup_and_teardown):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
"http://localhost:3000/generate", "http://localhost:8000/generate",
json={"text": "federer-is-the-greatest-tennis-player-of-all-time"}, json={"text": "federer-is-the-greatest-tennis-player-of-all-time"},
headers={"accept": "text/event-stream"}, headers={"accept": "text/event-stream"},
) as resp: ) as resp:
......
...@@ -79,7 +79,7 @@ The `dynamo serve` command deploys the entire service graph, automatically handl ...@@ -79,7 +79,7 @@ The `dynamo serve` command deploys the entire service graph, automatically handl
```bash ```bash
curl -X 'POST' \ curl -X 'POST' \
'http://localhost:3000/generate' \ 'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \ -H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{ -d '{
...@@ -173,10 +173,10 @@ Once you create the Dynamo deployment, a pod prefixed with `yatai-dynamonim-imag ...@@ -173,10 +173,10 @@ Once you create the Dynamo deployment, a pod prefixed with `yatai-dynamonim-imag
```bash ```bash
# Forward the service port to localhost # Forward the service port to localhost
kubectl -n ${KUBE_NS} port-forward svc/${HELM_RELEASE}-frontend 3000:3000 kubectl -n ${KUBE_NS} port-forward svc/${HELM_RELEASE}-frontend 8000:8000
# Test the API endpoint # Test the API endpoint
curl -X 'POST' 'http://localhost:3000/generate' \ curl -X 'POST' 'http://localhost:8000/generate' \
-H 'accept: text/event-stream' \ -H 'accept: text/event-stream' \
-H 'Content-Type: application/json' \ -H 'Content-Type: application/json' \
-d '{"text": "test"}' -d '{"text": "test"}'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment