Unverified Commit b293b45b authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

feat: introduce abstract classes to dynamo services (#924)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 568eb100
...@@ -21,8 +21,9 @@ from typing import Any ...@@ -21,8 +21,9 @@ from typing import Any
warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources.*") warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources.*")
# flake8: noqa: E402 # flake8: noqa: E402
from dynamo.sdk.core.decorators.endpoint import api, endpoint from dynamo.sdk.core.decorators.endpoint import abstract_endpoint, api, endpoint
from dynamo.sdk.core.lib import DYNAMO_IMAGE, depends, liveness, readiness, service from dynamo.sdk.core.lib import DYNAMO_IMAGE, depends, liveness, readiness, service
from dynamo.sdk.core.protocol.interface import AbstractService
from dynamo.sdk.lib.decorators import async_on_start, on_shutdown from dynamo.sdk.lib.decorators import async_on_start, on_shutdown
dynamo_context: dict[str, Any] = {} dynamo_context: dict[str, Any] = {}
...@@ -36,6 +37,8 @@ __all__ = [ ...@@ -36,6 +37,8 @@ __all__ = [
"endpoint", "endpoint",
"api", "api",
"service", "service",
"AbstractService",
"abstract_endpoint",
"liveness", "liveness",
"readiness", "readiness",
] ]
...@@ -231,6 +231,10 @@ def serve_dynamo_graph( ...@@ -231,6 +231,10 @@ def serve_dynamo_graph(
for name, dep_svc in svc.all_services().items(): for name, dep_svc in svc.all_services().items():
if name == svc.name or name in dependency_map: if name == svc.name or name in dependency_map:
continue continue
if not dep_svc.is_servable():
raise RuntimeError(
f"Service {dep_svc.name} is not servable. Please use link to override with a concrete implementation."
)
new_watcher, new_socket, uri = create_dynamo_watcher( new_watcher, new_socket, uri = create_dynamo_watcher(
dynamo_pipeline, dynamo_pipeline,
dep_svc, dep_svc,
......
...@@ -14,9 +14,20 @@ ...@@ -14,9 +14,20 @@
# limitations under the License. # limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES # Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import abc
import asyncio import asyncio
import typing as t
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints from typing import (
Any,
Callable,
Dict,
List,
Optional,
Protocol,
TypeVar,
get_type_hints,
)
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
from dynamo.sdk.core.protocol.interface import ( from dynamo.sdk.core.protocol.interface import (
...@@ -28,6 +39,12 @@ from dynamo.sdk.core.protocol.interface import ( ...@@ -28,6 +39,12 @@ from dynamo.sdk.core.protocol.interface import (
T = TypeVar("T") T = TypeVar("T")
class AbstractDynamoEndpoint(Protocol):
"""Protocol for functions that can be marked as abstract dynamo endpoints."""
__is_abstract_dynamo__: bool
class DynamoEndpoint(DynamoEndpointInterface): class DynamoEndpoint(DynamoEndpointInterface):
""" """
Base class for dynamo endpoints Base class for dynamo endpoints
...@@ -72,6 +89,13 @@ class DynamoEndpoint(DynamoEndpointInterface): ...@@ -72,6 +89,13 @@ class DynamoEndpoint(DynamoEndpointInterface):
return self._transports return self._transports
# Decorator for abstract dynamo endpoints
def abstract_endpoint(func: t.Callable) -> t.Callable:
"""Mark an abstract endpoint in an interface."""
func.__is_abstract_dynamo__ = True # type: ignore
return abc.abstractmethod(func)
def endpoint( def endpoint(
name: Optional[str] = None, name: Optional[str] = None,
transports: Optional[List[DynamoTransport]] = None, transports: Optional[List[DynamoTransport]] = None,
......
...@@ -16,15 +16,18 @@ ...@@ -16,15 +16,18 @@
import logging import logging
import os import os
from typing import Any, Callable, Optional, Type, TypeVar from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union
from fastapi import FastAPI from fastapi import FastAPI
from dynamo.sdk.core.protocol.interface import ( from dynamo.sdk.core.protocol.interface import (
AbstractService,
DependencyInterface, DependencyInterface,
DeploymentTarget, DeploymentTarget,
DynamoConfig,
ServiceConfig, ServiceConfig,
ServiceInterface, ServiceInterface,
validate_dynamo_interfaces,
) )
G = TypeVar("G", bound=Callable[..., Any]) G = TypeVar("G", bound=Callable[..., Any])
...@@ -33,6 +36,10 @@ G = TypeVar("G", bound=Callable[..., Any]) ...@@ -33,6 +36,10 @@ G = TypeVar("G", bound=Callable[..., Any])
# this should be set to a concrete implementation of the DeploymentTarget interface # this should be set to a concrete implementation of the DeploymentTarget interface
_target: DeploymentTarget _target: DeploymentTarget
# Add global cache for abstract services
_abstract_service_cache: Dict[Type[AbstractService], ServiceInterface[Any]] = {}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DYNAMO_IMAGE = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm") DYNAMO_IMAGE = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm")
...@@ -50,39 +57,93 @@ def get_target() -> DeploymentTarget: ...@@ -50,39 +57,93 @@ def get_target() -> DeploymentTarget:
return _target return _target
# Helper function to get or create service instance for AbstractService
def _get_or_create_abstract_service_instance(
abstract_service_cls: Type[AbstractService],
) -> ServiceInterface[Any]:
"""
Retrieves a service instance from cache or creates a new one
for the given AbstractService class.
"""
global _abstract_service_cache
if abstract_service_cls in _abstract_service_cache:
return _abstract_service_cache[abstract_service_cls]
# This placeholder service will be a singleton, and will be used for all dependencies that depend on this abstract service.
# The name for DynamoConfig will be the class name of the abstract service.
dynamo_config_for_abstract = DynamoConfig(enabled=True)
# Call the main service() decorator/function to create the service instance
# validate_dynamo_interfaces is False because validating an interface has implemented dynamo endpoints will obviously fail
service_instance = service(
abstract_service_cls,
dynamo=dynamo_config_for_abstract,
should_validate_dynamo_interfaces=False,
)
_abstract_service_cache[abstract_service_cls] = service_instance
return service_instance
def service( def service(
inner: Optional[Type[G]] = None, inner: Optional[Type[G]] = None,
/, /,
*, *,
app: Optional[FastAPI] = None, app: Optional[FastAPI] = None,
should_validate_dynamo_interfaces: bool = True,
system_app: Optional[FastAPI] = None, system_app: Optional[FastAPI] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Service decorator that's adapter-agnostic""" """Service decorator that's adapter-agnostic"""
config = ServiceConfig(**kwargs) config = ServiceConfig(**kwargs)
def decorator(inner: Type[G]) -> ServiceInterface[G]: def decorator(inner: Type[G]) -> ServiceInterface[G]:
# Ensures that all declared dynamo endpoints on the parent interfaces are implemented
if should_validate_dynamo_interfaces:
validate_dynamo_interfaces(inner)
provider = get_target() provider = get_target()
if inner is not None: if inner is not None:
config.dynamo.name = inner.__name__ config.dynamo.name = inner.__name__
return provider.create_service( service_instance = provider.create_service(
service_cls=inner, service_cls=inner,
config=config, config=config,
app=app, app=app,
system_app=system_app, system_app=system_app,
**kwargs, **kwargs,
) )
return service_instance
ret = decorator(inner) if inner is not None else decorator ret = decorator(inner) if inner is not None else decorator
return ret return ret
def depends( def depends(
on: Optional[ServiceInterface[G]] = None, **kwargs: Any on: Optional[Union[ServiceInterface[G], Type[AbstractService]]] = None,
**kwargs: Any,
) -> DependencyInterface[G]: ) -> DependencyInterface[G]:
"""Create a dependency using the current service provider""" """Create a dependency using the current service provider.
If 'on' is an AbstractService type, a placeholder service will be
created and used for the dependency.
"""
provider = get_target() provider = get_target()
return provider.create_dependency(on=on, **kwargs) actual_on_service: Optional[ServiceInterface[Any]] = None
if isinstance(on, type) and issubclass(on, AbstractService):
actual_on_service = _get_or_create_abstract_service_instance(on)
# The type of actual_on_service here would be ServiceInterface[NameOfAbstractClass]
# So, T would be NameOfAbstractClass.
return provider.create_dependency(on=actual_on_service, **kwargs)
elif isinstance(on, ServiceInterface):
# This handles both 'on=None' and 'on=SomeServiceInterfaceInstance'
# If 'on' is ServiceInterface[K], T could be K. If 'on' is None, T remains unbound here.
actual_on_service = on
return provider.create_dependency(on=actual_on_service, **kwargs)
else:
raise TypeError(
"depends() expects 'on' to be a ServiceInterface, an AbstractService type"
)
def liveness(func: G) -> G: def liveness(func: G) -> G:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES # Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import abc
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from enum import Enum, auto from enum import Enum, auto
...@@ -27,6 +28,12 @@ from .deployment import Env ...@@ -27,6 +28,12 @@ from .deployment import Env
T = TypeVar("T", bound=object) T = TypeVar("T", bound=object)
class AbstractService(abc.ABC):
"""Base class for Dynamo service interfaces."""
pass
class LeaseConfig(BaseModel): class LeaseConfig(BaseModel):
"""Configuration for custom dynamo leases""" """Configuration for custom dynamo leases"""
...@@ -146,10 +153,59 @@ class ServiceInterface(Generic[T], ABC): ...@@ -146,10 +153,59 @@ class ServiceInterface(Generic[T], ABC):
"""List names of all registered endpoints""" """List names of all registered endpoints"""
pass pass
@abstractmethod
def link(self, next_service: "ServiceInterface") -> "ServiceInterface": def link(self, next_service: "ServiceInterface") -> "ServiceInterface":
"""Link this service to another service, creating a pipeline""" """Link this service to another service, creating a pipeline.
pass
This method allows linking (injecting) a concrete service implementation by checking if there is a dependency that next_service implements/inherits from.
Args:
next_service: The concrete service implementation to link
Returns:
The next_service that was linked to this service
Raises:
ValueError: If no matching interface is found or if multiple matches are found
"""
if not isinstance(next_service, ServiceInterface):
raise ValueError(f"link must be passed a Service, got {type(next_service)}")
# Get all the deps of the service
inner_deps = [
(dep.on.inner, dep_key, dep)
for dep_key, dep in self.dependencies.items()
if dep.on is not None
] # type: ignore
# Get the inner class of the passed in service
curr_inner = next_service.inner
# Find deps that next_service implements/inherits from
matching_deps = []
for dep_inner, dep_key, original_dep in inner_deps:
if issubclass(curr_inner, dep_inner):
matching_deps.append((dep_inner, dep_key, original_dep))
if not matching_deps:
raise ValueError(
f"{curr_inner.__name__} does not fulfill any dependencies required by {self.name}"
)
if len(matching_deps) > 1:
dep_names = [dep_key for _, _, dep_key in matching_deps]
raise ValueError(
f"{curr_inner.__name__} fulfills multiple dependencies required by {self.name}: {dep_names}"
)
# Get the matching interface, dep_key, and original dependency
_, _, matching_dep = matching_deps[0]
# Let's hot swap the on of the existing dependency with the new service
matching_dep.on = next_service
# Record the link
LinkedServices.add((self, next_service))
return next_service
@abstractmethod @abstractmethod
def remove_unused_edges(self, used_edges: Set["ServiceInterface"]) -> None: def remove_unused_edges(self, used_edges: Set["ServiceInterface"]) -> None:
...@@ -194,6 +250,93 @@ class ServiceInterface(Generic[T], ABC): ...@@ -194,6 +250,93 @@ class ServiceInterface(Generic[T], ABC):
def dynamo_address(self) -> tuple[str, str]: def dynamo_address(self) -> tuple[str, str]:
raise NotImplementedError() raise NotImplementedError()
def is_servable(self) -> bool:
"""Check if this service is ready to be served.
A service is servable if:
1. It is not a subclass of AbstractService (concrete service)
2. If it is a subclass of AbstractService, all abstract methods are implemented
with @dynamo_endpoint decorators
"""
# If not a AbstractService, it's servable by default
if not issubclass(self.inner, AbstractService):
return True
# For AbstractService, check implementations
abstract_endpoints = _get_abstract_dynamo_endpoints(self.inner)
if (
not abstract_endpoints
): # No abstract endpoints to implement, so it's servable
return True
return all(
_check_dynamo_endpoint_implemented(self.inner, name)
for name in abstract_endpoints
) # type: ignore[return-value]
def _get_abstract_dynamo_endpoints(cls: type) -> Set[str]:
"""Get all abstract endpoint names from the class's MRO."""
return {
name
for base in cls.mro()
for name, val in base.__dict__.items()
if getattr(val, "__is_abstract_dynamo__", False)
}
def _check_dynamo_endpoint_implemented(cls: type, name: str) -> bool:
"""Check if an endpoint is properly implemented."""
impl = getattr(cls, name, None)
# Ensure the implementation is a callable DynamoEndpointInterface
return (
impl is not None
and callable(impl)
and isinstance(impl, DynamoEndpointInterface)
)
def validate_dynamo_interfaces(cls: type) -> None:
"""
Validate that *cls* fully implements every @abstract_endpoint
declared in its ancestors and that each implementation is
decorated with @dynamo_endpoint.
"""
required = _get_abstract_dynamo_endpoints(cls)
missing: List[str] = []
undecorated: List[str] = []
not_callable: List[Tuple[str, str]] = []
for name in required:
impl = getattr(cls, name, None)
if impl is None:
missing.append(name)
continue
if not callable(impl):
not_callable.append((name, type(impl).__name__))
continue
if not isinstance(impl, DynamoEndpointInterface):
undecorated.append(name)
problems = []
if missing:
problems.append(f"missing implementation(s): {', '.join(missing)}")
if undecorated:
problems.append(
f"method(s) not decorated with @endpoint: {', '.join(undecorated)}"
)
if not_callable:
problems.append(
", ".join(f"{n} must be callable, got {kind}" for n, kind in not_callable)
)
if problems:
raise TypeError(
f"{cls.__name__} violates Dynamo interface — " + "; ".join(problems)
)
class DeploymentTarget(ABC): class DeploymentTarget(ABC):
"""Interface for service provider implementations""" """Interface for service provider implementations"""
...@@ -226,6 +369,12 @@ class DependencyInterface(Generic[T], ABC): ...@@ -226,6 +369,12 @@ class DependencyInterface(Generic[T], ABC):
"""Get the service this dependency is on""" """Get the service this dependency is on"""
pass pass
@on.setter
@abstractmethod
def on(self, value: Optional[ServiceInterface[T]]) -> None:
"""Set the service this dependency is on"""
pass
@abstractmethod @abstractmethod
def get(self, *args: Any, **kwargs: Any) -> Any: def get(self, *args: Any, **kwargs: Any) -> Any:
"""Get the dependency client""" """Get the dependency client"""
......
...@@ -34,7 +34,6 @@ from dynamo.sdk.core.protocol.interface import ( ...@@ -34,7 +34,6 @@ from dynamo.sdk.core.protocol.interface import (
DeploymentTarget, DeploymentTarget,
DynamoEndpointInterface, DynamoEndpointInterface,
DynamoTransport, DynamoTransport,
LinkedServices,
ServiceConfig, ServiceConfig,
ServiceInterface, ServiceInterface,
) )
...@@ -131,8 +130,8 @@ class LocalService(ServiceMixin, ServiceInterface[T]): ...@@ -131,8 +130,8 @@ class LocalService(ServiceMixin, ServiceInterface[T]):
return list(self._endpoints.keys()) return list(self._endpoints.keys())
def link(self, next_service: "ServiceInterface") -> "ServiceInterface": def link(self, next_service: "ServiceInterface") -> "ServiceInterface":
LinkedServices.add((self, next_service)) # Call the base implementation which handles AbstractService dependencies
return next_service return super().link(next_service)
def remove_unused_edges(self, used_edges: Set["ServiceInterface"]) -> None: def remove_unused_edges(self, used_edges: Set["ServiceInterface"]) -> None:
current_deps = dict(self._dependencies) current_deps = dict(self._dependencies)
...@@ -167,6 +166,10 @@ class LocalDependency(DependencyInterface[T]): ...@@ -167,6 +166,10 @@ class LocalDependency(DependencyInterface[T]):
def on(self) -> Optional[ServiceInterface[T]]: def on(self) -> Optional[ServiceInterface[T]]:
return self._on_service return self._on_service
@on.setter
def on(self, value: Optional[ServiceInterface[T]]) -> None:
self._on_service = value
def get(self, *args: Any, **kwargs: Any) -> Any: def get(self, *args: Any, **kwargs: Any) -> Any:
# Return a client that can communicate with the service # Return a client that can communicate with the service
# through the circus socket # through the circus socket
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dynamo.sdk import (
DYNAMO_IMAGE,
AbstractService,
abstract_endpoint,
api,
depends,
endpoint,
service,
)
logger = logging.getLogger(__name__)
class ChatRequest(BaseModel):
text: str
"""
Pipeline Architecture:
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/v1/chat/completions)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Router │ Routes requests to appropriate worker
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Worker │ Generates text using LLM
└─────────────┘
"""
class WorkerInterface(AbstractService):
"""Interface for LLM workers."""
@abstract_endpoint # enforces that the service implements the method, but also that it is properly decorated
async def generate(self, request: ChatRequest):
pass
class RouterInterface(AbstractService):
"""Interface for request routers."""
@abstract_endpoint
async def generate(self, request: ChatRequest):
pass
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class VllmWorker(WorkerInterface):
@endpoint()
async def generate(self, request: ChatRequest):
# Convert to Spongebob case (randomly capitalize letters)
for token in request.text.split():
spongebob_token = "".join(
c.upper() if random.random() < 0.5 else c.lower() for c in token
)
yield spongebob_token
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class TRTLLMWorker(WorkerInterface):
@endpoint()
async def generate(self, request: ChatRequest):
# Convert to SHOUTING case
for token in request.text.split():
yield token.upper()
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class SlowRouter(RouterInterface):
worker = depends(WorkerInterface) # Will be overridden by link()
@endpoint()
async def generate(self, request: ChatRequest):
print("Routing slow")
async for response in self.worker.generate(request.model_dump_json()):
await asyncio.sleep(1) # Simulate slow routing with a 1-second delay
yield response
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class FastRouter(RouterInterface):
worker = depends(WorkerInterface) # Will be overridden by link()
@endpoint()
async def generate(self, request: ChatRequest):
print("Routing fast")
async for response in self.worker.generate(request.model_dump_json()):
await asyncio.sleep(0.1) # Simulate fast routing with a 0.1-second delay
yield response
app = FastAPI()
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
app=app,
)
class Frontend:
router = depends(RouterInterface) # Will be overridden by link()
@api()
async def generate(self, request: ChatRequest):
print(f"Received request: {request}")
async def content_generator():
async for response in self.router.generate(request.model_dump_json()):
print(f"Received response: {response}")
# Format as SSE
yield f"data: {response}\n\n"
return StreamingResponse(
content_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# Mix and match pipelines (Tests)
# Frontend.link(SlowRouter).link(TRTLLMWorker) # type: ignore[attr-defined]
# slow_pipeline = Frontend.link(SlowRouter).link(VllmWorker) # type: ignore[attr-defined]
Frontend.link(FastRouter).link(VllmWorker) # type: ignore[attr-defined]
"""
Example usage:
fast_pipeline = Frontend.link(FastRouter).link(TRTLLMWorker)
# slow_pipeline = Frontend.link(SlowRouter).link(VllmWorker)
# mixed_pipeline = Frontend.link(FastRouter).link(VllmWorker)
# Basic setup with VLLM worker and slow router
The interface-based design allows for:
1. Easy swapping of implementations (VLLM vs TRT-LLM)
2. Different routing strategies (slow vs fast)
3. Type safety through interface contracts
"""
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
class ChatRequest(BaseModel):
text: str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment