Commit 3136b716 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat: save `dynamo serve` local state for planner and cleanups (#560)

parent 07f2f0ad
...@@ -28,8 +28,8 @@ from bentoml.exceptions import BentoMLConfigException ...@@ -28,8 +28,8 @@ from bentoml.exceptions import BentoMLConfigException
from simple_di import Provide, inject from simple_di import Provide, inject
NVIDIA_GPU = "nvidia.com/gpu" NVIDIA_GPU = "nvidia.com/gpu"
DISABLE_GPU_ALLOCATION_ENV = "DYNAMO_DISABLE_GPU_ALLOCATION" DYN_DISABLE_AUTO_GPU_ALLOCATION = "DYN_DISABLE_AUTO_GPU_ALLOCATION"
DYNAMO_DEPLOYMENT_ENV = "DYNAMO_DEPLOYMENT_ENV" DYN_DEPLOYMENT_ENV = "DYN_DEPLOYMENT_ENV"
class ResourceAllocator: class ResourceAllocator:
...@@ -45,7 +45,7 @@ class ResourceAllocator: ...@@ -45,7 +45,7 @@ class ResourceAllocator:
if count > self.remaining_gpus: if count > self.remaining_gpus:
warnings.warn( warnings.warn(
f"Requested {count} GPUs, but only {self.remaining_gpus} are remaining. " f"Requested {count} GPUs, but only {self.remaining_gpus} are remaining. "
f"Serving may fail due to inadequate GPUs. Set {DISABLE_GPU_ALLOCATION_ENV}=1 " f"Serving may fail due to inadequate GPUs. Set {DYN_DISABLE_AUTO_GPU_ALLOCATION}=1 "
"to disable automatic allocation and allocate GPUs manually.", "to disable automatic allocation and allocate GPUs manually.",
ResourceWarning, ResourceWarning,
stacklevel=3, stacklevel=3,
...@@ -117,8 +117,8 @@ class ResourceAllocator: ...@@ -117,8 +117,8 @@ class ResourceAllocator:
return num_workers, resource_envs return num_workers, resource_envs
else: # workers is a number else: # workers is a number
num_workers = workers num_workers = workers
if num_gpus and DISABLE_GPU_ALLOCATION_ENV not in os.environ: if num_gpus and DYN_DISABLE_AUTO_GPU_ALLOCATION not in os.environ:
if os.environ.get(DYNAMO_DEPLOYMENT_ENV): if os.environ.get(DYN_DEPLOYMENT_ENV):
# K8s replicas: Assumes DYNAMO_DEPLOYMENT_ENV is set # K8s replicas: Assumes DYNAMO_DEPLOYMENT_ENV is set
# each pod in replicaset will have separate GPU with same CUDA_VISIBLE_DEVICES # each pod in replicaset will have separate GPU with same CUDA_VISIBLE_DEVICES
assigned = self.assign_gpus(num_gpus) assigned = self.assign_gpus(num_gpus)
......
...@@ -189,6 +189,12 @@ def build_serve_command() -> click.Group: ...@@ -189,6 +189,12 @@ def build_serve_command() -> click.Group:
help="Print the final service configuration and exit without starting the server", help="Print the final service configuration and exit without starting the server",
default=False, default=False,
) )
@click.option(
"--enable-planner",
is_flag=True,
help="Save a snapshot of your service state to a file that allows planner to edit your deployment configuration",
default=False,
)
@click.pass_context @click.pass_context
def serve( def serve(
ctx: click.Context, ctx: click.Context,
...@@ -200,6 +206,7 @@ def build_serve_command() -> click.Group: ...@@ -200,6 +206,7 @@ def build_serve_command() -> click.Group:
host: str, host: str,
file: str | None, file: str | None,
working_dir: str | None, working_dir: str | None,
enable_planner: bool,
**attrs: t.Any, **attrs: t.Any,
) -> None: ) -> None:
"""Locally run connected Dynamo services. You can pass service-specific configuration options using --ServiceName.param=value format.""" """Locally run connected Dynamo services. You can pass service-specific configuration options using --ServiceName.param=value format."""
...@@ -270,6 +277,7 @@ def build_serve_command() -> click.Group: ...@@ -270,6 +277,7 @@ def build_serve_command() -> click.Group:
port=port, port=port,
dependency_map=runner_map_dict, dependency_map=runner_map_dict,
service_name=service_name, service_name=service_name,
enable_planner=enable_planner,
) )
return cli return cli
......
...@@ -44,7 +44,12 @@ from circus.watcher import Watcher ...@@ -44,7 +44,12 @@ from circus.watcher import Watcher
from simple_di import Provide, inject from simple_di import Provide, inject
from .allocator import ResourceAllocator from .allocator import ResourceAllocator
from .utils import path_to_uri, reserve_free_port from .utils import (
DYN_LOCAL_STATE_DIR,
path_to_uri,
reserve_free_port,
save_dynamo_state,
)
# Define a Protocol for services to ensure type safety # Define a Protocol for services to ensure type safety
...@@ -57,6 +62,9 @@ class ServiceProtocol(Protocol): ...@@ -57,6 +62,9 @@ class ServiceProtocol(Protocol):
def is_dynamo_component(self) -> bool: def is_dynamo_component(self) -> bool:
... ...
def dynamo_address(self) -> tuple[str, str]:
...
# Use Protocol as the base for type alias # Use Protocol as the base for type alias
AnyService = TypeVar("AnyService", bound=ServiceProtocol) AnyService = TypeVar("AnyService", bound=ServiceProtocol)
...@@ -115,7 +123,37 @@ else: ...@@ -115,7 +123,37 @@ else:
# WARNING: internal # WARNING: internal
_SERVICE_WORKER_SCRIPT = "_bentoml_impl.worker.service" _BENTO_WORKER_SCRIPT = "_bentoml_impl.worker.service"
_DYNAMO_WORKER_SCRIPT = "dynamo.sdk.cli.serve_dynamo"
def _get_dynamo_worker_script(bento_identifier: str, svc_name: str) -> list[str]:
args = [
"-m",
_DYNAMO_WORKER_SCRIPT,
bento_identifier,
"--service-name",
svc_name,
"--worker-id",
"$(CIRCUS.WID)",
]
return args
def _get_bento_worker_script(bento_identifier: str, svc_name: str) -> list[str]:
args = [
"-m",
_BENTO_WORKER_SCRIPT,
bento_identifier,
"--service-name",
svc_name,
"--fd",
f"$(circus.sockets.{svc_name})",
"--worker-id",
"$(CIRCUS.WID)",
]
return args
def create_dependency_watcher( def create_dependency_watcher(
...@@ -131,18 +169,7 @@ def create_dependency_watcher( ...@@ -131,18 +169,7 @@ def create_dependency_watcher(
num_workers, resource_envs = scheduler.get_resource_envs(svc) num_workers, resource_envs = scheduler.get_resource_envs(svc)
uri, socket = _get_server_socket(svc, uds_path, port_stack) uri, socket = _get_server_socket(svc, uds_path, port_stack)
args = [ args = _get_bento_worker_script(bento_identifier, svc.name)
"-m",
_SERVICE_WORKER_SCRIPT,
bento_identifier,
"--service-name",
svc.name,
"--fd",
f"$(circus.sockets.{svc.name})",
"--worker-id",
"$(CIRCUS.WID)",
]
if resource_envs: if resource_envs:
args.extend(["--worker-env", json.dumps(resource_envs)]) args.extend(["--worker-env", json.dumps(resource_envs)])
...@@ -168,23 +195,9 @@ def create_dynamo_watcher( ...@@ -168,23 +195,9 @@ def create_dynamo_watcher(
"""Create a watcher for a Dynamo service in the dependency graph""" """Create a watcher for a Dynamo service in the dependency graph"""
from bentoml.serving import create_watcher from bentoml.serving import create_watcher
# Get socket for this service
uri, socket = _get_server_socket(svc, uds_path, port_stack)
# Get worker configuration
num_workers, resource_envs = scheduler.get_resource_envs(svc) num_workers, resource_envs = scheduler.get_resource_envs(svc)
uri, socket = _get_server_socket(svc, uds_path, port_stack)
# Create Dynamo-specific worker args args = _get_dynamo_worker_script(bento_identifier, svc.name)
args = [
"-m",
"dynamo.sdk.cli.serve_dynamo", # Use our Dynamo worker module
bento_identifier,
"--service-name",
svc.name,
"--worker-id",
"$(CIRCUS.WID)",
]
if resource_envs: if resource_envs:
args.extend(["--worker-env", json.dumps(resource_envs)]) args.extend(["--worker-env", json.dumps(resource_envs)])
...@@ -209,41 +222,21 @@ def create_dynamo_watcher( ...@@ -209,41 +222,21 @@ def create_dynamo_watcher(
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}") logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}")
# use namespace from the service
namespace, _ = svc.dynamo_address()
# Create the watcher with updated environment # Create the watcher with updated environment
watcher = create_watcher( watcher = create_watcher(
name=f"dynamo_service_{svc.name}", name=f"{namespace}_{svc.name}",
args=args, args=args,
numprocesses=num_workers, numprocesses=num_workers,
working_dir=working_dir, working_dir=working_dir,
env=worker_env, env=worker_env,
) )
return watcher, socket, uri logger.info(f"Created watcher for {svc.name}'s in the {namespace} namespace")
@inject
def server_on_deployment(
svc: ServiceProtocol, result_file: str = Provide[BentoMLContainer.result_store_file]
) -> None:
# Resolve models before server starts.
if hasattr(svc, "bento") and (bento := getattr(svc, "bento")):
for model in bento.info.all_models:
model.to_model().resolve()
elif hasattr(svc, "models"):
for model in svc.models:
model.resolve()
if hasattr(svc, "inner"): return watcher, socket, uri
inner = svc.inner
for name in dir(inner):
member = getattr(inner, name)
if callable(member) and getattr(
member, "__bentoml_deployment_hook__", False
):
member()
if os.path.exists(result_file):
os.remove(result_file)
@inject(squeeze_none=True) @inject(squeeze_none=True)
...@@ -254,6 +247,7 @@ def serve_http( ...@@ -254,6 +247,7 @@ def serve_http(
port: int = Provide[BentoMLContainer.http.port], port: int = Provide[BentoMLContainer.http.port],
dependency_map: dict[str, str] | None = None, dependency_map: dict[str, str] | None = None,
service_name: str = "", service_name: str = "",
enable_planner: bool = False,
) -> Server: ) -> Server:
# WARNING: internal # WARNING: internal
from _bentoml_impl.loader import load from _bentoml_impl.loader import load
...@@ -263,9 +257,14 @@ def serve_http( ...@@ -263,9 +257,14 @@ def serve_http(
from bentoml.serving import create_watcher from bentoml.serving import create_watcher
from circus.sockets import CircusSocket from circus.sockets import CircusSocket
from dynamo.sdk.lib.logging import configure_server_logging
from .allocator import ResourceAllocator from .allocator import ResourceAllocator
configure_server_logging()
bento_id: str = "" bento_id: str = ""
namespace: str = ""
env: dict[str, Any] = {} env: dict[str, Any] = {}
if isinstance(bento_identifier, Service): if isinstance(bento_identifier, Service):
svc = bento_identifier svc = bento_identifier
...@@ -296,7 +295,6 @@ def serve_http( ...@@ -296,7 +295,6 @@ def serve_http(
if service_name and service_name != svc.name: if service_name and service_name != svc.name:
svc = svc.find_dependent_by_name(service_name) svc = svc.find_dependent_by_name(service_name)
num_workers, resource_envs = allocator.get_resource_envs(svc) num_workers, resource_envs = allocator.get_resource_envs(svc)
server_on_deployment(svc)
uds_path = tempfile.mkdtemp(prefix="bentoml-uds-") uds_path = tempfile.mkdtemp(prefix="bentoml-uds-")
try: try:
if not service_name and not standalone: if not service_name and not standalone:
...@@ -321,6 +319,7 @@ def serve_http( ...@@ -321,6 +319,7 @@ def serve_http(
str(bento_path.absolute()), str(bento_path.absolute()),
env=env, env=env,
) )
namespace, _ = dep_svc.dynamo_address()
else: else:
# Regular BentoML service # Regular BentoML service
new_watcher, new_socket, uri = create_dependency_watcher( new_watcher, new_socket, uri = create_dependency_watcher(
...@@ -336,7 +335,6 @@ def serve_http( ...@@ -336,7 +335,6 @@ def serve_http(
watchers.append(new_watcher) watchers.append(new_watcher)
sockets.append(new_socket) sockets.append(new_socket)
dependency_map[name] = uri dependency_map[name] = uri
server_on_deployment(dep_svc)
# reserve one more to avoid conflicts # reserve one more to avoid conflicts
port_stack.enter_context(reserve_free_port()) port_stack.enter_context(reserve_free_port())
...@@ -365,7 +363,7 @@ def serve_http( ...@@ -365,7 +363,7 @@ def serve_http(
server_args = [ server_args = [
"-m", "-m",
_SERVICE_WORKER_SCRIPT, _BENTO_WORKER_SCRIPT,
bento_identifier, bento_identifier,
"--fd", "--fd",
f"$(circus.sockets.{API_SERVER_NAME})", f"$(circus.sockets.{API_SERVER_NAME})",
...@@ -374,6 +372,15 @@ def serve_http( ...@@ -374,6 +372,15 @@ def serve_http(
"--worker-id", "--worker-id",
"$(CIRCUS.WID)", "$(CIRCUS.WID)",
] ]
dynamo_args = [
"-m",
_DYNAMO_WORKER_SCRIPT,
bento_identifier,
"--service-name",
svc.name,
"--worker-id",
"$(CIRCUS.WID)",
]
if resource_envs: if resource_envs:
server_args.extend(["--worker-env", json.dumps(resource_envs)]) server_args.extend(["--worker-env", json.dumps(resource_envs)])
...@@ -381,20 +388,10 @@ def serve_http( ...@@ -381,20 +388,10 @@ def serve_http(
# Check if this is a Dynamo service # Check if this is a Dynamo service
if hasattr(svc, "is_dynamo_component") and svc.is_dynamo_component(): if hasattr(svc, "is_dynamo_component") and svc.is_dynamo_component():
# Create Dynamo-specific watcher using existing socket
args = [
"-m",
"dynamo.sdk.cli.serve_dynamo", # Use our Dynamo worker module
bento_identifier,
"--service-name",
svc.name,
"--worker-id",
"$(CIRCUS.WID)",
]
# resource_envs is the resource allocation (ie CUDA_VISIBLE_DEVICES) for each worker created by the allocator # resource_envs is the resource allocation (ie CUDA_VISIBLE_DEVICES) for each worker created by the allocator
# these resource_envs are passed to each individual worker's environment which is set in serve_dynamo # these resource_envs are passed to each individual worker's environment which is set in serve_dynamo
if resource_envs: if resource_envs:
args.extend(["--worker-env", json.dumps(resource_envs)]) dynamo_args.extend(["--worker-env", json.dumps(resource_envs)])
# env is the base bentoml environment variables. We make a copy and update it to add any service configurations and additional env vars # env is the base bentoml environment variables. We make a copy and update it to add any service configurations and additional env vars
worker_env = env.copy() if env else {} worker_env = env.copy() if env else {}
...@@ -416,16 +413,17 @@ def serve_http( ...@@ -416,16 +413,17 @@ def serve_http(
logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}") logger.warning(f"Failed to parse DYNAMO_SERVICE_ENVS: {e}")
watcher = create_watcher( watcher = create_watcher(
name=f"dynamo_service_{svc.name}", name=f"{namespace}_{svc.name}",
args=args, args=dynamo_args,
numprocesses=num_workers, numprocesses=num_workers,
working_dir=str(bento_path.absolute()), working_dir=str(bento_path.absolute()),
env=worker_env, # Dependency map will be injected by serve_http env=worker_env,
) )
watchers.append(watcher) watchers.append(watcher)
logger.info(f"dynamo_service_{svc.name} entrypoint created") logger.info(
f"Created watcher for {svc.name}'s in the {namespace} namespace"
)
else: else:
# Create regular BentoML service watcher
watchers.append( watchers.append(
create_watcher( create_watcher(
name="service", name="service",
...@@ -455,6 +453,32 @@ def serve_http( ...@@ -455,6 +453,32 @@ def serve_http(
arbiter = create_standalone_arbiter(**arbiter_kwargs) arbiter = create_standalone_arbiter(**arbiter_kwargs)
arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True) arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True)
if enable_planner:
arbiter.exit_stack.callback(
shutil.rmtree,
os.environ.get(
DYN_LOCAL_STATE_DIR, os.path.expanduser("~/.dynamo/state")
),
ignore_errors=True,
)
logger.warn(f"arbiter: {arbiter.endpoint}")
# save deployment state for planner
if not namespace:
raise ValueError("No namespace found for service")
save_dynamo_state(
namespace,
arbiter.endpoint,
components={
watcher.name: {
"watcher_name": watcher.name,
"cmd": watcher.cmd + " ".join(watcher.args),
}
for watcher in watchers
},
environment={
"DYNAMO_SERVICE_CONFIG": os.environ["DYNAMO_SERVICE_CONFIG"],
},
)
arbiter.start( arbiter.start(
cb=lambda _: logger.info( # type: ignore cb=lambda _: logger.info( # type: ignore
( (
......
...@@ -16,6 +16,8 @@ ...@@ -16,6 +16,8 @@
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES # Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import contextlib import contextlib
import json
import logging
import os import os
import pathlib import pathlib
import random import random
...@@ -26,6 +28,10 @@ import click ...@@ -26,6 +28,10 @@ import click
import psutil import psutil
from click import Command, Context from click import Command, Context
logger = logging.getLogger(__name__)
DYN_LOCAL_STATE_DIR = "DYN_LOCAL_STATE_DIR"
class DynamoCommandGroup(click.Group): class DynamoCommandGroup(click.Group):
"""Simplified version of BentoMLCommandGroup for Dynamo CLI""" """Simplified version of BentoMLCommandGroup for Dynamo CLI"""
...@@ -149,3 +155,30 @@ def path_to_uri(path: str) -> str: ...@@ -149,3 +155,30 @@ def path_to_uri(path: str) -> str:
if psutil.POSIX: if psutil.POSIX:
return pathlib.PurePosixPath(path).as_uri() return pathlib.PurePosixPath(path).as_uri()
raise ValueError("Unsupported OS") raise ValueError("Unsupported OS")
def save_dynamo_state(
namespace: str,
circus_endpoint: str,
components: dict[str, t.Any],
environment: dict[str, t.Any],
):
state_dir = os.environ.get(
DYN_LOCAL_STATE_DIR, os.path.expanduser("~/.dynamo/state")
)
os.makedirs(state_dir, exist_ok=True)
# create the state object
state = {
"namespace": namespace,
"circus_endpoint": circus_endpoint,
"components": components,
"environment": environment,
}
# save the state object to a file
state_file = os.path.join(state_dir, f"{namespace}.json")
with open(state_file, "w") as f:
json.dump(state, f)
logger.warning(f"Saved state to {state_file}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment