Unverified Commit d675d221 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: decouple dynamo sdk to support mutiple deployment targets (#905)

parent 5d5235bc
...@@ -17,7 +17,6 @@ from __future__ import annotations ...@@ -17,7 +17,6 @@ from __future__ import annotations
import json import json
import logging import logging
import os import os
from collections import defaultdict
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
...@@ -28,6 +27,7 @@ from _bentoml_sdk.images import Image ...@@ -28,6 +27,7 @@ from _bentoml_sdk.images import Image
from _bentoml_sdk.service.config import validate from _bentoml_sdk.service.config import validate
from fastapi import FastAPI from fastapi import FastAPI
from dynamo.sdk.core.protocol.interface import DynamoTransport, LinkedServices
from dynamo.sdk.lib.decorators import DynamoEndpoint from dynamo.sdk.lib.decorators import DynamoEndpoint
T = TypeVar("T", bound=object) T = TypeVar("T", bound=object)
...@@ -45,32 +45,6 @@ class ComponentType(str, Enum): ...@@ -45,32 +45,6 @@ class ComponentType(str, Enum):
# etc. # etc.
class RuntimeLinkedServices:
"""
A class to track the linked services in the runtime.
"""
def __init__(self) -> None:
self.edges: Dict[DynamoService, Set[DynamoService]] = defaultdict(set)
def add(self, edge: Tuple[DynamoService, DynamoService]):
src, dest = edge
self.edges[src].add(dest.inner)
# track the dest node as well so we can cleanup later
self.edges[dest]
def remove_unused_edges(self):
# this method is idempotent
if not self.edges:
return
# remove edges that are not in the current service
for u, vertices in self.edges.items():
u.remove_unused_edges(used_edges=vertices)
LinkedServices = RuntimeLinkedServices()
@dataclass @dataclass
class DynamoConfig: class DynamoConfig:
"""Configuration for Dynamo components""" """Configuration for Dynamo components"""
...@@ -152,7 +126,9 @@ class DynamoService(Service[T]): ...@@ -152,7 +126,9 @@ class DynamoService(Service[T]):
value = getattr(inner, field) value = getattr(inner, field)
if isinstance(value, DynamoEndpoint): if isinstance(value, DynamoEndpoint):
self._dynamo_endpoints[value.name] = value self._dynamo_endpoints[value.name] = value
if getattr(value, "is_api", False): if DynamoTransport.HTTP in getattr(
value, "_transports", [DynamoTransport.DEFAULT]
):
# Ensure endpoint path starts with '/' # Ensure endpoint path starts with '/'
path = ( path = (
value.name if value.name.startswith("/") else f"/{value.name}" value.name if value.name.startswith("/") else f"/{value.name}"
...@@ -174,15 +150,8 @@ class DynamoService(Service[T]): ...@@ -174,15 +150,8 @@ class DynamoService(Service[T]):
return service_config.get("ServiceArgs") return service_config.get("ServiceArgs")
return None return None
def is_dynamo_component(self) -> bool:
"""Check if this service is configured as a Dynamo component"""
return self._dynamo_config.enabled
def dynamo_address(self) -> Tuple[Optional[str], Optional[str]]: def dynamo_address(self) -> Tuple[Optional[str], Optional[str]]:
"""Get the Dynamo address for this component in namespace/name format""" """Get the Dynamo address for this component in namespace/name format"""
if not self.is_dynamo_component():
raise ValueError("Service is not configured as a Dynamo component")
# Check if we have a runner map with Dynamo address # Check if we have a runner map with Dynamo address
runner_map = os.environ.get("BENTOML_RUNNER_MAP") runner_map = os.environ.get("BENTOML_RUNNER_MAP")
if runner_map: if runner_map:
......
...@@ -17,11 +17,11 @@ ...@@ -17,11 +17,11 @@
# Use this to test changes made to CLI, SDK, etc # Use this to test changes made to CLI, SDK, etc
from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.sdk import depends, dynamo_endpoint, service from dynamo.sdk import depends, dynamo_endpoint, service
from dynamo.sdk.core.protocol.interface import DynamoTransport
""" """
Pipeline Architecture: Pipeline Architecture:
...@@ -56,14 +56,10 @@ class ResponseType(BaseModel): ...@@ -56,14 +56,10 @@ class ResponseType(BaseModel):
GPU_ENABLED = False GPU_ENABLED = False
app = FastAPI(title="Hello World!")
@service( @service(
resources={"cpu": "1"}, resources={"cpu": "1"},
traffic={"timeout": 30}, traffic={"timeout": 30},
dynamo={ dynamo={
"enabled": True,
"namespace": "inference", "namespace": "inference",
}, },
workers=1, workers=1,
...@@ -94,7 +90,7 @@ class Backend: ...@@ -94,7 +90,7 @@ class Backend:
@service( @service(
resources={"cpu": "2"}, resources={"cpu": "2"},
traffic={"timeout": 30}, traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"}, dynamo={"namespace": "inference"},
) )
class Backend2: class Backend2:
backend = depends(Backend) backend = depends(Backend)
...@@ -116,7 +112,7 @@ class Backend2: ...@@ -116,7 +112,7 @@ class Backend2:
@service( @service(
resources={"cpu": "1"}, resources={"cpu": "1"},
traffic={"timeout": 30}, traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"}, dynamo={"namespace": "inference"},
) )
class Middle: class Middle:
backend = depends(Backend) backend = depends(Backend)
...@@ -150,8 +146,7 @@ class Middle: ...@@ -150,8 +146,7 @@ class Middle:
@service( @service(
resources={"cpu": "1"}, resources={"cpu": "1"},
traffic={"timeout": 60}, traffic={"timeout": 60},
dynamo={"enabled": True, "namespace": "inference"}, dynamo={"namespace": "inference"},
app=app,
) )
class Frontend: class Frontend:
middle = depends(Middle) middle = depends(Middle)
...@@ -160,7 +155,7 @@ class Frontend: ...@@ -160,7 +155,7 @@ class Frontend:
def __init__(self) -> None: def __init__(self) -> None:
print("Starting frontend") print("Starting frontend")
@dynamo_endpoint(is_api=True) @dynamo_endpoint(transports=[DynamoTransport.HTTP])
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
print(f"Frontend received: {request.text}") print(f"Frontend received: {request.text}")
......
...@@ -105,8 +105,8 @@ async def test_pipeline(setup_and_teardown): ...@@ -105,8 +105,8 @@ async def test_pipeline(setup_and_teardown):
in text in text
) )
break break
except Exception: except Exception as e:
if attempt == max_retries - 1: if attempt == max_retries - 1:
raise raise
print(f"Attempt {attempt + 1} failed, retrying...") print(f"Attempt {attempt + 1} failed, retrying... {e}")
await asyncio.sleep(3) await asyncio.sleep(3)
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import pytest import pytest
from dynamo.sdk.lib.service import LinkedServices from dynamo.sdk.core.protocol.interface import LinkedServices
pytestmark = pytest.mark.pre_merge pytestmark = pytest.mark.pre_merge
...@@ -38,8 +38,8 @@ def test_remove_backend2(): ...@@ -38,8 +38,8 @@ def test_remove_backend2():
LinkedServices.remove_unused_edges() LinkedServices.remove_unused_edges()
# Final state assertions after linking and cleanup # Final state assertions after linking and cleanup
assert set(Frontend.dependencies.keys()) == {"middle"}
assert Frontend.dependencies["middle"].on == Middle assert Frontend.dependencies["middle"].on == Middle
assert set(Frontend.dependencies.keys()) == {"middle"}
assert set(Middle.dependencies.keys()) == {"backend"} assert set(Middle.dependencies.keys()) == {"backend"}
assert Middle.dependencies["backend"].on == Backend assert Middle.dependencies["backend"].on == Backend
......
...@@ -30,7 +30,6 @@ from dynamo.sdk import dynamo_endpoint, service ...@@ -30,7 +30,6 @@ from dynamo.sdk import dynamo_endpoint, service
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
}, },
) )
...@@ -108,7 +107,6 @@ class ResponseType(BaseModel): ...@@ -108,7 +107,6 @@ class ResponseType(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
} }
) )
...@@ -218,7 +216,6 @@ class ResponseType(BaseModel): ...@@ -218,7 +216,6 @@ class ResponseType(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
} }
) )
...@@ -312,7 +309,6 @@ in your class implementation: ...@@ -312,7 +309,6 @@ in your class implementation:
```python ```python
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
}, },
) )
...@@ -340,7 +336,6 @@ your own custom metrics and use them in your cost function: ...@@ -340,7 +336,6 @@ your own custom metrics and use them in your cost function:
```python ```python
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
}, },
) )
...@@ -551,7 +546,6 @@ disaggregation, the DecodeWorker could just always do the Prefill step as well. ...@@ -551,7 +546,6 @@ disaggregation, the DecodeWorker could just always do the Prefill step as well.
```python ```python
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
}, },
) )
...@@ -579,7 +573,6 @@ class DecodeWorker: ...@@ -579,7 +573,6 @@ class DecodeWorker:
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "your_namespace", "namespace": "your_namespace",
}, },
) )
......
...@@ -22,8 +22,7 @@ from components.utils import GeneralRequest ...@@ -22,8 +22,7 @@ from components.utils import GeneralRequest
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from dynamo.sdk import depends, dynamo_endpoint, service from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_api, service
from dynamo.sdk.lib.image import DYNAMO_IMAGE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,7 +30,7 @@ app = FastAPI(title="Hello World LLM") ...@@ -31,7 +30,7 @@ app = FastAPI(title="Hello World LLM")
@service( @service(
dynamo={"enabled": True, "namespace": "dynamo-demo"}, dynamo={"namespace": "dynamo-demo"},
image=DYNAMO_IMAGE, image=DYNAMO_IMAGE,
app=app, app=app,
) )
...@@ -46,7 +45,7 @@ class Frontend: ...@@ -46,7 +45,7 @@ class Frontend:
logger.debug(f"Received signal {signum}, shutting down...") logger.debug(f"Received signal {signum}, shutting down...")
sys.exit(0) sys.exit(0)
@dynamo_endpoint(is_api=True) @dynamo_api()
async def generate(self, prompt, request_id): # from request body keys async def generate(self, prompt, request_id): # from request body keys
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
logger.info(f"Received: {prompt=},{request_id=}") logger.info(f"Received: {prompt=},{request_id=}")
......
...@@ -30,7 +30,6 @@ logger = logging.getLogger(__name__) ...@@ -30,7 +30,6 @@ logger = logging.getLogger(__name__)
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo-demo", "namespace": "dynamo-demo",
}, },
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
......
...@@ -14,17 +14,18 @@ ...@@ -14,17 +14,18 @@
# limitations under the License. # limitations under the License.
import logging import logging
import os
from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_endpoint, service from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_api, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
""" """
Pipeline Architecture: Pipeline Architecture:
...@@ -57,7 +58,6 @@ class ResponseType(BaseModel): ...@@ -57,7 +58,6 @@ class ResponseType(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "inference", "namespace": "inference",
}, },
image=DYNAMO_IMAGE, image=DYNAMO_IMAGE,
...@@ -76,11 +76,11 @@ class Backend: ...@@ -76,11 +76,11 @@ class Backend:
logger.info(f"Backend received: {req_text}") logger.info(f"Backend received: {req_text}")
text = f"{req_text}-{self.message}" text = f"{req_text}-{self.message}"
for token in text.split(): for token in text.split():
yield f"Backend: {token}" yield f"[process_id:{os.getpid()}] Backend: {token}"
@service( @service(
dynamo={"enabled": True, "namespace": "inference"}, dynamo={"namespace": "inference"},
image=DYNAMO_IMAGE, image=DYNAMO_IMAGE,
) )
class Middle: class Middle:
...@@ -101,16 +101,12 @@ class Middle: ...@@ -101,16 +101,12 @@ class Middle:
next_request = RequestType(text=text).model_dump_json() next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request): async for response in self.backend.generate(next_request):
logger.info(f"Middle received response: {response}") logger.info(f"Middle received response: {response}")
yield f"Middle: {response}" yield f"[process_id:{os.getpid()}] Middle: {response}"
app = FastAPI(title="Hello World!")
@service( @service(
dynamo={"enabled": True, "namespace": "inference"}, dynamo={"namespace": "inference"},
image=DYNAMO_IMAGE, image=DYNAMO_IMAGE,
app=app,
) )
class Frontend: class Frontend:
"""A simple frontend HTTP API that forwards requests to the dynamo graph.""" """A simple frontend HTTP API that forwards requests to the dynamo graph."""
...@@ -128,13 +124,14 @@ class Frontend: ...@@ -128,13 +124,14 @@ class Frontend:
logger.info(f"Frontend config message: {self.message}") logger.info(f"Frontend config message: {self.message}")
logger.info(f"Frontend config port: {self.port}") logger.info(f"Frontend config port: {self.port}")
@dynamo_endpoint(is_api=True) # alternative syntax: @dynamo_endpoint(transports=[DynamoTransport.HTTP])
@dynamo_api()
async def generate(self, request: RequestType): async def generate(self, request: RequestType):
"""Stream results from the pipeline.""" """Stream results from the pipeline."""
logger.info(f"Frontend received: {request.text}") logger.info(f"Frontend received: {request.text}")
async def content_generator(): async def content_generator():
async for response in self.middle.generate(request.model_dump_json()): async for response in self.middle.generate(request.model_dump_json()):
yield f"Frontend: {response}" yield f"[process_id:{os.getpid()}] Frontend: {response}"
return StreamingResponse(content_generator()) return StreamingResponse(content_generator())
...@@ -20,7 +20,6 @@ from pathlib import Path ...@@ -20,7 +20,6 @@ from pathlib import Path
from components.planner_service import Planner from components.planner_service import Planner
from components.processor import Processor from components.processor import Processor
from components.worker import VllmWorker from components.worker import VllmWorker
from fastapi import FastAPI
from pydantic import BaseModel from pydantic import BaseModel
from dynamo import sdk from dynamo import sdk
...@@ -52,13 +51,11 @@ class FrontendConfig(BaseModel): ...@@ -52,13 +51,11 @@ class FrontendConfig(BaseModel):
# todo this should be called ApiServer # todo this should be called ApiServer
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
workers=1, workers=1,
image=DYNAMO_IMAGE, image=DYNAMO_IMAGE,
app=FastAPI(title="LLM Example"),
) )
class Frontend: class Frontend:
planner = depends(Planner) planner = depends(Planner)
...@@ -71,7 +68,6 @@ class Frontend: ...@@ -71,7 +68,6 @@ class Frontend:
frontend_config = FrontendConfig(**config.get("Frontend", {})) frontend_config = FrontendConfig(**config.get("Frontend", {}))
self.frontend_config = frontend_config self.frontend_config = frontend_config
self.process = None self.process = None
self.setup_model() self.setup_model()
self.start_http_server() self.start_http_server()
......
...@@ -76,7 +76,6 @@ def parse_args(service_name, prefix) -> Namespace: ...@@ -76,7 +76,6 @@ def parse_args(service_name, prefix) -> Namespace:
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
......
...@@ -33,7 +33,6 @@ class RequestType(BaseModel): ...@@ -33,7 +33,6 @@ class RequestType(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
"component_type": "planner", "component_type": "planner",
}, },
......
...@@ -41,7 +41,6 @@ class RequestType(BaseModel): ...@@ -41,7 +41,6 @@ class RequestType(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
......
...@@ -45,7 +45,6 @@ class RequestType(Enum): ...@@ -45,7 +45,6 @@ class RequestType(Enum):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
......
...@@ -39,7 +39,6 @@ logger = logging.getLogger(__name__) ...@@ -39,7 +39,6 @@ logger = logging.getLogger(__name__)
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
......
...@@ -31,7 +31,6 @@ logger = logging.getLogger(__name__) ...@@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
......
...@@ -20,14 +20,13 @@ from fastapi import FastAPI ...@@ -20,14 +20,13 @@ from fastapi import FastAPI
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from utils.protocol import MultiModalRequest from utils.protocol import MultiModalRequest
from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_endpoint, service from dynamo.sdk import DYNAMO_IMAGE, depends, dynamo_api, service
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
...@@ -38,7 +37,7 @@ logger = logging.getLogger(__name__) ...@@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
class Frontend: class Frontend:
processor = depends(Processor) processor = depends(Processor)
@dynamo_endpoint(is_api=True) @dynamo_api()
async def generate(self, request: MultiModalRequest): async def generate(self, request: MultiModalRequest):
async def content_generator(): async def content_generator():
async for response in self.processor.generate(request.model_dump_json()): async for response in self.processor.generate(request.model_dump_json()):
......
...@@ -45,7 +45,6 @@ class RequestType(BaseModel): ...@@ -45,7 +45,6 @@ class RequestType(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
......
...@@ -43,7 +43,6 @@ class RequestType(Enum): ...@@ -43,7 +43,6 @@ class RequestType(Enum):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"cpu": "10", "memory": "20Gi"}, resources={"cpu": "10", "memory": "20Gi"},
......
...@@ -48,7 +48,6 @@ logger = logging.getLogger(__name__) ...@@ -48,7 +48,6 @@ logger = logging.getLogger(__name__)
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
resources={"gpu": 1, "cpu": "10", "memory": "20Gi"}, resources={"gpu": 1, "cpu": "10", "memory": "20Gi"},
......
...@@ -49,7 +49,6 @@ class FrontendConfig(BaseModel): ...@@ -49,7 +49,6 @@ class FrontendConfig(BaseModel):
@service( @service(
dynamo={ dynamo={
"enabled": True,
"namespace": "dynamo", "namespace": "dynamo",
}, },
workers=1, workers=1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment