Unverified Commit 6d46288c authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: rename dynamo decorator (#1133)

parent b520bf44
...@@ -37,7 +37,7 @@ The code for the pipeline looks like this: ...@@ -37,7 +37,7 @@ The code for the pipeline looks like this:
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_endpoint, service, dynamo_api from dynamo.sdk import DYNAMO_IMAGE, depends, endpoint, service, api
class RequestType(BaseModel): class RequestType(BaseModel):
...@@ -52,7 +52,7 @@ class ResponseType(BaseModel): ...@@ -52,7 +52,7 @@ class ResponseType(BaseModel):
dynamo={"namespace": "inference"}, dynamo={"namespace": "inference"},
) )
class Backend: class Backend:
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
text = f"{req.text}-back" text = f"{req.text}-back"
for token in text.split(): for token in text.split():
...@@ -65,7 +65,7 @@ class Backend: ...@@ -65,7 +65,7 @@ class Backend:
class Middle: class Middle:
backend = depends(Backend) backend = depends(Backend)
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
text = f"{req.text}-mid" text = f"{req.text}-mid"
next_request = RequestType(text=text).model_dump_json() next_request = RequestType(text=text).model_dump_json()
...@@ -83,7 +83,7 @@ app = FastAPI(title="Hello World!") ...@@ -83,7 +83,7 @@ app = FastAPI(title="Hello World!")
class Frontend: class Frontend:
middle = depends(Middle) middle = depends(Middle)
@dynamo_api() @api()
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
async def content_generator(): async def content_generator():
async for response in self.middle.generate(request.model_dump_json()): async for response in self.middle.generate(request.model_dump_json()):
......
...@@ -74,7 +74,7 @@ class ServiceA: ...@@ -74,7 +74,7 @@ class ServiceA:
await self.engine.shutdown() await self.engine.shutdown()
print("ServiceA engine shut down") print("ServiceA engine shut down")
@dynamo_endpoint() @endpoint()
async def generate(self, request: ChatCompletionRequest): async def generate(self, request: ChatCompletionRequest):
# Call dependent service # Call dependent service
processed_request = await self.service_b.preprocess(request) processed_request = await self.service_b.preprocess(request)
...@@ -89,8 +89,8 @@ Dynamo follows a class-based architecture similar to BentoML making it intuitive ...@@ -89,8 +89,8 @@ Dynamo follows a class-based architecture similar to BentoML making it intuitive
1. Class attributes for dependencies using `depends()` 1. Class attributes for dependencies using `depends()`
2. An `__init__` method for standard initialization 2. An `__init__` method for standard initialization
3. Optional lifecycle hooks like `@async_on_start` and `@async_on_shutdown` 3. Optional lifecycle hooks like `@async_on_start` and `@async_on_shutdown`
4. Endpoints defined with `@dynamo_endpoint()`. Optionally, an endpoint can be given a name 4. Endpoints defined with `@endpoint()`. Optionally, an endpoint can be given a name
via `@dynamo_endpoint("my_endpoint_name")`, but otherwise will default to the name of the via `@endpoint("my_endpoint_name")`, but otherwise will default to the name of the
function being decorated if omitted. function being decorated if omitted.
This approach provides a clean separation of concerns and makes the service structure easy to understand. This approach provides a clean separation of concerns and makes the service structure easy to understand.
......
...@@ -17,7 +17,7 @@ from typing import Any ...@@ -17,7 +17,7 @@ from typing import Any
from bentoml import on_shutdown as async_on_shutdown from bentoml import on_shutdown as async_on_shutdown
from dynamo.sdk.core.decorators.endpoint import dynamo_api, dynamo_endpoint from dynamo.sdk.core.decorators.endpoint import api, endpoint
from dynamo.sdk.core.lib import DYNAMO_IMAGE, depends, service from dynamo.sdk.core.lib import DYNAMO_IMAGE, depends, service
from dynamo.sdk.lib.decorators import async_on_start from dynamo.sdk.lib.decorators import async_on_start
...@@ -29,7 +29,7 @@ __all__ = [ ...@@ -29,7 +29,7 @@ __all__ = [
"async_on_start", "async_on_start",
"depends", "depends",
"dynamo_context", "dynamo_context",
"dynamo_endpoint", "endpoint",
"dynamo_api", "api",
"service", "service",
] ]
...@@ -214,7 +214,7 @@ def main( ...@@ -214,7 +214,7 @@ def main(
for name, endpoint in dynamo_endpoints.items(): for name, endpoint in dynamo_endpoints.items():
bound_method = endpoint.func.__get__(class_instance) bound_method = endpoint.func.__get__(class_instance)
# Only pass request type for now, use Any for response # Only pass request type for now, use Any for response
# TODO: Handle a dynamo_endpoint not having types # TODO: Handle an endpoint not having types
# TODO: Handle multiple endpoints in a single component # TODO: Handle multiple endpoints in a single component
dynamo_wrapped_method = dynamo_endpoint(endpoint.request_type, Any)( dynamo_wrapped_method = dynamo_endpoint(endpoint.request_type, Any)(
bound_method bound_method
......
...@@ -31,7 +31,7 @@ T = TypeVar("T") ...@@ -31,7 +31,7 @@ T = TypeVar("T")
class DynamoEndpoint(DynamoEndpointInterface): class DynamoEndpoint(DynamoEndpointInterface):
""" """
Base class for dynamo endpoints Base class for dynamo endpoints
Dynamo endpoints are methods decorated with @dynamo_endpoint. Dynamo endpoints are methods decorated with @endpoint.
""" """
def __init__( def __init__(
...@@ -72,7 +72,7 @@ class DynamoEndpoint(DynamoEndpointInterface): ...@@ -72,7 +72,7 @@ class DynamoEndpoint(DynamoEndpointInterface):
return self._transports return self._transports
def dynamo_endpoint( def endpoint(
name: Optional[str] = None, name: Optional[str] = None,
transports: Optional[List[DynamoTransport]] = None, transports: Optional[List[DynamoTransport]] = None,
**kwargs, **kwargs,
...@@ -85,7 +85,7 @@ def dynamo_endpoint( ...@@ -85,7 +85,7 @@ def dynamo_endpoint(
return decorator return decorator
def dynamo_api( def api(
name: Optional[str] = None, name: Optional[str] = None,
**kwargs, **kwargs,
) -> Callable[[Callable], DynamoEndpoint]: ) -> Callable[[Callable], DynamoEndpoint]:
......
...@@ -63,7 +63,7 @@ class DynamoEndpoint: ...@@ -63,7 +63,7 @@ class DynamoEndpoint:
return await self.func(*args, **kwargs) return await self.func(*args, **kwargs)
def dynamo_endpoint( def endpoint(
name: str | None = None, name: str | None = None,
is_api: bool = False, is_api: bool = False,
) -> t.Callable[[t.Callable], DynamoEndpoint]: ) -> t.Callable[[t.Callable], DynamoEndpoint]:
...@@ -74,11 +74,11 @@ def dynamo_endpoint( ...@@ -74,11 +74,11 @@ def dynamo_endpoint(
is_api: Whether to expose the endpoint as an API. Defaults to False. is_api: Whether to expose the endpoint as an API. Defaults to False.
Example: Example:
@dynamo_endpoint() @endpoint()
def my_endpoint(self, input: str) -> str: def my_endpoint(self, input: str) -> str:
return input return input
@dynamo_endpoint(name="custom_name") @endpoint(name="custom_name")
def another_endpoint(self, input: str) -> str: def another_endpoint(self, input: str) -> str:
return input return input
""" """
......
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.sdk import depends, dynamo_endpoint, service from dynamo.sdk import depends, endpoint, service
from dynamo.sdk.core.protocol.interface import DynamoTransport from dynamo.sdk.core.protocol.interface import DynamoTransport
""" """
...@@ -68,7 +68,7 @@ class Backend: ...@@ -68,7 +68,7 @@ class Backend:
def __init__(self) -> None: def __init__(self) -> None:
print("Starting backend") print("Starting backend")
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
"""Generate tokens.""" """Generate tokens."""
req_text = req.text req_text = req.text
...@@ -77,7 +77,7 @@ class Backend: ...@@ -77,7 +77,7 @@ class Backend:
for token in text.split(): for token in text.split():
yield f"Backend: {token}" yield f"Backend: {token}"
@dynamo_endpoint() @endpoint()
async def generate_v2(self, req: RequestType): async def generate_v2(self, req: RequestType):
"""Generate tokens.""" """Generate tokens."""
req_text = req.text req_text = req.text
...@@ -98,7 +98,7 @@ class Backend2: ...@@ -98,7 +98,7 @@ class Backend2:
def __init__(self) -> None: def __init__(self) -> None:
print("Starting backend2") print("Starting backend2")
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
"""Forward requests to backend.""" """Forward requests to backend."""
...@@ -121,7 +121,7 @@ class Middle: ...@@ -121,7 +121,7 @@ class Middle:
def __init__(self) -> None: def __init__(self) -> None:
print("Starting middle") print("Starting middle")
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
"""Forward requests to backend.""" """Forward requests to backend."""
req_text = req.text req_text = req.text
...@@ -155,7 +155,7 @@ class Frontend: ...@@ -155,7 +155,7 @@ class Frontend:
def __init__(self) -> None: def __init__(self) -> None:
print("Starting frontend") print("Starting frontend")
@dynamo_endpoint(transports=[DynamoTransport.HTTP]) @endpoint(transports=[DynamoTransport.HTTP])
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
print(f"Frontend received: {request.text}") print(f"Frontend received: {request.text}")
......
...@@ -16,7 +16,7 @@ see the [Dynamo Serve Guide](../docs/guides/dynamo_serve.md). ...@@ -16,7 +16,7 @@ see the [Dynamo Serve Guide](../docs/guides/dynamo_serve.md).
When deploying a python-based worker with `dynamo serve` or `dynamo deploy`, it is When deploying a python-based worker with `dynamo serve` or `dynamo deploy`, it is
a Python class based definition that requires a few key decorators to get going: a Python class based definition that requires a few key decorators to get going:
- `@service`: used to define a worker class - `@service`: used to define a worker class
- `@dynamo_endpoint`: marks methods that can be called by other workers or clients - `@endpoint`: marks methods that can be called by other workers or clients
For more detailed information on these concepts, see the For more detailed information on these concepts, see the
[Dynamo SDK Docs](../deploy/sdk/docs/sdk/README.md). [Dynamo SDK Docs](../deploy/sdk/docs/sdk/README.md).
...@@ -26,7 +26,7 @@ For more detailed information on these concepts, see the ...@@ -26,7 +26,7 @@ For more detailed information on these concepts, see the
Here is the rough outline of what a worker may look like in its simplest form: Here is the rough outline of what a worker may look like in its simplest form:
```python ```python
from dynamo.sdk import dynamo_endpoint, service from dynamo.sdk import endpoint, service
@service( @service(
dynamo={ dynamo={
...@@ -37,7 +37,7 @@ class YourWorker: ...@@ -37,7 +37,7 @@ class YourWorker:
# Worker implementation # Worker implementation
# ... # ...
@dynamo_endpoint() @endpoint()
async def your_endpoint(self, request: RequestType) -> AsyncIterator[ResponseType]: async def your_endpoint(self, request: RequestType) -> AsyncIterator[ResponseType]:
# Endpoint Implementation # Endpoint Implementation
pass pass
...@@ -48,7 +48,7 @@ When addressing this worker's endpoint with the `namespace/component/endpoint` s ...@@ -48,7 +48,7 @@ When addressing this worker's endpoint with the `namespace/component/endpoint` s
based on the definitions above, it would be: `your_namespace/YourWorker/your_endpoint`: based on the definitions above, it would be: `your_namespace/YourWorker/your_endpoint`:
- `namespace="your_namespace"`: Defined in the `@service` decorator - `namespace="your_namespace"`: Defined in the `@service` decorator
- `component="YourWorker"`: Defined by the Python Class name - `component="YourWorker"`: Defined by the Python Class name
- `endpoint="your_endpoint"`: Defined by the `@dynamo_endpoint` decorator, or by default the name of the function being decorated. - `endpoint="your_endpoint"`: Defined by the `@endpoint` decorator, or by default the name of the function being decorated.
For more details about service configuration, resource management, and dynamo endpoints, For more details about service configuration, resource management, and dynamo endpoints,
see the [Dynamo SDK Docs](../deploy/sdk/docs/README.md). see the [Dynamo SDK Docs](../deploy/sdk/docs/README.md).
...@@ -79,7 +79,7 @@ Chat Completions objects, such as the ones specified in the OpenAI API. For exam ...@@ -79,7 +79,7 @@ Chat Completions objects, such as the ones specified in the OpenAI API. For exam
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
class YourLLMWorker: class YourLLMWorker:
@dynamo_endpoint(name="my_chat_completions_endpoint") @endpoint(name="my_chat_completions_endpoint")
async def generate(self, request: ChatCompletionRequest): async def generate(self, request: ChatCompletionRequest):
# Endpoint Implementation # Endpoint Implementation
pass pass
...@@ -95,7 +95,7 @@ via custom RequestType/ResponseType definitions: ...@@ -95,7 +95,7 @@ via custom RequestType/ResponseType definitions:
# This can be run standalone with `dynamo serve basic_worker:YourWorker` # This can be run standalone with `dynamo serve basic_worker:YourWorker`
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.sdk import dynamo_endpoint, service from dynamo.sdk import endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -114,7 +114,7 @@ class YourWorker: ...@@ -114,7 +114,7 @@ class YourWorker:
def __init__(self) -> None: def __init__(self) -> None:
logger.info("Starting worker...") logger.info("Starting worker...")
@dynamo_endpoint() @endpoint()
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Generate tokens and stream them back""" """Generate tokens and stream them back"""
logger.info(f"Worker endpoint received: {request.text}") logger.info(f"Worker endpoint received: {request.text}")
...@@ -204,7 +204,7 @@ import random ...@@ -204,7 +204,7 @@ import random
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.llm import KvMetricsPublisher from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import dynamo_endpoint, service, dynamo_context from dynamo.sdk import endpoint, service, dynamo_context
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -269,7 +269,7 @@ class YourWorker: ...@@ -269,7 +269,7 @@ class YourWorker:
self.gpu_prefix_cache_hit_rate, self.gpu_prefix_cache_hit_rate,
) )
@dynamo_endpoint() @endpoint()
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Generate tokens, update KV Cache metrics, and stream the tokens back""" """Generate tokens, update KV Cache metrics, and stream the tokens back"""
# Increment the number of active requests on receiving one # Increment the number of active requests on receiving one
...@@ -384,7 +384,7 @@ class Router: ...@@ -384,7 +384,7 @@ class Router:
return best_worker_id return best_worker_id
@dynamo_endpoint() @endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]: async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
try: try:
# lora_id is a placeholder for lora support, but not used in this example # lora_id is a placeholder for lora support, but not used in this example
...@@ -563,7 +563,7 @@ class DecodeWorker: ...@@ -563,7 +563,7 @@ class DecodeWorker:
.endpoint("generate") .endpoint("generate")
.client() .client()
@dynamo_endpoint() @endpoint()
async def generate(self, request): async def generate(self, request):
if self.do_remote_prefill: if self.do_remote_prefill:
# Forward the request to the prefill worker # Forward the request to the prefill worker
...@@ -580,7 +580,7 @@ class PrefillWorker: ...@@ -580,7 +580,7 @@ class PrefillWorker:
def __init__(self): def __init__(self):
# ... # ...
@dynamo_endpoint() @endpoint()
async def generate(self, request): async def generate(self, request):
# ... framework-specific prefill logic ... # ... framework-specific prefill logic ...
``` ```
...@@ -612,7 +612,7 @@ For more information on Disaggregated Serving, see the ...@@ -612,7 +612,7 @@ For more information on Disaggregated Serving, see the
2. **Async Operations**: Use async/await for I/O operations: 2. **Async Operations**: Use async/await for I/O operations:
```python ```python
@dynamo_endpoint() @endpoint()
async def generate(self, request): async def generate(self, request):
# Use async operations for better performance # Use async operations for better performance
result = await self.some_async_operation() result = await self.some_async_operation()
......
...@@ -17,7 +17,7 @@ For example, the deployment configuration `examples/llm/configs/disagg.yaml` hav ...@@ -17,7 +17,7 @@ For example, the deployment configuration `examples/llm/configs/disagg.yaml` hav
- `Processor`: When a new request arrives, `Processor` applies the chat template and perform the tokenization. Then, it route the request to the `VllmWorker`. - `Processor`: When a new request arrives, `Processor` applies the chat template and perform the tokenization. Then, it route the request to the `VllmWorker`.
- `VllmWorker` and `PrefillWorker`: Perform the actual decode and prefill computation. - `VllmWorker` and `PrefillWorker`: Perform the actual decode and prefill computation.
Since the four workers are deployed in different processes, each of them have their own `DistributedRuntime`. Within their own `DistributedRuntime`, they all have their own `Namespace`s named `dynamo`. Then, under their own `dynamo` namespace, they have their own `Component`s named `Frontend/Processor/VllmWorker/PrefillWorker`. Lastly, for the `Endpoint`, `Frontend` has no `Endpoints`, `Processor` and `VllmWorker` each has a `generate` endpoint, and `PrefillWorker` has a placeholder `mock` endpoint. Their `DistributedRuntime`s and `Namespace`s are set in the `@service` decorators in `examples/llm/components/<frontend/processor/worker/prefill_worker>.py`. Their `Component`s are set by their name in `/deploy/dynamo/sdk/src/dynamo/sdk/cli/serve_dynamo.py`. Their `Endpoint`s are set by the `@dynamo_endpoint` decorators in `examples/llm/components/<frontend/processor/worker/prefill_worker>.py`. Since the four workers are deployed in different processes, each of them have their own `DistributedRuntime`. Within their own `DistributedRuntime`, they all have their own `Namespace`s named `dynamo`. Then, under their own `dynamo` namespace, they have their own `Component`s named `Frontend/Processor/VllmWorker/PrefillWorker`. Lastly, for the `Endpoint`, `Frontend` has no `Endpoints`, `Processor` and `VllmWorker` each has a `generate` endpoint, and `PrefillWorker` has a placeholder `mock` endpoint. Their `DistributedRuntime`s and `Namespace`s are set in the `@service` decorators in `examples/llm/components/<frontend/processor/worker/prefill_worker>.py`. Their `Component`s are set by their name in `/deploy/dynamo/sdk/src/dynamo/sdk/cli/serve_dynamo.py`. Their `Endpoint`s are set by the `@endpoint` decorators in `examples/llm/components/<frontend/processor/worker/prefill_worker>.py`.
## Initialization ## Initialization
......
...@@ -22,7 +22,7 @@ from components.utils import GeneralRequest ...@@ -22,7 +22,7 @@ from components.utils import GeneralRequest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_api, service from dynamo.sdk import DYNAMO_IMAGE, api, depends, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -45,7 +45,7 @@ class Frontend: ...@@ -45,7 +45,7 @@ class Frontend:
logger.debug(f"Received signal {signum}, shutting down...") logger.debug(f"Received signal {signum}, shutting down...")
sys.exit(0) sys.exit(0)
@dynamo_api() @api()
async def generate(self, prompt, request_id): # from request body keys async def generate(self, prompt, request_id): # from request body keys
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
logger.info(f"Received: {prompt=},{request_id=}") logger.info(f"Received: {prompt=},{request_id=}")
......
...@@ -21,7 +21,7 @@ from typing import AsyncIterator ...@@ -21,7 +21,7 @@ from typing import AsyncIterator
from components.utils import check_required_workers from components.utils import check_required_workers
from components.worker import DummyWorker from components.worker import DummyWorker
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
WorkerId = str WorkerId = str
...@@ -92,7 +92,7 @@ class Router: ...@@ -92,7 +92,7 @@ class Router:
# A dummy hit rate checking endpoint # A dummy hit rate checking endpoint
# The actual worker selection is based on custom cost function # The actual worker selection is based on custom cost function
# See details at examples/llm/components/kv_router.py # See details at examples/llm/components/kv_router.py
@dynamo_endpoint() @endpoint()
async def check_hit_rate(self, request_prompt: str) -> AsyncIterator[WorkerId]: async def check_hit_rate(self, request_prompt: str) -> AsyncIterator[WorkerId]:
max_id, max_hit_rate = self._cost_function(request_prompt) max_id, max_hit_rate = self._cost_function(request_prompt)
yield f"{max_id}_{max_hit_rate}" yield f"{max_id}_{max_hit_rate}"
...@@ -23,7 +23,7 @@ import sys ...@@ -23,7 +23,7 @@ import sys
from components.utils import NixlMetadataStore, PrefillQueue, RemotePrefillRequest from components.utils import NixlMetadataStore, PrefillQueue, RemotePrefillRequest
from vllm.distributed.device_communicators.nixl import NixlMetadata from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -102,6 +102,6 @@ class PrefillWorker: ...@@ -102,6 +102,6 @@ class PrefillWorker:
print("Prefill invoked and will read KV cache from worker and write it back") print("Prefill invoked and will read KV cache from worker and write it back")
yield "prefill invoked" yield "prefill invoked"
@dynamo_endpoint() @endpoint()
async def mock(self, req: RemotePrefillRequest): async def mock(self, req: RemotePrefillRequest):
yield f"mock_response: {req}" yield f"mock_response: {req}"
...@@ -21,7 +21,7 @@ from components.utils import GeneralRequest, GeneralResponse, check_required_wor ...@@ -21,7 +21,7 @@ from components.utils import GeneralRequest, GeneralResponse, check_required_wor
from components.worker import DummyWorker from components.worker import DummyWorker
from dynamo._core import Client from dynamo._core import Client
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.dependency import DynamoClient from dynamo.sdk.lib.dependency import DynamoClient
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -96,7 +96,7 @@ class Processor(Protocol): ...@@ -96,7 +96,7 @@ class Processor(Protocol):
async for resp in engine_generator: async for resp in engine_generator:
yield GeneralResponse.model_validate_json(resp.data()) yield GeneralResponse.model_validate_json(resp.data())
@dynamo_endpoint() @endpoint()
async def processor_generate(self, raw_request: GeneralRequest): async def processor_generate(self, raw_request: GeneralRequest):
async for response in self._generate(raw_request): async for response in self._generate(raw_request):
yield response.model_dump_json() yield response.model_dump_json()
...@@ -26,7 +26,7 @@ from components.utils import ( ...@@ -26,7 +26,7 @@ from components.utils import (
) )
from vllm.distributed.device_communicators.nixl import NixlMetadata from vllm.distributed.device_communicators.nixl import NixlMetadata
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -86,7 +86,7 @@ class DummyWorker: ...@@ -86,7 +86,7 @@ class DummyWorker:
return callback return callback
@dynamo_endpoint() @endpoint()
async def worker_generate(self, request: GeneralRequest): async def worker_generate(self, request: GeneralRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely # TODO: consider prefix hit when deciding prefill locally or remotely
......
...@@ -19,7 +19,7 @@ from fastapi.responses import StreamingResponse ...@@ -19,7 +19,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_api, dynamo_endpoint, service from dynamo.sdk import DYNAMO_IMAGE, api, depends, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -70,7 +70,7 @@ class Backend: ...@@ -70,7 +70,7 @@ class Backend:
self.message = config.get("Backend", {}).get("message", "back") self.message = config.get("Backend", {}).get("message", "back")
logger.info(f"Backend config message: {self.message}") logger.info(f"Backend config message: {self.message}")
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
"""Generate tokens.""" """Generate tokens."""
req_text = req.text req_text = req.text
...@@ -93,7 +93,7 @@ class Middle: ...@@ -93,7 +93,7 @@ class Middle:
self.message = config.get("Middle", {}).get("message", "mid") self.message = config.get("Middle", {}).get("message", "mid")
logger.info(f"Middle config message: {self.message}") logger.info(f"Middle config message: {self.message}")
@dynamo_endpoint() @endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
"""Forward requests to backend.""" """Forward requests to backend."""
req_text = req.text req_text = req.text
...@@ -125,8 +125,8 @@ class Frontend: ...@@ -125,8 +125,8 @@ class Frontend:
logger.info(f"Frontend config message: {self.message}") logger.info(f"Frontend config message: {self.message}")
logger.info(f"Frontend config port: {self.port}") logger.info(f"Frontend config port: {self.port}")
# alternative syntax: @dynamo_endpoint(transports=[DynamoTransport.HTTP]) # alternative syntax: @endpoint(transports=[DynamoTransport.HTTP])
@dynamo_api() @api()
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
logger.info(f"Frontend received: {request.text}") logger.info(f"Frontend received: {request.text}")
......
...@@ -26,7 +26,7 @@ from utils.protocol import Tokens ...@@ -26,7 +26,7 @@ from utils.protocol import Tokens
from utils.vllm import RouterType from utils.vllm import RouterType
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
WorkerId = str WorkerId = str
...@@ -247,7 +247,7 @@ class Router: ...@@ -247,7 +247,7 @@ class Router:
) )
return best_worker_id, kv_load[best_worker_id] return best_worker_id, kv_load[best_worker_id]
@dynamo_endpoint() @endpoint()
async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]: async def generate(self, request: Tokens) -> AsyncIterator[Tuple[WorkerId, float]]:
metrics = await self.metrics_aggregator.get_metrics() metrics = await self.metrics_aggregator.get_metrics()
......
...@@ -21,7 +21,7 @@ from pydantic import BaseModel ...@@ -21,7 +21,7 @@ from pydantic import BaseModel
from components.planner import start_planner # type: ignore[attr-defined] from components.planner import start_planner # type: ignore[attr-defined]
from dynamo.planner.defaults import PlannerDefaults from dynamo.planner.defaults import PlannerDefaults
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
from dynamo.sdk.core.protocol.interface import ComponentType from dynamo.sdk.core.protocol.interface import ComponentType
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
from dynamo.sdk.lib.image import DYNAMO_IMAGE from dynamo.sdk.lib.image import DYNAMO_IMAGE
...@@ -109,7 +109,7 @@ class Planner: ...@@ -109,7 +109,7 @@ class Planner:
await start_planner(self.runtime, self.args) await start_planner(self.runtime, self.args)
logger.info("Planner started") logger.info("Planner started")
@dynamo_endpoint() @endpoint()
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Dummy endpoint to satisfy that each component has an endpoint""" """Dummy endpoint to satisfy that each component has an endpoint"""
yield "mock endpoint" yield "mock endpoint"
...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.api_server import ( ...@@ -30,7 +30,7 @@ from vllm.entrypoints.openai.api_server import (
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from dynamo.sdk import async_on_start, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, dynamo_context, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -206,6 +206,6 @@ class PrefillWorker: ...@@ -206,6 +206,6 @@ class PrefillWorker:
): ):
yield yield
@dynamo_endpoint() @endpoint()
async def mock(self, req: RequestType): async def mock(self, req: RequestType):
yield f"mock_response: {req}" yield f"mock_response: {req}"
...@@ -33,7 +33,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer ...@@ -33,7 +33,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from dynamo.llm import KvMetricsAggregator from dynamo.llm import KvMetricsAggregator
from dynamo.runtime import EtcdKvCache from dynamo.runtime import EtcdKvCache
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -173,16 +173,16 @@ class Processor(ProcessMixIn): ...@@ -173,16 +173,16 @@ class Processor(ProcessMixIn):
async def _get_kv_load(self): async def _get_kv_load(self):
metrics = await self.metrics_aggregator.get_metrics() metrics = await self.metrics_aggregator.get_metrics()
kv_load = {} kv_load = {}
for endpoint in metrics.endpoints: for end_point in metrics.endpoints:
worker_id = endpoint.worker_id worker_id = end_point.worker_id
kv_load[worker_id] = getattr(endpoint, "gpu_cache_usage_perc", 0.0) kv_load[worker_id] = getattr(end_point, "gpu_cache_usage_perc", 0.0)
return kv_load return kv_load
async def _get_pending_requests(self): async def _get_pending_requests(self):
metrics = await self.metrics_aggregator.get_metrics() metrics = await self.metrics_aggregator.get_metrics()
pending_requests = {} pending_requests = {}
for endpoint in metrics.endpoints: for end_point in metrics.endpoints:
worker_id = endpoint.worker_id worker_id = end_point.worker_id
pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0) pending_requests[worker_id] = getattr(endpoint, "num_requests_waiting", 0)
return pending_requests return pending_requests
...@@ -327,12 +327,12 @@ class Processor(ProcessMixIn): ...@@ -327,12 +327,12 @@ class Processor(ProcessMixIn):
f"Request type {request_type} not implemented" f"Request type {request_type} not implemented"
) )
@dynamo_endpoint(name="chat/completions") @endpoint(name="chat/completions")
async def chat_completions(self, raw_request: ChatCompletionRequest): async def chat_completions(self, raw_request: ChatCompletionRequest):
async for response in self._generate(raw_request, RequestType.CHAT): async for response in self._generate(raw_request, RequestType.CHAT):
yield response yield response
# @dynamo_endpoint() # @endpoint()
# async def completions(self, raw_request: CompletionRequest): # async def completions(self, raw_request: CompletionRequest):
# async for response in self._generate(raw_request, RequestType.COMPLETION): # async for response in self._generate(raw_request, RequestType.COMPLETION):
# yield response # yield response
...@@ -32,7 +32,7 @@ from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest ...@@ -32,7 +32,7 @@ from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest
from vllm.sampling_params import RequestOutputKind from vllm.sampling_params import RequestOutputKind
from dynamo.llm import KvMetricsPublisher from dynamo.llm import KvMetricsPublisher
from dynamo.sdk import async_on_start, depends, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_on_start, depends, dynamo_context, endpoint, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -183,7 +183,7 @@ class VllmWorker: ...@@ -183,7 +183,7 @@ class VllmWorker:
return callback return callback
# TODO: use the same child lease for metrics publisher endpoint and generate endpoint # TODO: use the same child lease for metrics publisher endpoint and generate endpoint
@dynamo_endpoint() @endpoint()
async def generate(self, request: vLLMGenerateRequest): async def generate(self, request: vLLMGenerateRequest):
# TODO: consider prefix hit when deciding prefill locally or remotely # TODO: consider prefix hit when deciding prefill locally or remotely
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment