Commit cc086811 authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

feat(sdk): pass in CLI args when running `serve` (#78)

parent 30c5a79f
......@@ -15,6 +15,8 @@
from __future__ import annotations
import collections
import json
import logging
import os
import sys
......@@ -60,6 +62,52 @@ def deprecated_option(*param_decls: str, **attrs: t.Any):
return decorator
def _parse_service_arg(arg: str) -> tuple[str, str, t.Any] | None:
"""Parse a single CLI argument into service name, key, and value."""
if not (arg.startswith("--") and "=" in arg):
return None
# Remove leading dashes
param = arg[2:]
key_path, value_str = param.split("=", 1)
if "." not in key_path:
return None
service, key = key_path.split(".", 1)
# Parse value based on type
try:
# Try as JSON for complex types
value = json.loads(value_str)
except json.JSONDecodeError:
# Handle basic types
if value_str.isdigit():
value = int(value_str)
elif value_str.replace(".", "", 1).isdigit() and value_str.count(".") <= 1:
value = float(value_str)
elif value_str.lower() in ("true", "false"):
value = value_str.lower() == "true"
else:
value = value_str
return service, key, value
def _parse_service_args(args: list[str]) -> t.Dict[str, t.Any] | None:
service_configs: t.DefaultDict[str, t.Dict[str, t.Any]] = collections.defaultdict(
dict
)
for arg in args:
parsed = _parse_service_arg(arg)
if parsed:
service, key, value = parsed
service_configs[service][key] = value
return service_configs
def build_serve_command() -> click.Group:
from bentoml._internal.log import configure_server_logging
from bentoml_cli.env_manager import env_manager
......@@ -69,7 +117,14 @@ def build_serve_command() -> click.Group:
def cli():
pass
@cli.command(aliases=["serve-http"], cls=AliasCommand)
@cli.command(
context_settings=dict(
ignore_unknown_options=True,
allow_extra_args=True,
),
aliases=["serve-http"],
cls=AliasCommand,
)
@click.argument("bento", type=click.STRING, default=".")
@click.option(
"--development",
......@@ -203,8 +258,10 @@ def build_serve_command() -> click.Group:
show_default=True,
hidden=True,
)
@click.pass_context
@env_manager
def serve(
ctx: click.Context,
bento: str,
development: bool,
port: int,
......@@ -227,6 +284,9 @@ def build_serve_command() -> click.Group:
) -> None:
"""Start a HTTP BentoServer from a given 🍱
\b
You can also pass service-specific configuration options using --ServiceName.param=value format.
\b
BENTO is the serving target, it can be the import as:
- the import path of a 'bentoml.Service' instance
......@@ -261,6 +321,13 @@ def build_serve_command() -> click.Group:
from bentoml import Service
from bentoml._internal.service.loader import load
# Process service-specific options
service_configs = _parse_service_args(ctx.args)
# Set environment variable with service configuration
if service_configs:
os.environ["DYNAMO_SERVICE_CONFIG"] = json.dumps(service_configs)
configure_server_logging()
if working_dir is None:
if os.path.isdir(os.path.expanduser(bento)):
......
......@@ -188,13 +188,18 @@ def create_dynamo_watcher(
if worker_envs:
args.extend(["--worker-env", json.dumps(worker_envs)])
# Update env to include ServiceConfig
worker_env = env.copy() if env else {}
if "DYNAMO_SERVICE_CONFIG" in os.environ:
worker_env["DYNAMO_SERVICE_CONFIG"] = os.environ["DYNAMO_SERVICE_CONFIG"]
# Create the watcher with dependency map in environment
watcher = create_watcher(
name=f"dynamo_service_{svc.name}",
args=args,
numprocesses=num_workers,
working_dir=working_dir,
env=env, # Dependency map will be injected by serve_http
env=worker_env, # Use updated environment
)
return watcher, socket, uri
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
class ServiceConfig(dict):
"""Configuration store that inherits from dict for simpler access patterns"""
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls._load_from_env()
return cls._instance
@classmethod
def _load_from_env(cls):
"""Load config from environment variable"""
configs = {}
env_config = os.environ.get("DYNAMO_SERVICE_CONFIG")
if env_config:
try:
configs = json.loads(env_config)
except json.JSONDecodeError:
print("Failed to parse DYNAMO_SERVICE_CONFIG")
return cls(configs) # Initialize dict subclass with configs
def require(self, service_name, key):
"""Require a config value, raising error if not found"""
if service_name not in self or key not in self[service_name]:
raise ValueError(f"{service_name}.{key} must be specified in configuration")
return self[service_name][key]
def as_args(self, service_name, prefix=""):
"""Extract configs as CLI args for a service, with optional prefix filtering"""
if service_name not in self:
return []
args = []
for key, value in self[service_name].items():
if prefix and not key.startswith(prefix):
continue
# Strip prefix if needed
arg_key = key[len(prefix) :] if prefix and key.startswith(prefix) else key
# Convert to CLI format
if isinstance(value, bool):
if value:
args.append(f"--{arg_key}")
else:
args.extend([f"--{arg_key}", str(value)])
return args
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import subprocess
import time
import pytest
pytestmark = pytest.mark.pre_merge
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
# Setup code
nats_server = subprocess.Popen(["nats-server", "-js"])
etcd = subprocess.Popen(["etcd"])
print("Setting up resources")
server = subprocess.Popen(
[
"dynamo-sdk",
"serve",
"pipeline:Frontend",
"--working-dir",
"deploy/dynamo/sdk/src/dynamo/sdk/tests",
"--Frontend.model=qwentastic",
"--Middle.bias=0.5",
]
)
time.sleep(5)
yield
# Teardown code
print("Tearing down resources")
server.terminate()
server.wait()
nats_server.terminate()
nats_server.wait()
etcd.terminate()
etcd.wait()
async def test_pipeline():
import asyncio
import aiohttp
max_retries = 5
for attempt in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
"http://localhost:3000/generate",
json={"text": "federer-is-the-greatest-tennis-player-of-all-time"},
headers={"accept": "text/event-stream"},
) as resp:
assert resp.status == 200
text = await resp.text()
assert (
"federer-is-the-greatest-tennis-player-of-all-time-mid-back"
in text
)
break
except Exception:
if attempt == max_retries - 1:
raise
print(f"Attempt {attempt + 1} failed, retrying...")
await asyncio.sleep(1)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This is a simple example of a pipeline that uses Dynamo to deploy a backend, middle, and frontend service. Use this to test
# changes made to CLI, SDK, etc
import os
from pydantic import BaseModel
from dynamo.sdk import api, depends, dynamo_endpoint, service
from dynamo.sdk.lib.config import ServiceConfig
"""
Pipeline Architecture:
Users/Clients (HTTP)
┌─────────────┐
│ Frontend │ HTTP API endpoint (/generate)
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Middle │
└─────────────┘
│ dynamo/runtime
┌─────────────┐
│ Backend │
└─────────────┘
"""
class RequestType(BaseModel):
text: str
class ResponseType(BaseModel):
text: str
GPU_ENABLED = False
class FrontendConfig(BaseModel):
model: str
temperature: float = 0.7
max_tokens: int = 1024
stream: bool = True
class MiddleConfig(BaseModel):
bias: float
@service(
resources={"cpu": "2"},
traffic={"timeout": 30},
dynamo={
"enabled": True,
"namespace": "inference",
},
workers=3,
)
class Backend:
def __init__(self) -> None:
print("Starting backend")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Generate tokens."""
req_text = req.text
print(f"Backend received: {req_text}")
text = f"{req_text}-back"
for token in text.split():
yield f"Backend: {token}"
@service(
resources={"cpu": "2"},
traffic={"timeout": 30},
dynamo={"enabled": True, "namespace": "inference"},
)
class Middle:
backend = depends(Backend)
def __init__(self) -> None:
print("Starting middle")
config = ServiceConfig.get_instance()
middle_config = MiddleConfig(**config.get("Middle", {}))
print(f"bias: {middle_config.bias}")
if GPU_ENABLED:
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.utils import FlexibleArgumentParser
try:
os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
# Get VLLM args using new pattern
vllm_args = config.as_args("Middle", prefix="vllm_")
print(f"VLLM args to parse: {vllm_args}")
# Create and use parser
parser = FlexibleArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args(vllm_args)
self.engine_args = AsyncEngineArgs.from_cli_args(args)
self.engine = AsyncLLMEngine.from_engine_args(self.engine_args)
except ImportError:
print("VLLM imports not available, skipping engine arg parsing")
except Exception as e:
print(f"Error parsing VLLM args: {e}")
@dynamo_endpoint()
async def generate(self, req: RequestType):
"""Forward requests to backend."""
req_text = req.text
print(f"Middle received: {req_text}")
text = f"{req_text}-mid"
next_request = RequestType(text=text).model_dump_json()
async for response in self.backend.generate(next_request):
print(f"Middle received response: {response}")
yield f"Middle: {response}"
@service(resources={"cpu": "1"}, traffic={"timeout": 60}) # Regular HTTP API
class Frontend:
middle = depends(Middle)
def __init__(self) -> None:
print("Starting frontend")
self.config = ServiceConfig.get_instance()
frontend_config = FrontendConfig(**self.config.get("Frontend", {}))
print(
f"Frontend initialized with model={frontend_config.model}, "
f"temp={frontend_config.temperature}, max_tokens={frontend_config.max_tokens}"
)
# Get all configs for a service (new dict pattern)
all_frontend_configs = self.config.get("Frontend", {})
print(f"All Frontend configs: {all_frontend_configs}")
# Check other service configs (new dict pattern)
if self.config.get("Middle", {}).get("special_mode") == "fast":
print("Using Middle service in fast mode")
@api
async def generate(self, text):
"""Stream results from the pipeline."""
print(f"Frontend received: {text}")
print(f"Frontend received type: {type(text)}")
txt = RequestType(text=text)
print(f"Frontend sending: {type(txt)}")
async for response in self.middle.generate(txt.model_dump_json()):
yield f"Frontend: {response}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment