Unverified Commit 1eab75d2 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: bugfix - dynamo serve merge issue and service config fixes (#1036)


Co-authored-by: default avatarGraham King <grahamk@nvidia.com>
Co-authored-by: default avatarhongkuan <hongkuanz@nvidia.com>
Co-authored-by: default avatarUbuntu <ubuntu@crusoe-prod--inst-2wjuoekvfq72mlpdrcugujrtgfp.us-east1-a.compute.internal>
parent f42a09af
......@@ -116,8 +116,8 @@ class Bento(BaseBento):
)
# build_config.envs.extend(svc.envs)
# build_config.labels.update(svc.labels)
# if svc.image is not None:
# image = Image(base_image=svc.image)
if svc.image is not None:
image = Image(base_image=svc.image)
if not disable_image:
image = populate_image_from_build_config(image, build_config, build_ctx)
build_config = build_config.with_defaults()
......@@ -215,7 +215,6 @@ class Bento(BaseBento):
f.write(get_default_svc_readme(svc, version))
else:
f.write(build_config.description)
if image is None:
bento_info = BentoInfo(
tag=tag,
......@@ -248,10 +247,11 @@ class Bento(BaseBento):
schema=svc.schema() if not is_legacy else {},
)
else:
svc = svc.get_bentoml_service()
services = [
BentoServiceInfo.from_service(s) for s in svc.all_services().values()
BentoServiceInfo.from_service(s.get_bentoml_service())
for s in svc.all_services().values()
]
svc = svc.get_bentoml_service()
bento_info = BentoInfoV2(
tag=tag,
service=svc, # type: ignore # attrs converters do not typecheck
......
......@@ -189,21 +189,9 @@ def main(
component = runtime.namespace(namespace).component(component_name)
try:
# if a custom lease is specified we need to create the service with that lease
lease = None
if service._dynamo_config.custom_lease:
lease = await component.create_service_with_custom_lease(
ttl=service._dynamo_config.custom_lease.ttl
)
lease_id = lease.id()
dynamo_context["lease"] = lease
logger.info(
f"Created {service.name} component with custom lease id {lease_id}"
)
else:
# Create service first
await component.create_service()
logger.info(f"Created {service.name} component")
# Create service first
await component.create_service()
logger.info(f"Created {service.name} component")
# Set runtime on all dependencies
for dep in service.dependencies.values():
......
......@@ -367,5 +367,5 @@ def configure_target_environment(target: TargetEnum):
target = LocalDeploymentTarget()
else:
raise ValueError(f"Invalid target: {target}")
logger.info(f"Setting deployment target to {target}")
logger.debug(f"Setting deployment target to {target}")
set_target(target)
......@@ -105,7 +105,7 @@ class ServiceInterface(Generic[T], ABC):
"""Remove unused dependencies"""
pass
# @abstractmethod
@abstractmethod
def inject_config(self) -> None:
"""Inject configuration from environment into service configs"""
pass
......@@ -117,7 +117,7 @@ class ServiceInterface(Generic[T], ABC):
return {}
# @property
# @abstractmethod
@abstractmethod
def get_service_configs(self) -> Dict[str, ServiceConfig]:
"""Get all services"""
return {}
......@@ -159,24 +159,23 @@ class LeaseConfig:
ttl: int = 1 # seconds
class ComponentType(str, Enum):
"""Types of Dynamo components"""
PLANNER = "planner"
@dataclass
class DynamoConfig:
"""Configuration for Dynamo components"""
def __init__(
self,
enabled: bool = False,
name: Optional[str] = None,
namespace: Optional[str] = None,
custom_lease: Optional[LeaseConfig] = None,
**kwargs,
):
self.enabled = enabled
self.name = name
self.namespace = namespace
self.custom_lease = custom_lease
# Store any additional configuration options
for key, value in kwargs.items():
setattr(self, key, value)
enabled: bool = True
name: str | None = None
namespace: str | None = None
custom_lease: LeaseConfig | None = None
component_type: ComponentType | None = (
None # Indicates if this is a meta/system component
)
class DeploymentTarget(ABC):
......
......@@ -15,6 +15,7 @@
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Set, Type, TypeVar
from _bentoml_sdk import Service as BentoService
......@@ -32,6 +33,7 @@ from dynamo.sdk.core.protocol.interface import (
ServiceConfig,
ServiceInterface,
)
from dynamo.sdk.core.runner.common import ServiceMixin
T = TypeVar("T", bound=object)
......@@ -61,35 +63,75 @@ class BentoEndpoint(DynamoEndpoint):
return self._transports
class BentoMLService(ServiceInterface[T]):
class BentoServiceAdapter(ServiceMixin, ServiceInterface[T]):
"""BentoML adapter implementing the ServiceInterface"""
def __init__(
self,
bentoml_service: BentoService,
service_cls: Type[T],
config: ServiceConfig,
dynamo_config: Optional[DynamoConfig] = None,
app: Optional[FastAPI] = None,
**kwargs,
):
self._bentoml_service = bentoml_service
name = bentoml_service.inner.__name__
name = service_cls.__name__
self._dynamo_config = dynamo_config or DynamoConfig(
name=name, namespace="default"
)
image = kwargs.get("image")
envs = kwargs.get("envs", [])
self.image = image
# Get service args from environment if available
service_args = self._get_service_args(name)
if service_args:
# Update config with service args
for key, value in service_args.items():
if key not in config:
config[key] = value
# Extract and apply specific args if needed
if "workers" in service_args:
config["workers"] = service_args["workers"]
if "envs" in service_args and envs:
envs.extend(service_args["envs"])
elif "envs" in service_args:
envs = service_args["envs"]
# Initialize BentoML service
self._bentoml_service = BentoService(
config=config,
inner=service_cls,
image=image,
envs=envs or [],
)
self._endpoints: Dict[str, BentoEndpoint] = {}
if not app:
self.app = FastAPI(title=name)
else:
self.app = app
self._dependencies: Dict[str, "DependencyInterface"] = {}
self._bentoml_service.config["dynamo"] = asdict(self._dynamo_config)
self._api_endpoints: list[str] = []
# Map BentoML endpoints to our generic interface
for field_name in dir(bentoml_service.inner):
field = getattr(bentoml_service.inner, field_name)
for field_name in dir(service_cls):
field = getattr(service_cls, field_name)
if isinstance(field, DynamoEndpoint):
self._endpoints[field.name] = BentoEndpoint(
field, field.name, field.transports
)
if DynamoTransport.HTTP in field.transports:
# Ensure endpoint path starts with '/'
path = (
field.name if field.name.startswith("/") else f"/{field.name}"
)
self._api_endpoints.append(path)
if isinstance(field, DependencyInterface):
self._dependencies[field_name] = field
# If any API endpoints exist, mark service as HTTP-exposed and list endpoints
if self._api_endpoints:
self._bentoml_service.config["http_exposed"] = True
self._bentoml_service.config["api_endpoints"] = self._api_endpoints.copy()
@property
def dependencies(self) -> dict[str, "DependencyInterface"]:
......@@ -137,7 +179,6 @@ class BentoMLService(ServiceInterface[T]):
instance = self.inner()
return instance
# TODO: add attribution to bentoml
def find_dependent_by_name(self, name: str) -> "ServiceInterface":
"""Find dynamo service by name"""
return self.all_services()[name]
......@@ -159,7 +200,7 @@ class BentoMLDependency(DependencyInterface[T]):
def __init__(
self,
bentoml_dependency: BentoDependency,
on_service: Optional[BentoMLService[T]] = None,
on_service: Optional[BentoServiceAdapter[T]] = None,
):
self._bentoml_dependency = bentoml_dependency
self._on_service = on_service
......@@ -214,20 +255,15 @@ class BentoDeploymentTarget(DeploymentTarget):
app: Optional[FastAPI] = None,
**kwargs,
) -> ServiceInterface[T]:
# Create BentoML service
image = kwargs.get("image")
envs = kwargs.get("envs", [])
bentoml_service = BentoService(
"""Create a BentoServiceAdapter with the given parameters"""
return BentoServiceAdapter(
service_cls=service_cls,
config=config,
inner=service_cls,
image=image,
envs=envs,
dynamo_config=dynamo_config,
app=app,
**kwargs,
)
# Wrap in our adapter
return BentoMLService(bentoml_service, dynamo_config, app)
def create_dependency(
self, on: Optional[ServiceInterface[T]] = None, **kwargs
) -> DependencyInterface[T]:
......@@ -237,7 +273,7 @@ class BentoDeploymentTarget(DeploymentTarget):
# Get the underlying BentoML service if available
bentoml_service = None
if on is not None and isinstance(on, BentoMLService):
if on is not None and isinstance(on, BentoServiceAdapter):
# this is underlying bentoml service
bentoml_service = on.get_bentoml_service()
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import json
import logging
import os
from typing import Any, ClassVar, Dict, Optional
logger = logging.getLogger(__name__)
class ServiceMixin:
"""Mixin for Dynamo services to inject configuration from environment."""
# Class variable to store service configurations
_global_service_configs: ClassVar[Dict[str, Dict[str, Any]]] = {}
def all_services(self) -> Dict[str, Any]:
"""Return all services in the dependency chain."""
raise NotImplementedError("")
def inject_config(self) -> None:
"""Inject configuration from environment into service configs.
This reads from DYNAMO_SERVICE_CONFIG environment variable and merges
the configuration with any existing service config.
"""
# Get service configs from environment
service_config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if not service_config_str:
logger.debug("No DYNAMO_SERVICE_CONFIG found in environment")
return
try:
service_configs = json.loads(service_config_str)
logger.debug(f"Loaded service configs: {service_configs}")
except json.JSONDecodeError as e:
logger.error(f"Failed to parse DYNAMO_SERVICE_CONFIG: {e}")
return
cls = self.__class__
# Store the entire config at class level
if not hasattr(cls, "_global_service_configs"):
setattr(cls, "_global_service_configs", {})
cls._global_service_configs = service_configs
# Process ServiceArgs for all services
all_services = self.all_services()
logger.debug(f"Processing configs for services: {list(all_services.keys())}")
for name, svc in all_services.items():
if name in service_configs:
svc_config = service_configs[name]
# Extract ServiceArgs if present
if "ServiceArgs" in svc_config:
logger.debug(
f"Found ServiceArgs for {name}: {svc_config['ServiceArgs']}"
)
if not hasattr(svc, "_service_args"):
object.__setattr__(svc, "_service_args", {})
svc._service_args = svc_config["ServiceArgs"]
else:
logger.debug(f"No ServiceArgs found for {name}")
# Set default config
if not hasattr(svc, "_service_args"):
object.__setattr__(svc, "_service_args", {"workers": 1})
def get_service_configs(self) -> Dict[str, Dict[str, Any]]:
"""Get the service configurations for resource allocation.
Returns:
Dict mapping service names to their configs
"""
# Get all services in the dependency chain
all_services = self.all_services()
result = {}
# If we have global configs, use them to build service configs
cls = self.__class__
if hasattr(cls, "_global_service_configs"):
for name, svc in all_services.items():
# Start with default config
config = {"workers": 1}
# If service has specific args, use them
if hasattr(svc, "_service_args"):
config.update(svc._service_args)
# If there are global configs for this service, get ServiceArgs
if name in cls._global_service_configs:
svc_config = cls._global_service_configs[name]
if "ServiceArgs" in svc_config:
config.update(svc_config["ServiceArgs"])
result[name] = config
logger.debug(f"Built config for {name}: {config}")
return result
def _remove_service_args(self, service_name: str):
"""Remove ServiceArgs from the environment config after using them, preserving envs"""
logger.debug(f"Removing service args for {service_name}")
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str:
config = json.loads(config_str)
if service_name in config and "ServiceArgs" in config[service_name]:
# Save envs to separate env var before removing ServiceArgs
service_args = config[service_name]["ServiceArgs"]
if "envs" in service_args:
service_envs = os.environ.get("DYNAMO_SERVICE_ENVS", "{}")
envs_config = json.loads(service_envs)
if service_name not in envs_config:
envs_config[service_name] = {}
envs_config[service_name]["ServiceArgs"] = {
"envs": service_args["envs"]
}
os.environ["DYNAMO_SERVICE_ENVS"] = json.dumps(envs_config)
def _get_service_args(self, service_name: str) -> Optional[dict]:
"""Get ServiceArgs from environment config if specified"""
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str:
config = json.loads(config_str)
service_config = config.get(service_name, {})
return service_config.get("ServiceArgs")
return None
......@@ -38,6 +38,7 @@ from dynamo.sdk.core.protocol.interface import (
ServiceConfig,
ServiceInterface,
)
from dynamo.sdk.core.runner.common import ServiceMixin
logger = logging.getLogger(__name__)
......@@ -61,7 +62,7 @@ class LocalEndpoint(DynamoEndpoint):
return self._name
class LocalService(ServiceInterface[T]):
class LocalService(ServiceMixin, ServiceInterface[T]):
"""Circus implementation of the ServiceInterface"""
def __init__(
......
......@@ -438,7 +438,7 @@ async def worker(runtime: DistributedRuntime):
# 3. Attach request handler
#
await endpoint.serve_endpoint(RequestHandler(engine).generate, None)
await endpoint.serve_endpoint(RequestHandler(engine).generate)
class RequestHandler:
......
......@@ -14,7 +14,6 @@
# limitations under the License.
import logging
import os
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
......@@ -60,6 +59,8 @@ class ResponseType(BaseModel):
dynamo={
"namespace": "inference",
},
resource={"cpu": 1, "memory": "500Mi"},
workers=2,
image=DYNAMO_IMAGE,
)
class Backend:
......@@ -76,7 +77,7 @@ class Backend:
logger.info(f"Backend received: {req_text}")
text = f"{req_text}-{self.message}"
for token in text.split():
yield f"[process_id:{os.getpid()}] Backend: {token}"
yield f"Backend: {token}"
@service(
......@@ -101,7 +102,7 @@ class Middle:
next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request):
logger.info(f"Middle received response: {response}")
yield f"[process_id:{os.getpid()}] Middle: {response}"
yield f"Middle: {response}"
@service(
......@@ -132,6 +133,6 @@ class Frontend:
async def content_generator():
async for response in self.middle.generate(request.model_dump_json()):
yield f"[process_id:{os.getpid()}] Frontend: {response}"
yield f"Frontend: {response}"
return StreamingResponse(content_generator())
......@@ -134,7 +134,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate, None)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
def cmd_line_args():
......
......@@ -154,7 +154,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(
RequestHandler(engine_client, default_sampling_params).generate, None
RequestHandler(engine_client, default_sampling_params).generate
)
......
......@@ -108,7 +108,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate, None)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
def cmd_line_args():
......
......@@ -125,7 +125,7 @@ async def init(runtime: DistributedRuntime, config: Config):
# the server will gracefully shutdown (i.e., keep opened TCP streams finishes)
# after the lease is revoked
await endpoint.serve_endpoint(RequestHandler(engine_client).generate, None)
await endpoint.serve_endpoint(RequestHandler(engine_client).generate)
def cmd_line_args():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment