Unverified Commit b293b45b authored by mohammedabdulwahhab's avatar mohammedabdulwahhab Committed by GitHub
Browse files

feat: introduce abstract classes to dynamo services (#924)


Signed-off-by: default avatarmohammedabdulwahhab <furkhan324@berkeley.edu>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent 568eb100
......@@ -21,8 +21,9 @@ from typing import Any
warnings.filterwarnings("ignore", category=UserWarning, message=".*pkg_resources.*")
# flake8: noqa: E402
from dynamo.sdk.core.decorators.endpoint import api, endpoint
from dynamo.sdk.core.decorators.endpoint import abstract_endpoint, api, endpoint
from dynamo.sdk.core.lib import DYNAMO_IMAGE, depends, liveness, readiness, service
from dynamo.sdk.core.protocol.interface import AbstractService
from dynamo.sdk.lib.decorators import async_on_start, on_shutdown
dynamo_context: dict[str, Any] = {}
......@@ -36,6 +37,8 @@ __all__ = [
"endpoint",
"api",
"service",
"AbstractService",
"abstract_endpoint",
"liveness",
"readiness",
]
......@@ -231,6 +231,10 @@ def serve_dynamo_graph(
for name, dep_svc in svc.all_services().items():
if name == svc.name or name in dependency_map:
continue
if not dep_svc.is_servable():
raise RuntimeError(
f"Service {dep_svc.name} is not servable. Please use link to override with a concrete implementation."
)
new_watcher, new_socket, uri = create_dynamo_watcher(
dynamo_pipeline,
dep_svc,
......
......@@ -14,9 +14,20 @@
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import abc
import asyncio
import typing as t
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, TypeVar, get_type_hints
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Protocol,
TypeVar,
get_type_hints,
)
from dynamo.runtime import DistributedRuntime
from dynamo.sdk.core.protocol.interface import (
......@@ -28,6 +39,12 @@ from dynamo.sdk.core.protocol.interface import (
T = TypeVar("T")
class AbstractDynamoEndpoint(Protocol):
"""Protocol for functions that can be marked as abstract dynamo endpoints."""
__is_abstract_dynamo__: bool
class DynamoEndpoint(DynamoEndpointInterface):
"""
Base class for dynamo endpoints
......@@ -72,6 +89,13 @@ class DynamoEndpoint(DynamoEndpointInterface):
return self._transports
# Decorator for abstract dynamo endpoints
def abstract_endpoint(func: t.Callable) -> t.Callable:
"""Mark an abstract endpoint in an interface."""
func.__is_abstract_dynamo__ = True # type: ignore
return abc.abstractmethod(func)
def endpoint(
name: Optional[str] = None,
transports: Optional[List[DynamoTransport]] = None,
......
......@@ -16,15 +16,18 @@
import logging
import os
from typing import Any, Callable, Optional, Type, TypeVar
from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union
from fastapi import FastAPI
from dynamo.sdk.core.protocol.interface import (
AbstractService,
DependencyInterface,
DeploymentTarget,
DynamoConfig,
ServiceConfig,
ServiceInterface,
validate_dynamo_interfaces,
)
G = TypeVar("G", bound=Callable[..., Any])
......@@ -33,6 +36,10 @@ G = TypeVar("G", bound=Callable[..., Any])
# this should be set to a concrete implementation of the DeploymentTarget interface
_target: DeploymentTarget
# Add global cache for abstract services
_abstract_service_cache: Dict[Type[AbstractService], ServiceInterface[Any]] = {}
logger = logging.getLogger(__name__)
DYNAMO_IMAGE = os.getenv("DYNAMO_IMAGE", "dynamo:latest-vllm")
......@@ -50,39 +57,93 @@ def get_target() -> DeploymentTarget:
return _target
# Helper function to get or create service instance for AbstractService
def _get_or_create_abstract_service_instance(
abstract_service_cls: Type[AbstractService],
) -> ServiceInterface[Any]:
"""
Retrieves a service instance from cache or creates a new one
for the given AbstractService class.
"""
global _abstract_service_cache
if abstract_service_cls in _abstract_service_cache:
return _abstract_service_cache[abstract_service_cls]
# This placeholder service will be a singleton, and will be used for all dependencies that depend on this abstract service.
# The name for DynamoConfig will be the class name of the abstract service.
dynamo_config_for_abstract = DynamoConfig(enabled=True)
# Call the main service() decorator/function to create the service instance
# validate_dynamo_interfaces is False because validating an interface has implemented dynamo endpoints will obviously fail
service_instance = service(
abstract_service_cls,
dynamo=dynamo_config_for_abstract,
should_validate_dynamo_interfaces=False,
)
_abstract_service_cache[abstract_service_cls] = service_instance
return service_instance
def service(
inner: Optional[Type[G]] = None,
/,
*,
app: Optional[FastAPI] = None,
should_validate_dynamo_interfaces: bool = True,
system_app: Optional[FastAPI] = None,
**kwargs: Any,
) -> Any:
"""Service decorator that's adapter-agnostic"""
config = ServiceConfig(**kwargs)
def decorator(inner: Type[G]) -> ServiceInterface[G]:
# Ensures that all declared dynamo endpoints on the parent interfaces are implemented
if should_validate_dynamo_interfaces:
validate_dynamo_interfaces(inner)
provider = get_target()
if inner is not None:
config.dynamo.name = inner.__name__
return provider.create_service(
service_instance = provider.create_service(
service_cls=inner,
config=config,
app=app,
system_app=system_app,
**kwargs,
)
return service_instance
ret = decorator(inner) if inner is not None else decorator
return ret
def depends(
on: Optional[ServiceInterface[G]] = None, **kwargs: Any
on: Optional[Union[ServiceInterface[G], Type[AbstractService]]] = None,
**kwargs: Any,
) -> DependencyInterface[G]:
"""Create a dependency using the current service provider"""
"""Create a dependency using the current service provider.
If 'on' is an AbstractService type, a placeholder service will be
created and used for the dependency.
"""
provider = get_target()
return provider.create_dependency(on=on, **kwargs)
actual_on_service: Optional[ServiceInterface[Any]] = None
if isinstance(on, type) and issubclass(on, AbstractService):
actual_on_service = _get_or_create_abstract_service_instance(on)
# The type of actual_on_service here would be ServiceInterface[NameOfAbstractClass]
# So, T would be NameOfAbstractClass.
return provider.create_dependency(on=actual_on_service, **kwargs)
elif isinstance(on, ServiceInterface):
# This handles both 'on=None' and 'on=SomeServiceInterfaceInstance'
# If 'on' is ServiceInterface[K], T could be K. If 'on' is None, T remains unbound here.
actual_on_service = on
return provider.create_dependency(on=actual_on_service, **kwargs)
else:
raise TypeError(
"depends() expects 'on' to be a ServiceInterface, an AbstractService type"
)
def liveness(func: G) -> G:
......
......@@ -14,6 +14,7 @@
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import abc
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum, auto
......@@ -27,6 +28,12 @@ from .deployment import Env
T = TypeVar("T", bound=object)
class AbstractService(abc.ABC):
"""Base class for Dynamo service interfaces."""
pass
class LeaseConfig(BaseModel):
"""Configuration for custom dynamo leases"""
......@@ -146,10 +153,59 @@ class ServiceInterface(Generic[T], ABC):
"""List names of all registered endpoints"""
pass
@abstractmethod
def link(self, next_service: "ServiceInterface") -> "ServiceInterface":
"""Link this service to another service, creating a pipeline"""
pass
"""Link this service to another service, creating a pipeline.
This method allows linking (injecting) a concrete service implementation by checking if there is a dependency that next_service implements/inherits from.
Args:
next_service: The concrete service implementation to link
Returns:
The next_service that was linked to this service
Raises:
ValueError: If no matching interface is found or if multiple matches are found
"""
if not isinstance(next_service, ServiceInterface):
raise ValueError(f"link must be passed a Service, got {type(next_service)}")
# Get all the deps of the service
inner_deps = [
(dep.on.inner, dep_key, dep)
for dep_key, dep in self.dependencies.items()
if dep.on is not None
] # type: ignore
# Get the inner class of the passed in service
curr_inner = next_service.inner
# Find deps that next_service implements/inherits from
matching_deps = []
for dep_inner, dep_key, original_dep in inner_deps:
if issubclass(curr_inner, dep_inner):
matching_deps.append((dep_inner, dep_key, original_dep))
if not matching_deps:
raise ValueError(
f"{curr_inner.__name__} does not fulfill any dependencies required by {self.name}"
)
if len(matching_deps) > 1:
dep_names = [dep_key for _, _, dep_key in matching_deps]
raise ValueError(
f"{curr_inner.__name__} fulfills multiple dependencies required by {self.name}: {dep_names}"
)
# Get the matching interface, dep_key, and original dependency
_, _, matching_dep = matching_deps[0]
# Let's hot swap the on of the existing dependency with the new service
matching_dep.on = next_service
# Record the link
LinkedServices.add((self, next_service))
return next_service
@abstractmethod
def remove_unused_edges(self, used_edges: Set["ServiceInterface"]) -> None:
......@@ -194,6 +250,93 @@ class ServiceInterface(Generic[T], ABC):
def dynamo_address(self) -> tuple[str, str]:
raise NotImplementedError()
def is_servable(self) -> bool:
"""Check if this service is ready to be served.
A service is servable if:
1. It is not a subclass of AbstractService (concrete service)
2. If it is a subclass of AbstractService, all abstract methods are implemented
with @dynamo_endpoint decorators
"""
# If not a AbstractService, it's servable by default
if not issubclass(self.inner, AbstractService):
return True
# For AbstractService, check implementations
abstract_endpoints = _get_abstract_dynamo_endpoints(self.inner)
if (
not abstract_endpoints
): # No abstract endpoints to implement, so it's servable
return True
return all(
_check_dynamo_endpoint_implemented(self.inner, name)
for name in abstract_endpoints
) # type: ignore[return-value]
def _get_abstract_dynamo_endpoints(cls: type) -> Set[str]:
"""Get all abstract endpoint names from the class's MRO."""
return {
name
for base in cls.mro()
for name, val in base.__dict__.items()
if getattr(val, "__is_abstract_dynamo__", False)
}
def _check_dynamo_endpoint_implemented(cls: type, name: str) -> bool:
"""Check if an endpoint is properly implemented."""
impl = getattr(cls, name, None)
# Ensure the implementation is a callable DynamoEndpointInterface
return (
impl is not None
and callable(impl)
and isinstance(impl, DynamoEndpointInterface)
)
def validate_dynamo_interfaces(cls: type) -> None:
"""
Validate that *cls* fully implements every @abstract_endpoint
declared in its ancestors and that each implementation is
decorated with @dynamo_endpoint.
"""
required = _get_abstract_dynamo_endpoints(cls)
missing: List[str] = []
undecorated: List[str] = []
not_callable: List[Tuple[str, str]] = []
for name in required:
impl = getattr(cls, name, None)
if impl is None:
missing.append(name)
continue
if not callable(impl):
not_callable.append((name, type(impl).__name__))
continue
if not isinstance(impl, DynamoEndpointInterface):
undecorated.append(name)
problems = []
if missing:
problems.append(f"missing implementation(s): {', '.join(missing)}")
if undecorated:
problems.append(
f"method(s) not decorated with @endpoint: {', '.join(undecorated)}"
)
if not_callable:
problems.append(
", ".join(f"{n} must be callable, got {kind}" for n, kind in not_callable)
)
if problems:
raise TypeError(
f"{cls.__name__} violates Dynamo interface — " + "; ".join(problems)
)
class DeploymentTarget(ABC):
"""Interface for service provider implementations"""
......@@ -226,6 +369,12 @@ class DependencyInterface(Generic[T], ABC):
"""Get the service this dependency is on"""
pass
@on.setter
@abstractmethod
def on(self, value: Optional[ServiceInterface[T]]) -> None:
"""Set the service this dependency is on"""
pass
@abstractmethod
def get(self, *args: Any, **kwargs: Any) -> Any:
"""Get the dependency client"""
......
......@@ -34,7 +34,6 @@ from dynamo.sdk.core.protocol.interface import (
DeploymentTarget,
DynamoEndpointInterface,
DynamoTransport,
LinkedServices,
ServiceConfig,
ServiceInterface,
)
......@@ -131,8 +130,8 @@ class LocalService(ServiceMixin, ServiceInterface[T]):
return list(self._endpoints.keys())
def link(self, next_service: "ServiceInterface") -> "ServiceInterface":
LinkedServices.add((self, next_service))
return next_service
# Call the base implementation which handles AbstractService dependencies
return super().link(next_service)
def remove_unused_edges(self, used_edges: Set["ServiceInterface"]) -> None:
current_deps = dict(self._dependencies)
......@@ -167,6 +166,10 @@ class LocalDependency(DependencyInterface[T]):
def on(self) -> Optional[ServiceInterface[T]]:
return self._on_service
@on.setter
def on(self, value: Optional[ServiceInterface[T]]) -> None:
self._on_service = value
def get(self, *args: Any, **kwargs: Any) -> Any:
# Return a client that can communicate with the service
# through the circus socket
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import random
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from dynamo.sdk import (
DYNAMO_IMAGE,
AbstractService,
abstract_endpoint,
api,
depends,
endpoint,
service,
)
logger = logging.getLogger(__name__)
class ChatRequest(BaseModel):
text: str
"""
Pipeline Architecture:
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/v1/chat/completions)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Router │ Routes requests to appropriate worker
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Worker │ Generates text using LLM
└─────────────┘
"""
class WorkerInterface(AbstractService):
"""Interface for LLM workers."""
@abstract_endpoint # enforces that the service implements the method, but also that it is properly decorated
async def generate(self, request: ChatRequest):
pass
class RouterInterface(AbstractService):
"""Interface for request routers."""
@abstract_endpoint
async def generate(self, request: ChatRequest):
pass
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class VllmWorker(WorkerInterface):
@endpoint()
async def generate(self, request: ChatRequest):
# Convert to Spongebob case (randomly capitalize letters)
for token in request.text.split():
spongebob_token = "".join(
c.upper() if random.random() < 0.5 else c.lower() for c in token
)
yield spongebob_token
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class TRTLLMWorker(WorkerInterface):
@endpoint()
async def generate(self, request: ChatRequest):
# Convert to SHOUTING case
for token in request.text.split():
yield token.upper()
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class SlowRouter(RouterInterface):
worker = depends(WorkerInterface) # Will be overridden by link()
@endpoint()
async def generate(self, request: ChatRequest):
print("Routing slow")
async for response in self.worker.generate(request.model_dump_json()):
await asyncio.sleep(1) # Simulate slow routing with a 1-second delay
yield response
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
)
class FastRouter(RouterInterface):
worker = depends(WorkerInterface) # Will be overridden by link()
@endpoint()
async def generate(self, request: ChatRequest):
print("Routing fast")
async for response in self.worker.generate(request.model_dump_json()):
await asyncio.sleep(0.1) # Simulate fast routing with a 0.1-second delay
yield response
app = FastAPI()
@service(
dynamo={"namespace": "llm-hello-world"},
image=DYNAMO_IMAGE,
app=app,
)
class Frontend:
router = depends(RouterInterface) # Will be overridden by link()
@api()
async def generate(self, request: ChatRequest):
print(f"Received request: {request}")
async def content_generator():
async for response in self.router.generate(request.model_dump_json()):
print(f"Received response: {response}")
# Format as SSE
yield f"data: {response}\n\n"
return StreamingResponse(
content_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# Mix and match pipelines (Tests)
# Frontend.link(SlowRouter).link(TRTLLMWorker) # type: ignore[attr-defined]
# slow_pipeline = Frontend.link(SlowRouter).link(VllmWorker) # type: ignore[attr-defined]
Frontend.link(FastRouter).link(VllmWorker) # type: ignore[attr-defined]
"""
Example usage:
fast_pipeline = Frontend.link(FastRouter).link(TRTLLMWorker)
# slow_pipeline = Frontend.link(SlowRouter).link(VllmWorker)
# mixed_pipeline = Frontend.link(FastRouter).link(VllmWorker)
# Basic setup with VLLM worker and slow router
The interface-based design allows for:
1. Easy swapping of implementations (VLLM vs TRT-LLM)
2. Different routing strategies (slow vs fast)
3. Type safety through interface contracts
"""
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
class ChatRequest(BaseModel):
text: str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment