Commit 07a1a8a1 authored by Biswa Ranjan Panda's avatar Biswa Ranjan Panda
Browse files

feat: Add compound AI python SDK

parent 2d906fb4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
[project]
name = "compoundai"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
authors = [
{ name = "mabdulwahhab", email = "mabdulwahhab@nvidia.com" }
]
requires-python = "==3.10.*"
dependencies = [
"bentoml>=1.4.1",
"types-psutil==7.0.0.20250218",
]
[project.scripts]
compoundai = "compoundai.cli.cli:cli"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import click
import psutil
def create_bentoml_cli() -> click.Command:
from bentoml._internal.configuration import BENTOML_VERSION
from bentoml._internal.context import server_context
from bentoml_cli.bentos import bento_command
from bentoml_cli.cloud import cloud_command
from bentoml_cli.containerize import containerize_command
from bentoml_cli.deployment import (
deploy_command,
deployment_command,
develop_command,
)
from bentoml_cli.env import env_command
from bentoml_cli.models import model_command
from bentoml_cli.secret import secret_command
from bentoml_cli.utils import BentoMLCommandGroup, get_entry_points
from compoundai.cli.serve import serve_command
from compoundai.cli.start import start_command
server_context.service_type = "cli"
CONTEXT_SETTINGS = {"help_option_names": ("-h", "--help")}
@click.group(cls=BentoMLCommandGroup, context_settings=CONTEXT_SETTINGS)
@click.version_option(BENTOML_VERSION, "-v", "--version")
def bentoml_cli(): # TODO: to be renamed to something....
""" """
# Add top-level CLI commands
bentoml_cli.add_command(env_command)
bentoml_cli.add_command(cloud_command)
bentoml_cli.add_command(model_command)
bentoml_cli.add_subcommands(bento_command)
bentoml_cli.add_subcommands(start_command)
bentoml_cli.add_subcommands(serve_command)
bentoml_cli.add_command(containerize_command)
bentoml_cli.add_command(deploy_command)
bentoml_cli.add_command(develop_command)
bentoml_cli.add_command(deployment_command)
bentoml_cli.add_command(secret_command)
# Load commands from extensions
for ep in get_entry_points("bentoml.commands"):
bentoml_cli.add_command(ep.load())
if psutil.WINDOWS:
import sys
sys.stdout.reconfigure(encoding="utf-8") # type: ignore
return bentoml_cli
cli = create_bentoml_cli()
if __name__ == "__main__":
cli()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
import os
import sys
import typing as t
import click
import rich
if t.TYPE_CHECKING:
P = t.ParamSpec("P") # type: ignore
F = t.Callable[P, t.Any] # type: ignore
logger = logging.getLogger(__name__)
DEFAULT_DEV_SERVER_HOST = "127.0.0.1"
def deprecated_option(*param_decls: str, **attrs: t.Any):
"""Marks a given options as deprecated, and omit a warning when it's used"""
deprecated = attrs.pop("deprecated", True)
new_behaviour = attrs.pop("current_behaviour", None)
assert new_behaviour is not None, "current_behaviour is required"
def show_deprecated_callback(
ctx: click.Context, param: click.Parameter, value: t.Any
):
if value is not param.default and deprecated:
name = "'--%(name)s'" if attrs.get("is_flag", False) else "'%(name)s'"
DEPRECATION_WARNING = f"[yellow]DeprecationWarning: The parameter {name} is deprecated and will be removed in the future. (Current behaviour: %(new_behaviour)s)[/]"
rich.print(
DEPRECATION_WARNING
% {"name": param.name, "new_behaviour": new_behaviour},
file=sys.stderr,
)
def decorator(f: F[t.Any]) -> t.Callable[[F[t.Any]], click.Command]: # type: ignore
msg = attrs.pop("help", "")
msg += " (Deprecated)" if msg else "(Deprecated)"
attrs.setdefault("help", msg)
attrs.setdefault("callback", show_deprecated_callback)
return click.option(*param_decls, **attrs)(f)
return decorator
def build_serve_command() -> click.Group:
from bentoml._internal.log import configure_server_logging
from bentoml_cli.env_manager import env_manager
from bentoml_cli.utils import AliasCommand, BentoMLCommandGroup
@click.group(name="serve", cls=BentoMLCommandGroup)
def cli():
pass
@cli.command(aliases=["serve-http"], cls=AliasCommand)
@click.argument("bento", type=click.STRING, default=".")
@click.option(
"--development",
type=click.BOOL,
help="Run the BentoServer in development mode",
is_flag=True,
default=False,
show_default=True,
)
@deprecated_option(
"--production",
type=click.BOOL,
help="Run BentoServer in production mode",
current_behaviour="This is enabled by default. To run in development mode, use '--development'.",
is_flag=True,
default=True,
show_default=False,
)
@click.option(
"-p",
"--port",
type=click.INT,
help="The port to listen on for the REST api server",
envvar="BENTOML_PORT",
show_envvar=True,
)
@click.option(
"--host",
type=click.STRING,
help="The host to bind for the REST api server",
envvar="BENTOML_HOST",
show_envvar=True,
)
@click.option(
"--api-workers",
type=click.INT,
help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode",
envvar="BENTOML_API_WORKERS",
show_envvar=True,
hidden=True,
)
@click.option(
"--timeout",
type=click.INT,
help="Specify the timeout (seconds) for API server and runners",
envvar="BENTOML_TIMEOUT",
hidden=True,
)
@click.option(
"--backlog",
type=click.INT,
help="The maximum number of pending connections.",
show_default=True,
hidden=True,
)
@click.option(
"--reload",
type=click.BOOL,
is_flag=True,
help="Reload Service when code changes detected",
default=False,
show_default=True,
)
@click.option(
"--working-dir",
type=click.Path(),
help="When loading from source code, specify the directory to find the Service instance",
default=None,
show_default=True,
)
@click.option(
"--ssl-certfile",
type=str,
help="SSL certificate file",
show_default=True,
hidden=True,
)
@click.option(
"--ssl-keyfile",
type=str,
help="SSL key file",
show_default=True,
hidden=True,
)
@click.option(
"--ssl-keyfile-password",
type=str,
help="SSL keyfile password",
show_default=True,
hidden=True,
)
@click.option(
"--ssl-version",
type=int,
help="SSL version to use (see stdlib 'ssl' module)",
show_default=True,
hidden=True,
)
@click.option(
"--ssl-cert-reqs",
type=int,
help="Whether client certificate is required (see stdlib 'ssl' module)",
show_default=True,
hidden=True,
)
@click.option(
"--ssl-ca-certs",
type=str,
help="CA certificates file",
show_default=True,
hidden=True,
)
@click.option(
"--ssl-ciphers",
type=str,
help="Ciphers to use (see stdlib 'ssl' module)",
show_default=True,
hidden=True,
)
@click.option(
"--timeout-keep-alive",
type=int,
help="Close Keep-Alive connections if no new data is received within this timeout.",
hidden=True,
)
@click.option(
"--timeout-graceful-shutdown",
type=int,
default=None,
help="Maximum number of seconds to wait for graceful shutdown. After this timeout, the server will start terminating requests.",
show_default=True,
hidden=True,
)
@env_manager
def serve(
bento: str,
development: bool,
port: int,
host: str,
api_workers: int,
timeout: int | None,
backlog: int,
reload: bool,
working_dir: str | None,
ssl_certfile: str | None,
ssl_keyfile: str | None,
ssl_keyfile_password: str | None,
ssl_version: int | None,
ssl_cert_reqs: int | None,
ssl_ca_certs: str | None,
ssl_ciphers: str | None,
timeout_keep_alive: int | None,
timeout_graceful_shutdown: int | None,
**attrs: t.Any,
) -> None:
"""Start a HTTP BentoServer from a given 🍱
\b
BENTO is the serving target, it can be the import as:
- the import path of a 'bentoml.Service' instance
- a tag to a Bento in local Bento store
- a folder containing a valid 'bentofile.yaml' build file with a 'service' field, which provides the import path of a 'bentoml.Service' instance
- a path to a built Bento (for internal & debug use only)
e.g.:
\b
Serve from a bentoml.Service instance source code (for development use only):
'bentoml serve fraud_detector.py:svc'
\b
Serve from a Bento built in local store:
'bentoml serve fraud_detector:4tht2icroji6zput3suqi5nl2'
'bentoml serve fraud_detector:latest'
\b
Serve from a Bento directory:
'bentoml serve ./fraud_detector_bento'
\b
If '--reload' is provided, BentoML will detect code and model store changes during development, and restarts the service automatically.
\b
The '--reload' flag will:
- be default, all file changes under '--working-dir' (default to current directory) will trigger a restart
- when specified, respect 'include' and 'exclude' under 'bentofile.yaml' as well as the '.bentoignore' file in '--working-dir', for code and file changes
- all model store changes will also trigger a restart (new model saved or existing model removed)
"""
from bentoml import Service
from bentoml._internal.service.loader import load
configure_server_logging()
if working_dir is None:
if os.path.isdir(os.path.expanduser(bento)):
working_dir = os.path.expanduser(bento)
else:
working_dir = "."
if sys.path[0] != working_dir:
sys.path.insert(0, working_dir)
svc = load(bento_identifier=bento, working_dir=working_dir)
if isinstance(svc, Service):
# bentoml<1.2
from bentoml.serving import serve_http_production
if development:
serve_http_production(
bento,
working_dir=working_dir,
port=port,
host=DEFAULT_DEV_SERVER_HOST if not host else host,
backlog=backlog,
api_workers=1,
timeout=timeout,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
reload=reload,
development_mode=True,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
)
else:
serve_http_production(
bento,
working_dir=working_dir,
port=port,
host=host,
api_workers=api_workers,
timeout=timeout,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
reload=reload,
development_mode=False,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
)
else:
# bentoml>=1.2
# from _bentoml_impl.server import serve_http
from compoundai.cli.serving import serve_http # type: ignore
svc.inject_config()
serve_http(
bento,
working_dir=working_dir,
host=host,
port=port,
backlog=backlog,
timeout=timeout,
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
development_mode=development,
reload=reload,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
)
return cli
serve_command = build_serve_command()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import inspect
import json
import logging
import os
import random
import string
import typing as t
from typing import Any
import click
from triton_distributed_rs import DistributedRuntime, triton_endpoint, triton_worker
logger = logging.getLogger("compoundai.serve.nova")
def generate_run_id():
"""Generate a random 6-character run ID"""
return "".join(random.choices(string.ascii_uppercase + string.digits, k=6))
@click.command()
@click.argument("bento_identifier", type=click.STRING, required=False, default=".")
@click.option("--service-name", type=click.STRING, required=False, default="")
@click.option(
"--runner-map",
type=click.STRING,
envvar="BENTOML_RUNNER_MAP",
help="JSON string of runners map, default sets to envars `BENTOML_RUNNER_MAP`",
)
@click.option(
"--worker-env", type=click.STRING, default=None, help="Environment variables"
)
@click.option(
"--worker-id",
required=False,
type=click.INT,
default=None,
help="If set, start the server as a bare worker with the given worker ID. Otherwise start a standalone server with a supervisor process.",
)
def main(
bento_identifier: str,
service_name: str,
runner_map: str | None,
worker_env: str | None,
worker_id: int | None,
) -> None:
"""Start a worker for the given service - either Nova or regular service"""
from _bentoml_impl.loader import import_service
from bentoml._internal.container import BentoMLContainer
from bentoml._internal.context import server_context
from bentoml._internal.log import configure_server_logging
run_id = generate_run_id()
# print the contents of the environment variable BENTOML_RUNNER_MAP
print(f"BENTOML_RUNNER_MAP: {os.environ['BENTOML_RUNNER_MAP']}")
# Import service first to check configuration
service = import_service(bento_identifier)
if service_name and service_name != service.name:
service = service.find_dependent_by_name(service_name)
# Handle worker environment if specified
if worker_env:
env_list: list[dict[str, t.Any]] = json.loads(worker_env)
if worker_id is not None:
worker_key = worker_id - 1
if worker_key >= len(env_list):
raise IndexError(
f"Worker ID {worker_id} is out of range, "
f"the maximum worker ID is {len(env_list)}"
)
os.environ.update(env_list[worker_key])
configure_server_logging()
if runner_map:
BentoMLContainer.remote_runner_mapping.set(
t.cast(t.Dict[str, str], json.loads(runner_map))
)
# Check if Nova is enabled for this service
if service.is_nova_component():
if worker_id is not None:
server_context.worker_index = worker_id
class_instance = service.inner()
@triton_worker()
async def worker(runtime: DistributedRuntime):
if service_name and service_name != service.name:
server_context.service_type = "service"
else:
server_context.service_type = "entry_service"
server_context.service_name = service.name
# Get Nova configuration and create component
namespace, component_name = service.nova_address()
logger.info(
f"[{run_id}] Registering component {namespace}/{component_name}"
)
component = runtime.namespace(namespace).component(component_name)
try:
# Create service first
await component.create_service()
logger.info(f"[{run_id}] Created {service.name} component")
# Run startup hooks before setting up endpoints
for name, member in vars(class_instance.__class__).items():
if callable(member) and getattr(
member, "__bentoml_startup_hook__", False
):
logger.info(f"[{run_id}] Running startup hook: {name}")
result = getattr(class_instance, name)()
if inspect.isawaitable(result):
await result
logger.info(
f"[{run_id}] Completed async startup hook: {name}"
)
else:
logger.info(f"[{run_id}] Completed startup hook: {name}")
# Set runtime on all dependencies
for dep in service.dependencies.values():
dep.set_runtime(runtime)
logger.info(f"[{run_id}] Set runtime for dependency: {dep}")
# Then register all Nova endpoints
nova_endpoints = service.get_nova_endpoints()
if not nova_endpoints:
error_msg = f"[{run_id}] FATAL ERROR: No Nova endpoints found in service {service.name}!"
logger.error(error_msg)
raise ValueError(error_msg)
print(f"[{run_id}] Nova endpoints: {nova_endpoints}")
for name, endpoint in nova_endpoints.items():
td_endpoint = component.endpoint(name)
logger.info(f"[{run_id}] Registering endpoint '{name}'")
# Bind an instance of inner to the endpoint
bound_method = endpoint.func.__get__(class_instance)
# Only pass request type for now, use Any for response
# TODO: Handle a triton_endpoint not having types
# TODO: Handle multiple endpoints in a single component
triton_wrapped_method = triton_endpoint(endpoint.request_type, Any)(
bound_method
)
result = await td_endpoint.serve_endpoint(triton_wrapped_method)
# WARNING: unreachable code :( because serve blocks
logger.info(f"[{run_id}] Result: {result}")
logger.info(f"[{run_id}] Registered endpoint '{name}'")
logger.info(
f"[{run_id}] Started {service.name} instance with all endpoints registered"
)
logger.info(
f"[{run_id}] Available endpoints: {service.list_nova_endpoints()}"
)
except Exception as e:
logger.error(f"[{run_id}] Error in Nova component setup: {str(e)}")
raise
asyncio.run(worker())
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import ipaddress
import json
import logging
import os
import pathlib
import platform
import shutil
import socket
import tempfile
import typing as t
from typing import Any, Dict, Optional, Protocol, TypeVar
from _bentoml_sdk import Service
from bentoml._internal.container import BentoMLContainer
from bentoml._internal.utils.circus import Server
from bentoml.exceptions import BentoMLConfigException
from circus.sockets import CircusSocket
from circus.watcher import Watcher
from simple_di import Provide, inject
if t.TYPE_CHECKING:
from _bentoml_impl.server.allocator import ResourceAllocator
# Define a Protocol for services to ensure type safety
class ServiceProtocol(Protocol):
name: str
inner: Any
models: list[Any]
bento: Any
def is_nova_component(self) -> bool:
...
# Use Protocol as the base for type alias
AnyService = TypeVar("AnyService", bound=ServiceProtocol)
POSIX = os.name == "posix"
WINDOWS = os.name == "nt"
IS_WSL = "microsoft-standard" in platform.release()
API_SERVER_NAME = "_bento_api_server"
MAX_AF_UNIX_PATH_LENGTH = 103
logger = logging.getLogger("bentoml.serve")
if POSIX and not IS_WSL:
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
backlog: int,
) -> tuple[str, CircusSocket]:
from bentoml._internal.utils.uri import path_to_uri
from circus.sockets import CircusSocket
socket_path = os.path.join(uds_path, f"{id(service)}.sock")
assert len(socket_path) < MAX_AF_UNIX_PATH_LENGTH
return path_to_uri(socket_path), CircusSocket(
name=service.name, path=socket_path, backlog=backlog
)
elif WINDOWS or IS_WSL:
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
backlog: int,
) -> tuple[str, CircusSocket]:
from bentoml._internal.utils import reserve_free_port
from circus.sockets import CircusSocket
runner_port = port_stack.enter_context(reserve_free_port())
runner_host = "127.0.0.1"
return f"tcp://{runner_host}:{runner_port}", CircusSocket(
name=service.name,
host=runner_host,
port=runner_port,
backlog=backlog,
)
else:
def _get_server_socket(
service: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
backlog: int,
) -> tuple[str, CircusSocket]:
from bentoml.exceptions import BentoMLException
raise BentoMLException("Unsupported platform")
_SERVICE_WORKER_SCRIPT = "_bentoml_impl.worker.service"
def create_dependency_watcher(
bento_identifier: str,
svc: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
backlog: int,
scheduler: ResourceAllocator,
working_dir: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
) -> tuple[Watcher, CircusSocket, str]:
from bentoml.serving import create_watcher
num_workers, worker_envs = scheduler.get_worker_env(svc)
uri, socket = _get_server_socket(svc, uds_path, port_stack, backlog)
args = [
"-m",
_SERVICE_WORKER_SCRIPT,
bento_identifier,
"--service-name",
svc.name,
"--fd",
f"$(circus.sockets.{svc.name})",
"--worker-id",
"$(CIRCUS.WID)",
]
if worker_envs:
args.extend(["--worker-env", json.dumps(worker_envs)])
watcher = create_watcher(
name=f"service_{svc.name}",
args=args,
numprocesses=num_workers,
working_dir=working_dir,
env=env,
)
return watcher, socket, uri
def create_nova_watcher(
bento_identifier: str,
svc: ServiceProtocol,
uds_path: str,
port_stack: contextlib.ExitStack,
backlog: int,
scheduler: ResourceAllocator,
working_dir: Optional[str] = None,
env: Optional[Dict[str, str]] = None,
) -> tuple[Watcher, CircusSocket, str]:
"""Create a watcher for a Nova service in the dependency graph"""
from bentoml.serving import create_watcher
# Get socket for this service
uri, socket = _get_server_socket(svc, uds_path, port_stack, backlog)
# Get worker configuration
num_workers, worker_envs = scheduler.get_worker_env(svc)
# Create Nova-specific worker args
args = [
"-m",
"compoundai.cli.serve_nova", # Use our Nova worker module
bento_identifier,
"--service-name",
svc.name,
"--worker-id",
"$(CIRCUS.WID)",
]
if worker_envs:
args.extend(["--worker-env", json.dumps(worker_envs)])
# Create the watcher with dependency map in environment
watcher = create_watcher(
name=f"nova_service_{svc.name}",
args=args,
numprocesses=num_workers,
working_dir=working_dir,
env=env, # Dependency map will be injected by serve_http
)
return watcher, socket, uri
@inject
def server_on_deployment(
svc: ServiceProtocol, result_file: str = Provide[BentoMLContainer.result_store_file]
) -> None:
# Resolve models before server starts.
if hasattr(svc, "bento") and (bento := getattr(svc, "bento")):
for model in bento.info.all_models:
model.to_model().resolve()
elif hasattr(svc, "models"):
for model in svc.models:
model.resolve()
if hasattr(svc, "inner"):
inner = svc.inner
for name in dir(inner):
member = getattr(inner, name)
if callable(member) and getattr(
member, "__bentoml_deployment_hook__", False
):
member()
if os.path.exists(result_file):
os.remove(result_file)
@inject(squeeze_none=True)
def serve_http(
bento_identifier: str | AnyService,
working_dir: str | None = None,
host: str = Provide[BentoMLContainer.http.host],
port: int = Provide[BentoMLContainer.http.port],
backlog: int = Provide[BentoMLContainer.api_server_config.backlog],
timeout: int | None = None,
ssl_certfile: str | None = Provide[BentoMLContainer.ssl.certfile],
ssl_keyfile: str | None = Provide[BentoMLContainer.ssl.keyfile],
ssl_keyfile_password: str | None = Provide[BentoMLContainer.ssl.keyfile_password],
ssl_version: int | None = Provide[BentoMLContainer.ssl.version],
ssl_cert_reqs: int | None = Provide[BentoMLContainer.ssl.cert_reqs],
ssl_ca_certs: str | None = Provide[BentoMLContainer.ssl.ca_certs],
ssl_ciphers: str | None = Provide[BentoMLContainer.ssl.ciphers],
bentoml_home: str = Provide[BentoMLContainer.bentoml_home],
development_mode: bool = False,
reload: bool = False,
timeout_keep_alive: int | None = None,
timeout_graceful_shutdown: int | None = None,
dependency_map: dict[str, str] | None = None,
service_name: str = "",
threaded: bool = False,
) -> Server:
from _bentoml_impl.loader import import_service, normalize_identifier
from _bentoml_impl.server.allocator import ResourceAllocator
from bentoml._internal.log import SERVER_LOGGING_CONFIG
from bentoml._internal.utils import reserve_free_port
from bentoml._internal.utils.analytics.usage_stats import track_serve
from bentoml._internal.utils.circus import create_standalone_arbiter
from bentoml.serving import (
construct_ssl_args,
construct_timeouts_args,
create_watcher,
ensure_prometheus_dir,
make_reload_plugin,
)
from circus.sockets import CircusSocket
bento_id: str = ""
env = {"PROMETHEUS_MULTIPROC_DIR": ensure_prometheus_dir()}
if isinstance(bento_identifier, Service):
svc = bento_identifier
bento_id = svc.import_string
assert (
working_dir is None
), "working_dir should not be set when passing a service in process"
# use cwd
bento_path = pathlib.Path(".")
else:
bento_id, bento_path = normalize_identifier(bento_identifier, working_dir)
svc = import_service(bento_id, bento_path)
watchers: list[Watcher] = []
sockets: list[CircusSocket] = []
allocator = ResourceAllocator()
if dependency_map is None:
dependency_map = {}
# TODO: Only for testing, this will prevent any other dep services from getting started, relying entirely on configured deps in the runner-map
standalone = False
if service_name:
print("Running in standalone mode")
print(f"service_name: {service_name}")
standalone = True
if service_name and service_name != svc.name:
svc = svc.find_dependent_by_name(service_name)
num_workers, worker_envs = allocator.get_worker_env(svc)
server_on_deployment(svc)
uds_path = tempfile.mkdtemp(prefix="bentoml-uds-")
try:
if not service_name and not development_mode and not standalone:
with contextlib.ExitStack() as port_stack:
for name, dep_svc in svc.all_services().items():
if name == svc.name:
continue
if name in dependency_map:
continue
# Check if this is a Nova service
if (
hasattr(dep_svc, "is_nova_component")
and dep_svc.is_nova_component()
):
new_watcher, new_socket, uri = create_nova_watcher(
bento_id,
dep_svc,
uds_path,
port_stack,
backlog,
allocator,
str(bento_path.absolute()),
env=env,
)
else:
# Regular BentoML service
new_watcher, new_socket, uri = create_dependency_watcher(
bento_id,
dep_svc,
uds_path,
port_stack,
backlog,
allocator,
str(bento_path.absolute()),
env=env,
)
watchers.append(new_watcher)
sockets.append(new_socket)
dependency_map[name] = uri
server_on_deployment(dep_svc)
# reserve one more to avoid conflicts
port_stack.enter_context(reserve_free_port())
try:
ipaddr = ipaddress.ip_address(host)
if ipaddr.version == 4:
family = socket.AF_INET
elif ipaddr.version == 6:
family = socket.AF_INET6
else:
raise BentoMLConfigException(
f"Unsupported host IP address version: {ipaddr.version}"
)
except ValueError as e:
raise BentoMLConfigException(f"Invalid host IP address: {host}") from e
if not svc.is_nova_component():
sockets.append(
CircusSocket(
name=API_SERVER_NAME,
host=host,
port=port,
family=family,
backlog=backlog,
)
)
if BentoMLContainer.ssl.enabled.get() and not ssl_certfile:
raise BentoMLConfigException("ssl_certfile is required when ssl is enabled")
ssl_args = construct_ssl_args(
ssl_certfile=ssl_certfile,
ssl_keyfile=ssl_keyfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
)
timeouts_args = construct_timeouts_args(
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
)
timeout_args = ["--timeout", str(timeout)] if timeout else []
server_args = [
"-m",
_SERVICE_WORKER_SCRIPT,
bento_identifier,
"--fd",
f"$(circus.sockets.{API_SERVER_NAME})",
"--service-name",
svc.name,
"--backlog",
str(backlog),
"--worker-id",
"$(CIRCUS.WID)",
*ssl_args,
*timeouts_args,
*timeout_args,
]
if worker_envs:
server_args.extend(["--worker-env", json.dumps(worker_envs)])
if development_mode:
server_args.append("--development-mode")
scheme = "https" if BentoMLContainer.ssl.enabled.get() else "http"
# Check if this is a Nova service
if hasattr(svc, "is_nova_component") and svc.is_nova_component():
# Create Nova-specific watcher using existing socket
args = [
"-m",
"compoundai.cli.serve_nova", # Use our Nova worker module
bento_identifier,
"--service-name",
svc.name,
"--worker-id",
"$(CIRCUS.WID)",
]
watcher = create_watcher(
name=f"nova_service_{svc.name}",
args=args,
numprocesses=num_workers,
working_dir=str(bento_path.absolute()),
close_child_stdin=not development_mode,
env=env, # Dependency map will be injected by serve_http
)
watchers.append(watcher)
print(f"nova_service_{svc.name} entrypoint created")
else:
# Create regular BentoML service watcher
watchers.append(
create_watcher(
name="service",
args=server_args,
working_dir=str(bento_path.absolute()),
numprocesses=num_workers,
close_child_stdin=not development_mode,
env=env,
)
)
log_host = "localhost" if host in ["0.0.0.0", "::"] else host
dependency_map[svc.name] = f"{scheme}://{log_host}:{port}"
# inject runner map now
inject_env = {"BENTOML_RUNNER_MAP": json.dumps(dependency_map)}
print(f"inject_env: {inject_env}")
for watcher in watchers:
if watcher.env is None:
watcher.env = inject_env
else:
watcher.env.update(inject_env)
arbiter_kwargs: dict[str, t.Any] = {
"watchers": watchers,
"sockets": sockets,
"threaded": threaded,
}
if reload:
reload_plugin = make_reload_plugin(str(bento_path.absolute()), bentoml_home)
arbiter_kwargs["plugins"] = [reload_plugin]
if development_mode:
arbiter_kwargs["debug"] = True
arbiter_kwargs["loggerconfig"] = SERVER_LOGGING_CONFIG
arbiter = create_standalone_arbiter(**arbiter_kwargs)
arbiter.exit_stack.enter_context(
track_serve(svc, production=not development_mode)
)
arbiter.exit_stack.callback(shutil.rmtree, uds_path, ignore_errors=True)
arbiter.start(
cb=lambda _: logger.info( # type: ignore
"Starting Nova Service %s (%s/%s) listening on %s://%s:%d (Press CTRL+C to quit)"
if (hasattr(svc, "is_nova_component") and svc.is_nova_component())
else 'Starting production %s BentoServer from "%s" (Press CTRL+C to quit)',
*(
(svc.name, *svc.nova_address(), scheme, log_host, port)
if (hasattr(svc, "is_nova_component") and svc.is_nova_component())
else (scheme.upper(), bento_identifier)
),
),
)
return Server(url=f"{scheme}://{log_host}:{port}", arbiter=arbiter)
except Exception:
shutil.rmtree(uds_path, ignore_errors=True)
raise
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import logging
import os
import sys
from typing import Optional
from urllib.parse import urlparse
import click
import rich
logger = logging.getLogger(__name__)
print("this module was loaded")
def build_start_command() -> click.Group:
from bentoml._internal.utils import add_experimental_docstring
from bentoml_cli.utils import BentoMLCommandGroup
@click.group(name="start", cls=BentoMLCommandGroup)
def cli():
pass
@cli.command()
@click.argument("bento", type=click.STRING, default=".")
@click.option(
"--service-name",
type=click.STRING,
required=False,
default="",
envvar="BENTOML_SERVE_SERVICE_NAME",
help="specify the runner name to serve",
)
@click.option(
"--depends",
type=click.STRING,
multiple=True,
envvar="BENTOML_SERVE_DEPENDS",
help="list of runners map",
)
@click.option(
"--runner-map",
type=click.STRING,
envvar="BENTOML_SERVE_RUNNER_MAP",
help="[Deprecated] use --depends instead. "
"JSON string of runners map. For backword compatibility for yatai < 1.0.0",
)
@click.option(
"--bind",
type=click.STRING,
help="[Deprecated] use --host and --port instead."
"Bind address for the server. For backword compatibility for yatai < 1.0.0",
required=False,
)
@click.option(
"--port",
type=click.INT,
help="The port to listen on for the REST api server",
envvar="BENTOML_PORT",
show_envvar=True,
)
@click.option(
"--host",
type=click.STRING,
help="The host to bind for the REST api server [defaults: 127.0.0.1(dev), 0.0.0.0(production)]",
show_envvar="BENTOML_HOST",
)
@click.option(
"--backlog",
type=click.INT,
help="The maximum number of pending connections.",
show_envvar=True,
)
@click.option(
"--api-workers",
type=click.INT,
help="Specify the number of API server workers to start. Default to number of available CPU cores in production mode",
envvar="BENTOML_API_WORKERS",
)
@click.option(
"--timeout",
type=click.INT,
help="Specify the timeout (seconds) for API server",
envvar="BENTOML_TIMEOUT",
)
@click.option(
"--working-dir",
type=click.Path(),
help="When loading from source code, specify the directory to find the Service instance",
default=None,
show_default=True,
)
@click.option("--ssl-certfile", type=str, help="SSL certificate file")
@click.option("--ssl-keyfile", type=str, help="SSL key file")
@click.option("--ssl-keyfile-password", type=str, help="SSL keyfile password")
@click.option(
"--ssl-version", type=int, help="SSL version to use (see stdlib 'ssl' module)"
)
@click.option(
"--ssl-cert-reqs",
type=int,
help="Whether client certificate is required (see stdlib 'ssl' module)",
)
@click.option("--ssl-ca-certs", type=str, help="CA certificates file")
@click.option(
"--ssl-ciphers", type=str, help="Ciphers to use (see stdlib 'ssl' module)"
)
@click.option(
"--timeout-keep-alive",
type=int,
help="Close Keep-Alive connections if no new data is received within this timeout.",
)
@click.option(
"--timeout-graceful-shutdown",
type=int,
default=None,
help="Maximum number of seconds to wait for graceful shutdown. After this timeout, the server will start terminating requests.",
)
@click.option(
"--reload",
is_flag=True,
help="Reload Service when code changes detected",
default=False,
)
@add_experimental_docstring
def start(
bento: str,
service_name: str,
depends: Optional[list[str]],
runner_map: Optional[str],
bind: Optional[str],
port: Optional[int],
host: Optional[str],
backlog: Optional[int],
working_dir: Optional[str],
api_workers: Optional[int],
timeout: Optional[int],
ssl_certfile: Optional[str],
ssl_keyfile: Optional[str],
ssl_keyfile_password: Optional[str],
ssl_version: Optional[int],
ssl_cert_reqs: Optional[int],
ssl_ca_certs: Optional[str],
ssl_ciphers: Optional[str],
timeout_keep_alive: Optional[int],
timeout_graceful_shutdown: Optional[int],
reload: bool = False,
) -> None:
"""
Start a HTTP API server standalone. This will be used inside Yatai.
"""
from bentoml import Service
from bentoml._internal.service.loader import load
if working_dir is None:
if os.path.isdir(os.path.expanduser(bento)):
working_dir = os.path.expanduser(bento)
else:
working_dir = "."
if sys.path[0] != working_dir:
sys.path.insert(0, working_dir)
if depends:
runner_map_dict = dict([s.split("=", maxsplit=2) for s in depends or []])
elif runner_map:
runner_map_dict = json.loads(runner_map)
else:
runner_map_dict = {}
if bind is not None:
parsed = urlparse(bind)
assert parsed.scheme == "tcp"
host = parsed.hostname or host
port = parsed.port or port
svc = load(bento, working_dir=working_dir)
if isinstance(svc, Service):
if reload:
logger.warning("--reload does not work with legacy style services")
# for <1.2 bentos
if not service_name or service_name == svc.name:
from bentoml.start import start_http_server
for dep in depends or []:
rich.print(f"Using remote: {dep}")
start_http_server(
bento,
runner_map=runner_map_dict,
working_dir=working_dir,
port=port,
host=host,
backlog=backlog,
api_workers=api_workers or 1,
timeout=timeout,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
)
else:
from bentoml.start import start_runner_server
if bind is not None:
parsed = urlparse(bind)
assert parsed.scheme == "tcp"
host = parsed.hostname or host
port = parsed.port or port
start_runner_server(
bento,
runner_name=service_name,
working_dir=working_dir,
timeout=timeout,
port=port,
host=host,
backlog=backlog,
)
else:
# for >=1.2 bentos
from compoundai.cli.serving import serve_http
print(f"Starting service {service_name}")
svc.inject_config()
serve_http(
bento,
working_dir=working_dir,
port=port,
host=host,
backlog=backlog,
timeout=timeout,
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile,
ssl_keyfile_password=ssl_keyfile_password,
ssl_version=ssl_version,
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=ssl_ca_certs,
ssl_ciphers=ssl_ciphers,
timeout_keep_alive=timeout_keep_alive,
timeout_graceful_shutdown=timeout_graceful_shutdown,
dependency_map=runner_map_dict,
service_name=service_name,
reload=reload,
)
return cli
start_command = build_start_command()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def main() -> None:
print("Hello from compoundai!")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import typing as t
from functools import wraps
from typing import Any, get_type_hints
import bentoml
from pydantic import BaseModel
class NovaEndpoint:
"""Decorator class for Nova endpoints"""
def __init__(self, func: t.Callable, name: str | None = None):
self.func = func
self.name = name or func.__name__
self.is_nova_endpoint = True
# Extract request type from hints
hints = get_type_hints(func)
args = list(hints.items())
# Skip self/cls argument
if args[0][0] in ("self", "cls"):
args = args[1:]
# Get request type from first arg
self.request_type = args[0][1]
wraps(func)(self)
async def __call__(self, *args: t.Any, **kwargs: t.Any) -> Any:
# Validate request
if len(args) > 1 and issubclass(self.request_type, BaseModel):
args = list(args) # type: ignore
if isinstance(args[1], (str, dict)):
args[1] = self.request_type.parse_obj(args[1]) # type: ignore
# Convert Pydantic model to dict before passing to triton
if len(args) > 1 and isinstance(args[1], BaseModel):
args = list(args) # type: ignore
args[1] = args[1].model_dump() # type: ignore
return await self.func(*args, **kwargs)
def nova_endpoint(name: str | None = None) -> t.Callable[[t.Callable], NovaEndpoint]:
"""Decorator for Nova endpoints.
Args:
name: Optional name for the endpoint. Defaults to function name.
Example:
@nova_endpoint()
def my_endpoint(self, input: str) -> str:
return input
@nova_endpoint(name="custom_name")
def another_endpoint(self, input: str) -> str:
return input
"""
def decorator(func: t.Callable) -> NovaEndpoint:
return NovaEndpoint(func, name)
return decorator
def nova_api(func: t.Callable) -> t.Callable:
"""Decorator for BentoML API endpoints.
Args:
func: The function to be decorated.
Returns:
The decorated function.
"""
return bentoml.api(func)
def async_onstart(func: t.Callable) -> t.Callable:
"""Decorator for async onstart functions."""
# Mark the function as a startup hook
setattr(func, "__bentoml_startup_hook__", True)
return bentoml.on_startup(func)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Any, Dict, Optional, TypeVar
from _bentoml_sdk.service import Service
from _bentoml_sdk.service.dependency import Dependency
from compoundai.sdk.service import CompoundService
T = TypeVar("T")
class NovaClient:
"""Client for calling Nova endpoints with streaming support"""
def __init__(self, service: CompoundService[Any]):
self._service = service
self._endpoints = service.get_nova_endpoints()
self._nova_clients: Dict[str, Any] = {}
self._runtime = None
def __getattr__(self, name: str) -> Any:
if name not in self._endpoints:
raise AttributeError(
f"No Nova endpoint '{name}' found on service '{self._service.name}'. "
f"Available endpoints: {list(self._endpoints.keys())}"
)
# For streaming endpoints, create/cache the stream function
if name not in self._nova_clients:
namespace, component_name = self._service.nova_address()
# Create async generator function that uses Queue for streaming
async def get_stream(*args, **kwargs):
queue: asyncio.Queue = asyncio.Queue()
if self._runtime is not None:
# Use existing runtime if available
async def stream_worker():
try:
client = (
await self._runtime.namespace(namespace)
.component(component_name)
.endpoint(name)
.client()
)
# TODO: Potentially model dump for a user here so they can pass around Pydantic models
stream = await client.generate(*args, **kwargs)
async for item in stream:
data = item.data()
print(f"Item data: {data}")
await queue.put(data)
await queue.put(None)
except Exception:
await queue.put(None)
raise
else:
# Create nova worker if no runtime
from triton_distributed_rs import DistributedRuntime, triton_worker
@triton_worker()
async def stream_worker(runtime: DistributedRuntime):
try:
# Store runtime for future use
self._runtime = runtime
client = (
await runtime.namespace(namespace)
.component(component_name)
.endpoint(name)
.client()
)
stream = await client.generate(*args, **kwargs)
async for item in stream:
data = item.data()
print(f"Item data: {data}")
await queue.put(data)
await queue.put(None)
except Exception:
await queue.put(None)
raise
# Start worker task with error handling
worker_task = asyncio.create_task(stream_worker())
try:
# Yield items from queue until None received
while True:
item = await queue.get()
if item is None:
break
yield item
finally:
try:
await worker_task
except Exception:
raise
self._nova_clients[name] = get_stream
return self._nova_clients[name]
class NovaDependency(Dependency[T]):
"""Enhanced dependency that supports Nova endpoints"""
def __init__(
self,
on: Service[T] | None = None,
url: str | None = None,
deployment: str | None = None,
cluster: str | None = None,
):
super().__init__(on, url=url, deployment=deployment, cluster=cluster)
self._nova_client: Optional[NovaClient] = None
self._runtime = None
def set_runtime(self, runtime: Any) -> None:
"""Set the Nova runtime for this dependency"""
self._runtime = runtime
if self._nova_client:
self._nova_client._runtime = runtime
def get(self, *args: Any, **kwargs: Any) -> T | Any:
# If this is a Nova-enabled service, return the Nova client
if isinstance(self.on, CompoundService) and self.on.is_nova_component():
if self._nova_client is None:
self._nova_client = NovaClient(self.on)
if self._runtime:
self._nova_client._runtime = self._runtime
return self._nova_client
# Otherwise fall back to normal BentoML dependency resolution
return super().get(*args, **kwargs)
def depends(
on: Service[T] | None = None,
*,
url: str | None = None,
deployment: str | None = None,
cluster: str | None = None,
) -> NovaDependency[T]:
"""Create a dependency that's Nova-aware.
If the dependency is on a Nova-enabled service, this will return a client
that can call Nova endpoints. Otherwise behaves like normal BentoML dependency.
Args:
on: The service to depend on
url: URL for remote service
deployment: Deployment name
cluster: Cluster name
Raises:
AttributeError: When trying to call a non-existent Nova endpoint
"""
if on is not None and not isinstance(on, Service):
raise TypeError("depends() expects a class decorated with @service()")
return NovaDependency(on, url=url, deployment=deployment, cluster=cluster)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# wrapper over bento images to handle TritonDistributed base image
import bentoml
NOVA_IMAGE = bentoml.images.PythonImage(base_image="triton-distributed:latest-vllm")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
from _bentoml_sdk import Service, ServiceConfig
from _bentoml_sdk.images import Image
from compoundai.sdk.decorators import NovaEndpoint
T = TypeVar("T", bound=object)
@dataclass
class NovaConfig:
"""Configuration for Nova components"""
enabled: bool = False
name: str | None = None
namespace: str | None = None
class CompoundService(Service[T]):
"""A custom service class that extends BentoML's base Service with Nova capabilities"""
def __init__(
self,
config: ServiceConfig,
inner: type[T],
image: Optional[Image] = None,
envs: Optional[list[dict[str, Any]]] = None,
nova_config: Optional[NovaConfig] = None,
):
super().__init__(config=config, inner=inner, image=image, envs=envs or [])
# Initialize Nova configuration
self._nova_config = (
nova_config
if nova_config
else NovaConfig(name=inner.__name__, namespace="default")
)
if self._nova_config.name is None:
self._nova_config.name = inner.__name__
# Register Nova endpoints
self._nova_endpoints: Dict[str, NovaEndpoint] = {}
for field in dir(inner):
value = getattr(inner, field)
if isinstance(value, NovaEndpoint):
self._nova_endpoints[value.name] = value
def is_nova_component(self) -> bool:
"""Check if this service is configured as a Nova component"""
return self._nova_config.enabled
def nova_address(self) -> Tuple[str, str]:
"""Get the Nova address for this component in namespace/name format"""
if not self.is_nova_component():
raise ValueError("Service is not configured as a Nova component")
# Check if we have a runner map with Nova address
runner_map = os.environ.get("BENTOML_RUNNER_MAP")
if runner_map:
try:
runners = json.loads(runner_map)
if self.name in runners:
address = runners[self.name]
if address.startswith("nova://"):
# Parse nova://namespace/name into (namespace, name)
_, path = address.split("://", 1)
namespace, name = path.split("/", 1)
print(
f"Resolved Nova address from runner map: {namespace}/{name}"
)
return (namespace, name)
except (json.JSONDecodeError, ValueError) as e:
raise ValueError(f"Failed to parse BENTOML_RUNNER_MAP: {str(e)}") from e
# Ensure namespace and name are not None
namespace = self._nova_config.namespace or "default"
name = self._nova_config.name or self.inner.__name__
print(f"Using default Nova address: {namespace}/{name}")
return (namespace, name)
def get_nova_endpoints(self) -> Dict[str, NovaEndpoint]:
"""Get all registered Nova endpoints"""
return self._nova_endpoints
def get_nova_endpoint(self, name: str) -> NovaEndpoint:
"""Get a specific Nova endpoint by name"""
if name not in self._nova_endpoints:
raise ValueError(f"No Nova endpoint found with name: {name}")
return self._nova_endpoints[name]
def list_nova_endpoints(self) -> List[str]:
"""List names of all registered Nova endpoints"""
return list(self._nova_endpoints.keys())
# todo: add another function to bind an instance of the inner to the self within these methods
def service(
inner: Optional[type[T]] = None,
/,
*,
image: Optional[Image] = None,
envs: Optional[list[dict[str, Any]]] = None,
nova: Optional[Union[Dict[str, Any], NovaConfig]] = None,
**kwargs: Any,
) -> Any:
"""Enhanced service decorator that supports Nova configuration
Args:
nova: Nova configuration, either as a NovaConfig object or dict with keys:
- enabled: bool (default True)
- name: str (default: class name)
- namespace: str (default: "default")
**kwargs: Existing BentoML service configuration
"""
config = kwargs
# Parse dict into NovaConfig object
nova_config: Optional[NovaConfig] = None
if nova is not None:
if isinstance(nova, dict):
nova_config = NovaConfig(**nova)
else:
nova_config = nova
def decorator(inner: type[T]) -> CompoundService[T]:
if isinstance(inner, Service):
raise TypeError("service() decorator can only be applied once")
return CompoundService(
config=config,
inner=inner,
image=image,
envs=envs or [],
nova_config=nova_config,
)
return decorator(inner) if inner is not None else decorator
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment