"runtime/rust/python-wheel/pyproject.toml" did not exist on "df90e29eba6aaf62f8fa67b8e9092d8a81c25856"
Commit b8120504 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(sdk): add initial graph structure for prebuilt components (#130)


Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent 4f7f4b40
...@@ -24,6 +24,7 @@ import typing as t ...@@ -24,6 +24,7 @@ import typing as t
import click import click
import rich import rich
import yaml
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
P = t.ParamSpec("P") # type: ignore P = t.ParamSpec("P") # type: ignore
...@@ -126,6 +127,12 @@ def build_serve_command() -> click.Group: ...@@ -126,6 +127,12 @@ def build_serve_command() -> click.Group:
cls=AliasCommand, cls=AliasCommand,
) )
@click.argument("bento", type=click.STRING, default=".") @click.argument("bento", type=click.STRING, default=".")
@click.option(
"-f",
"--file",
type=click.Path(exists=True),
help="Path to YAML config file for service configuration",
)
@click.option( @click.option(
"--development", "--development",
type=click.BOOL, type=click.BOOL,
...@@ -266,6 +273,7 @@ def build_serve_command() -> click.Group: ...@@ -266,6 +273,7 @@ def build_serve_command() -> click.Group:
development: bool, development: bool,
port: int, port: int,
host: str, host: str,
file: str | None,
api_workers: int, api_workers: int,
timeout: int | None, timeout: int | None,
backlog: int, backlog: int,
...@@ -321,8 +329,25 @@ def build_serve_command() -> click.Group: ...@@ -321,8 +329,25 @@ def build_serve_command() -> click.Group:
from bentoml import Service from bentoml import Service
from bentoml._internal.service.loader import load from bentoml._internal.service.loader import load
from dynamo.sdk.lib.service import LinkedServices
# Process service-specific options # Process service-specific options
service_configs = _parse_service_args(ctx.args) service_configs: t.Optional[t.Dict[str, t.Any]] = _parse_service_args(ctx.args)
# Load and merge config file if provided
if file:
with open(file) as f:
yaml_configs = yaml.safe_load(f)
# Initialize service_configs as empty dict if it's None
service_configs = dict(service_configs or {})
# Convert nested YAML structure to flat dict with dot notation
for service, configs in yaml_configs.items():
for key, value in configs.items():
if service not in service_configs:
service_configs[service] = {}
service_configs[service][key] = value
# print("service_configs", service_configs)
# Set environment variable with service configuration # Set environment variable with service configuration
if service_configs: if service_configs:
...@@ -337,6 +362,8 @@ def build_serve_command() -> click.Group: ...@@ -337,6 +362,8 @@ def build_serve_command() -> click.Group:
if sys.path[0] != working_dir: if sys.path[0] != working_dir:
sys.path.insert(0, working_dir) sys.path.insert(0, working_dir)
svc = load(bento_identifier=bento, working_dir=working_dir) svc = load(bento_identifier=bento, working_dir=working_dir)
LinkedServices.remove_unused_edges()
if isinstance(svc, Service): if isinstance(svc, Service):
# bentoml<1.2 # bentoml<1.2
from bentoml.serving import serve_http_production from bentoml.serving import serve_http_production
......
...@@ -29,6 +29,7 @@ import click ...@@ -29,6 +29,7 @@ import click
from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker from dynamo.runtime import DistributedRuntime, dynamo_endpoint, dynamo_worker
from dynamo.sdk import dynamo_context from dynamo.sdk import dynamo_context
from dynamo.sdk.lib.service import LinkedServices
logger = logging.getLogger("dynamo.sdk.serve.dynamo") logger = logging.getLogger("dynamo.sdk.serve.dynamo")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
...@@ -99,6 +100,8 @@ def main( ...@@ -99,6 +100,8 @@ def main(
t.cast(t.Dict[str, str], json.loads(runner_map)) t.cast(t.Dict[str, str], json.loads(runner_map))
) )
# TODO: test this with a deep chain of services
LinkedServices.remove_unused_edges()
# Check if Dynamo is enabled for this service # Check if Dynamo is enabled for this service
if service.is_dynamo_component(): if service.is_dynamo_component():
if worker_id is not None: if worker_id is not None:
......
...@@ -16,17 +16,45 @@ from __future__ import annotations ...@@ -16,17 +16,45 @@ from __future__ import annotations
import json import json
import os import os
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
from _bentoml_sdk import Service, ServiceConfig from _bentoml_sdk import Service, ServiceConfig
from _bentoml_sdk.images import Image from _bentoml_sdk.images import Image
from _bentoml_sdk.service.config import validate
from dynamo.sdk.lib.decorators import DynamoEndpoint from dynamo.sdk.lib.decorators import DynamoEndpoint
T = TypeVar("T", bound=object) T = TypeVar("T", bound=object)
class RuntimeLinkedServices:
"""
A class to track the linked services in the runtime.
"""
def __init__(self) -> None:
self.edges: Dict[DynamoService, Set[DynamoService]] = defaultdict(set)
def add(self, edge: Tuple[DynamoService, DynamoService]):
src, dest = edge
self.edges[src].add(dest.inner)
# track the dest node as well so we can cleanup later
self.edges[dest]
def remove_unused_edges(self):
# this method is idempotent
if not self.edges:
return
# remove edges that are not in the current service
for u, vertices in self.edges.items():
u.remove_unused_edges(used_edges=vertices)
LinkedServices = RuntimeLinkedServices()
@dataclass @dataclass
class DynamoConfig: class DynamoConfig:
"""Configuration for Dynamo components""" """Configuration for Dynamo components"""
...@@ -47,6 +75,15 @@ class DynamoService(Service[T]): ...@@ -47,6 +75,15 @@ class DynamoService(Service[T]):
envs: Optional[list[dict[str, Any]]] = None, envs: Optional[list[dict[str, Any]]] = None,
dynamo_config: Optional[DynamoConfig] = None, dynamo_config: Optional[DynamoConfig] = None,
): ):
service_name = inner.__name__
service_args = self._get_service_args(service_name)
if service_args:
# Validate and merge service args with existing config
validated_args = validate(service_args)
config.update(validated_args)
self._remove_service_args(service_name)
super().__init__(config=config, inner=inner, image=image, envs=envs or []) super().__init__(config=config, inner=inner, image=image, envs=envs or [])
# Initialize Dynamo configuration # Initialize Dynamo configuration
...@@ -65,6 +102,17 @@ class DynamoService(Service[T]): ...@@ -65,6 +102,17 @@ class DynamoService(Service[T]):
if isinstance(value, DynamoEndpoint): if isinstance(value, DynamoEndpoint):
self._dynamo_endpoints[value.name] = value self._dynamo_endpoints[value.name] = value
self._linked_services: List[DynamoService] = [] # Track linked services
def _get_service_args(self, service_name: str) -> Optional[dict]:
"""Get ServiceArgs from environment config if specified"""
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str:
config = json.loads(config_str)
service_config = config.get(service_name, {})
return service_config.get("ServiceArgs")
return None
def is_dynamo_component(self) -> bool: def is_dynamo_component(self) -> bool:
"""Check if this service is configured as a Dynamo component""" """Check if this service is configured as a Dynamo component"""
return self._dynamo_config.enabled return self._dynamo_config.enabled
...@@ -111,7 +159,27 @@ class DynamoService(Service[T]): ...@@ -111,7 +159,27 @@ class DynamoService(Service[T]):
"""List names of all registered Dynamo endpoints""" """List names of all registered Dynamo endpoints"""
return list(self._dynamo_endpoints.keys()) return list(self._dynamo_endpoints.keys())
# todo: add another function to bind an instance of the inner to the self within these methods def remove_unused_edges(self, used_edges: Set[DynamoService]):
"""Remove a dependancy from the current service based on the key"""
current_deps = dict(self.dependencies)
for dep_key, dep_value in current_deps.items():
if dep_value.on.inner not in used_edges:
del self.dependencies[dep_key]
def link(self, next_service: DynamoService):
"""Link this service to another service, creating a pipeline."""
self._linked_services.append(next_service)
LinkedServices.add((self, next_service))
return next_service
def _remove_service_args(self, service_name: str):
"""Remove ServiceArgs from the environment config after using them"""
config_str = os.environ.get("DYNAMO_SERVICE_CONFIG")
if config_str:
config = json.loads(config_str)
if service_name in config and "ServiceArgs" in config[service_name]:
del config[service_name]["ServiceArgs"]
os.environ["DYNAMO_SERVICE_CONFIG"] = json.dumps(config)
def service( def service(
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# linking syntax example
from dynamo.sdk.tests.pipeline import Backend, Frontend, Middle
# print("INITIAL DEPENDENCIES")
# print("Frontend dependencies", Frontend.dependencies)
# print("Middle dependencies", Middle.dependencies)
# print("Backend dependencies", Backend.dependencies)
# print("\n\n\n")
print()
Frontend.link(Middle).build()
print("Frontend dependencies", Frontend.dependencies)
print("Middle dependencies", Middle.dependencies)
print("Backend dependencies", Backend.dependencies)
...@@ -16,12 +16,10 @@ ...@@ -16,12 +16,10 @@
# This is a simple example of a pipeline that uses Dynamo to deploy a backend, middle, and frontend service. Use this to test # This is a simple example of a pipeline that uses Dynamo to deploy a backend, middle, and frontend service. Use this to test
# changes made to CLI, SDK, etc # changes made to CLI, SDK, etc
import os
from pydantic import BaseModel from pydantic import BaseModel
from dynamo.sdk import api, depends, dynamo_endpoint, service from dynamo.sdk import api, depends, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
""" """
Pipeline Architecture: Pipeline Architecture:
...@@ -56,25 +54,14 @@ class ResponseType(BaseModel): ...@@ -56,25 +54,14 @@ class ResponseType(BaseModel):
GPU_ENABLED = False GPU_ENABLED = False
class FrontendConfig(BaseModel):
model: str
temperature: float = 0.7
max_tokens: int = 1024
stream: bool = True
class MiddleConfig(BaseModel):
bias: float
@service( @service(
resources={"cpu": "2"}, resources={"cpu": "1"},
traffic={"timeout": 30}, traffic={"timeout": 30},
dynamo={ dynamo={
"enabled": True, "enabled": True,
"namespace": "inference", "namespace": "inference",
}, },
workers=3, workers=1,
) )
class Backend: class Backend:
def __init__(self) -> None: def __init__(self) -> None:
...@@ -95,74 +82,52 @@ class Backend: ...@@ -95,74 +82,52 @@ class Backend:
traffic={"timeout": 30}, traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"}, dynamo={"enabled": True, "namespace": "inference"},
) )
class Backend2:
backend = depends(Backend)
def __init__(self) -> None:
print("Starting middle2")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
print(f"Middle2 received: {req_text}")
text = f"{req_text}-mid2"
next_request = RequestType(text=text).model_dump_json()
print(next_request)
@service(
resources={"cpu": "1"},
traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"},
)
class Middle: class Middle:
backend = depends(Backend) backend = depends(Backend)
backend2 = depends(Backend2)
def __init__(self) -> None: def __init__(self) -> None:
print("Starting middle") print("Starting middle")
config = ServiceConfig.get_instance()
middle_config = MiddleConfig(**config.get("Middle", {}))
print(f"bias: {middle_config.bias}")
if GPU_ENABLED:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser
try:
os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
# Get VLLM args using new pattern
vllm_args = config.as_args("Middle", prefix="vllm_")
print(f"VLLM args to parse: {vllm_args}")
# Create and use parser
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
self.engine_args = AsyncEngineArgs.from_cli_args(args)
self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)
except ImportError:
print("VLLM imports not available, skipping engine arg parsing")
except Exception as e:
print(f"Error parsing VLLM args: {e}")
@dynamo_endpoint() @dynamo_endpoint()
async def generate(self, req: RequestType): async def generate(self, req: RequestType):
"""Forward requests to backend.""" """Forward requests to backend."""
req_text = req.text req_text = req.text
print(f"Middle received: {req_text}") print(f"Middle received: {req_text}")
text = f"{req_text}-mid" text = f"{req_text}-mid"
next_request = RequestType(text=text).model_dump_json() for token in text.split():
async for response in self.backend.generate(next_request): yield f"Mid: {token}"
print(f"Middle received response: {response}")
yield f"Middle: {response}"
@service(resources={"cpu": "1"}, traffic={"timeout": 60}) # Regular HTTP API @service(resources={"cpu": "1"}, traffic={"timeout": 60})
class Frontend: class Frontend:
middle = depends(Middle) middle = depends(Middle)
backend = depends(Backend)
def __init__(self) -> None: def __init__(self) -> None:
print("Starting frontend") print("Starting frontend")
self.config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**self.config.get("Frontend", {}))
print(
f"Frontend initialized with model={frontend_config.model}, "
f"temp={frontend_config.temperature}, max_tokens={frontend_config.max_tokens}"
)
# Get all configs for a service (new dict pattern)
all_frontend_configs = self.config.get("Frontend", {})
print(f"All Frontend configs: {all_frontend_configs}")
# Check other service configs (new dict pattern)
if self.config.get("Middle", {}).get("special_mode") == "fast":
print("Using Middle service in fast mode")
@api @api
async def generate(self, text): async def generate(self, text):
...@@ -171,5 +136,11 @@ class Frontend: ...@@ -171,5 +136,11 @@ class Frontend:
print(f"Frontend received type: {type(text)}") print(f"Frontend received type: {type(text)}")
txt = RequestType(text=text) txt = RequestType(text=text)
print(f"Frontend sending: {type(txt)}") print(f"Frontend sending: {type(txt)}")
async for response in self.middle.generate(txt.model_dump_json()): if self.backend:
yield f"Frontend: {response}" async for back_resp in self.backend.generate(txt.model_dump_json()):
print(f"Frontend received back_resp: {back_resp}")
yield f"Frontend: {back_resp}"
else:
async for mid_resp in self.middle.generate(txt.model_dump_json()):
print(f"Frontend received mid_resp: {mid_resp}")
yield f"Frontend: {mid_resp}"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo-init.VllmWorker.generate
port: 8000
VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
conditional-disagg: true
remote-prefill: true
PrefillWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
cuda-visible-device-offset: 1
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from disaggregated.frontend import Frontend
from disaggregated.kv_router import Router
from disaggregated.processor import Processor
from disaggregated.worker import VllmWorker
# example 2 and 3: kv aware routing + worker
# kv.yaml
Frontend.link(Processor).link(Router).link(VllmWorker)
# example 4 and 5: only disag - issue with endpoint (probably because of routerless)
# disag.yaml
# Frontend.link(VllmWorker).link(PrefillWorker)
# example 6: disag with kv
# kv_with_disag.yaml
# Frontend.link(Processor).link(Router).link(VllmWorker).link(PrefillWorker)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
from disaggregated.processor import Processor
from disaggregated.worker import VllmWorker
from pydantic import BaseModel
from dynamo.sdk import depends, service
from dynamo.sdk.lib.config import ServiceConfig
class FrontendConfig(BaseModel):
model: str
endpoint: str
port: int = 8080
@service(
resources={"cpu": "10", "memory": "20Gi"},
workers=1,
)
# todo this should be called ApiServer
class Frontend:
worker = depends(VllmWorker)
processor = depends(Processor)
def __init__(self):
config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**config.get("Frontend", {}))
os.environ["TRT_LOG"] = "DEBUG"
subprocess.run(
["llmctl", "http", "remove", "chat-models", frontend_config.model]
)
subprocess.run(
[
"llmctl",
"http",
"add",
"chat-models",
frontend_config.model,
frontend_config.endpoint,
]
)
subprocess.run(
["http", "-p", str(frontend_config.port)], stdout=None, stderr=None
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo-init.Processor.chat/completions
port: 8000
Processor:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
router: kv
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
enable-prefix-caching: true
router: kv
remote-prefill: true
tensor-parallel-size: 1
ServiceArgs:
workers: 2
envs:
- CUDA_VISIBLE_DEVICES: '0,1'
...@@ -20,11 +20,12 @@ import random ...@@ -20,11 +20,12 @@ import random
from argparse import Namespace from argparse import Namespace
from typing import AsyncIterator from typing import AsyncIterator
from disaggregated.worker import VllmWorker
from utils.protocol import Tokens from utils.protocol import Tokens
from vllm.logger import logger as vllm_logger from vllm.logger import logger as vllm_logger
from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores from dynamo.llm import AggregatedMetrics, KvIndexer, KvMetricsAggregator, OverlapScores
from dynamo.sdk import async_onstart, dynamo_context, dynamo_endpoint, service from dynamo.sdk import async_onstart, depends, dynamo_context, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig from dynamo.sdk.lib.config import ServiceConfig
WorkerId = str WorkerId = str
...@@ -76,10 +77,11 @@ class Router: ...@@ -76,10 +77,11 @@ class Router:
Request handler for the generate endpoint Request handler for the generate endpoint
""" """
worker = depends(VllmWorker)
def __init__(self): def __init__(self):
vllm_logger.info("Initializing Custom Router") vllm_logger.info("Initializing Custom Router")
self.args = parse_args(self.__class__.__name__, "") self.args = parse_args(self.__class__.__name__, "")
print("[ROUTER] args = ", self.args)
@async_onstart @async_onstart
async def async_init(self): async def async_init(self):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frontend:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
endpoint: dynamo-init.Processor.chat/completions
port: 8000
Processor:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
block-size: 64
max-model-len: 16384
router: kv
Router:
model-name: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
min-workers: 1
VllmWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
conditional-disagg: true
tensor-parallel-size: 1
router: kv
enable-prefix-caching: true
# TODO - set all of these but model as default
PrefillWorker:
model: deepseek-ai/DeepSeek-R1-Distill-Llama-8B
enforce-eager: true
kv-transfer-config: '{"kv_connector":"DynamoNixlConnector"}'
block-size: 64
max-model-len: 16384
max-num-batched-tokens: 16384
cuda-visible-device-offset: 1
...@@ -60,7 +60,6 @@ class Processor(ProcessMixIn): ...@@ -60,7 +60,6 @@ class Processor(ProcessMixIn):
class_name = self.__class__.__name__ class_name = self.__class__.__name__
self.engine_args = parse_vllm_args(class_name, "") self.engine_args = parse_vllm_args(class_name, "")
self.model_config = self.engine_args.create_model_config() self.model_config = self.engine_args.create_model_config()
print(f"[Processor] self.engine_args: {self.engine_args}")
self.tokenizer = self._create_tokenizer(self.engine_args) self.tokenizer = self._create_tokenizer(self.engine_args)
self.chat_processor = ChatProcessor(self.tokenizer, self.model_config) self.chat_processor = ChatProcessor(self.tokenizer, self.model_config)
self.completions_processor = CompletionsProcessor( self.completions_processor = CompletionsProcessor(
...@@ -191,7 +190,7 @@ class Processor(ProcessMixIn): ...@@ -191,7 +190,7 @@ class Processor(ProcessMixIn):
f"Request type {request_type} not implemented" f"Request type {request_type} not implemented"
) )
@dynamo_endpoint() @dynamo_endpoint(name="chat/completions")
async def chat_completions(self, raw_request: ChatCompletionRequest): async def chat_completions(self, raw_request: ChatCompletionRequest):
async for response in self._generate(raw_request, RequestType.CHAT): async for response in self._generate(raw_request, RequestType.CHAT):
yield response yield response
......
...@@ -22,9 +22,7 @@ from dynamo.sdk.lib.config import ServiceConfig ...@@ -22,9 +22,7 @@ from dynamo.sdk.lib.config import ServiceConfig
def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs: def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
config = ServiceConfig.get_instance() config = ServiceConfig.get_instance()
print(f"[DEBUG] config: {config}")
vllm_args = config.as_args(service_name, prefix=prefix) vllm_args = config.as_args(service_name, prefix=prefix)
print(f"[DEBUG] service_name: {service_name}, vllm_args: {vllm_args}")
parser = FlexibleArgumentParser() parser = FlexibleArgumentParser()
parser.add_argument( parser.add_argument(
"--router", "--router",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment