"vscode:/vscode.git/clone" did not exist on "57454300fe30ff365ee307c424ff40e8d2035fa5"
Unverified Commit 5b457b70 authored by Michael Feil's avatar Michael Feil Committed by GitHub
Browse files

feat: python add abi compatability for cross-platform builds + add a unit test...


feat: python add abi compatability for cross-platform builds + add a unit test to HttpServer (#3044)
Signed-off-by: default avatarmichaelfeil <me@michaelfeil.eu>
Signed-off-by: default avatarMichael Feil <63565275+michaelfeil@users.noreply.github.com>
Signed-off-by: default avatarroot <root@michaelfeil2-dev-pod-b200-0.michaelfeil2-dev-pod-b200.baseten.svc.cluster.local>
Signed-off-by: default avatarroot <root@michaelfeildns-dev-pod-h100-0.michaelfeildns-dev-pod-h100.baseten.svc.cluster.local>
Co-authored-by: default avatarroot <root@michaelfeil2-dev-pod-b200-0.michaelfeil2-dev-pod-b200.baseten.svc.cluster.local>
Co-authored-by: default avatarroot <root@michaelfeildns-dev-pod-h100-0.michaelfeildns-dev-pod-h100.baseten.svc.cluster.local>
parent 50cdae5f
...@@ -53,13 +53,14 @@ tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } ...@@ -53,13 +53,14 @@ tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
uuid = { version = "1.17", features = ["v4", "serde"] } uuid = { version = "1.17", features = ["v4", "serde"] }
# "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so) # "extension-module" tells pyo3 we want to build an extension module (skips linking against libpython.so)
# "abi3-py39" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.9 # "abi3-py310" tells pyo3 (and maturin) to build using the stable ABI with minimum Python version 3.10, which is the minimum version in pyproject.toml
pyo3 = { version = "0.23.4", default-features = false, features = [ pyo3 = { version = "0.23.4", default-features = false, features = [
"macros", "macros",
"experimental-async", "experimental-async",
"experimental-inspect", "experimental-inspect",
"extension-module", "extension-module",
"py-clone", "py-clone",
# "abi3-py310" # TODO: Add abi feature in follow-up, since docker build can be simplified.
] } ] }
pyo3-async-runtimes = { version = "0.23.0", default-features = false, features = [ pyo3-async-runtimes = { version = "0.23.0", default-features = false, features = [
......
...@@ -92,7 +92,7 @@ def parse_args(): ...@@ -92,7 +92,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"model", "model",
nargs="?", # Make it optional for argparse, we'll validate manually nargs="?", # Make it optional for argparse, we'll validate manually
help="Path to the model (e.g., Qwen/Qwen3-0.6B).\n" "Required unless out=dyn.", help="Path to the model (e.g., Qwen/Qwen3-0.6B).\nRequired unless out=dyn.",
) )
# Parse the arguments that were not 'in=' or 'out=' # Parse the arguments that were not 'in=' or 'out='
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
use std::sync::Arc; use std::sync::Arc;
use pyo3::{exceptions::PyException, prelude::*}; use pyo3::prelude::*;
use crate::{CancellationToken, engine::*, to_pyerr}; use crate::{CancellationToken, engine::*, to_pyerr};
...@@ -102,31 +102,6 @@ impl HttpService { ...@@ -102,31 +102,6 @@ impl HttpService {
} }
} }
/// Python Exception for HTTP errors
#[pyclass(extends=PyException)]
pub struct HttpError {
code: u16,
message: String,
}
#[pymethods]
impl HttpError {
#[new]
pub fn new(code: u16, message: String) -> Self {
HttpError { code, message }
}
#[getter]
fn code(&self) -> u16 {
self.code
}
#[getter]
fn message(&self) -> &str {
&self.message
}
}
#[pyclass] #[pyclass]
#[derive(Clone)] #[derive(Clone)]
pub struct HttpAsyncEngine(pub PythonAsyncEngine); pub struct HttpAsyncEngine(pub PythonAsyncEngine);
...@@ -172,18 +147,23 @@ where ...@@ -172,18 +147,23 @@ where
Err(e) => { Err(e) => {
if let Some(py_err) = e.downcast_ref::<PyErr>() { if let Some(py_err) = e.downcast_ref::<PyErr>() {
Python::with_gil(|py| { Python::with_gil(|py| {
if let Ok(http_error_instance) = py_err let err_val = py_err.clone_ref(py).into_value(py);
.clone_ref(py) let bound_err = err_val.bind(py);
.into_value(py)
.extract::<PyRef<HttpError>>(py) // check: Py03 exceptions cannot be cross-compiled, so we duck-type by name
// and fields.
if let Ok(type_name) = bound_err.get_type().name()
&& type_name.to_string().contains("HttpError")
&& let (Ok(code), Ok(message)) =
(bound_err.getattr("code"), bound_err.getattr("message"))
&& let (Ok(code), Ok(message)) =
(code.extract::<u16>(), message.extract::<String>())
{ {
Err(http_error::HttpError { // SSE panics if there are carriage returns or newlines
code: http_error_instance.code, let message = message.replace(['\r', '\n'], "");
message: http_error_instance.message.clone(), return Err(http_error::HttpError { code, message })?;
})?
} else {
Err(error!("Python Error: {}", py_err.to_string()))
} }
Err(error!("Python Error: {}", py_err.to_string()))
}) })
} else { } else {
Err(e) Err(e)
......
...@@ -166,7 +166,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -166,7 +166,6 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?; m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
m.add_class::<llm::kv::KvRecorder>()?; m.add_class::<llm::kv::KvRecorder>()?;
m.add_class::<http::HttpService>()?; m.add_class::<http::HttpService>()?;
m.add_class::<http::HttpError>()?;
m.add_class::<http::HttpAsyncEngine>()?; m.add_class::<http::HttpAsyncEngine>()?;
m.add_class::<context::Context>()?; m.add_class::<context::Context>()?;
m.add_class::<ModelType>()?; m.add_class::<ModelType>()?;
...@@ -179,6 +178,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -179,6 +178,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::KvPushRouter>()?; m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?; m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?; m.add_class::<RouterMode>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_class::<planner::VirtualConnectorCoordinator>()?; m.add_class::<planner::VirtualConnectorCoordinator>()?;
m.add_class::<planner::VirtualConnectorClient>()?; m.add_class::<planner::VirtualConnectorClient>()?;
m.add_class::<planner::PlannerDecision>()?; m.add_class::<planner::PlannerDecision>()?;
...@@ -574,6 +574,11 @@ impl DistributedRuntime { ...@@ -574,6 +574,11 @@ impl DistributedRuntime {
fn event_loop(&self) -> PyObject { fn event_loop(&self) -> PyObject {
self.event_loop.clone() self.event_loop.clone()
} }
fn child_token(&self) -> CancellationToken {
let inner = self.inner.runtime().child_token();
CancellationToken { inner }
}
} }
// Bind a TCP port and return a socket held until dropped. // Bind a TCP port and return a socket held until dropped.
......
...@@ -53,6 +53,26 @@ class DistributedRuntime: ...@@ -53,6 +53,26 @@ class DistributedRuntime:
""" """
... ...
def child_token(self) -> CancellationToken:
"""
Get a child cancellation token that can be passed to async tasks
"""
...
class CancellationToken:
def cancel(self) -> None:
"""
Cancel the token and all its children
"""
...
async def cancelled(self) -> None:
"""
Await until the token is cancelled
"""
...
class Namespace: class Namespace:
""" """
A namespace is a collection of components A namespace is a collection of components
...@@ -704,13 +724,6 @@ class HttpService: ...@@ -704,13 +724,6 @@ class HttpService:
... ...
class HttpError:
"""
An error that occurred in the HTTP service
"""
...
class HttpAsyncEngine: class HttpAsyncEngine:
""" """
An async engine for a distributed Dynamo http service. This is an extension of the An async engine for a distributed Dynamo http service. This is an extension of the
......
...@@ -20,7 +20,6 @@ from dynamo._core import EngineType ...@@ -20,7 +20,6 @@ from dynamo._core import EngineType
from dynamo._core import EntrypointArgs as EntrypointArgs from dynamo._core import EntrypointArgs as EntrypointArgs
from dynamo._core import ForwardPassMetrics as ForwardPassMetrics from dynamo._core import ForwardPassMetrics as ForwardPassMetrics
from dynamo._core import HttpAsyncEngine as HttpAsyncEngine from dynamo._core import HttpAsyncEngine as HttpAsyncEngine
from dynamo._core import HttpError as HttpError
from dynamo._core import HttpService as HttpService from dynamo._core import HttpService as HttpService
from dynamo._core import KvEventPublisher as KvEventPublisher from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer from dynamo._core import KvIndexer as KvIndexer
...@@ -47,3 +46,5 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for ...@@ -47,3 +46,5 @@ from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for
from dynamo._core import make_engine from dynamo._core import make_engine
from dynamo._core import register_llm as register_llm from dynamo._core import register_llm as register_llm
from dynamo._core import run_input from dynamo._core import run_input
from .exceptions import HttpError
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
class HttpError(Exception):
def __init__(self, code: int, message: str):
if not (isinstance(code, int) and 0 <= code < 600):
raise ValueError("HTTP status code must be an integer between 0 and 599")
if not (isinstance(message, str) and 0 < len(message) <= 8192):
raise ValueError("HTTP error message must be a string of length <= 8192")
self.code = code
self.message = message
super().__init__(f"HTTP {code}: {message}")
...@@ -24,3 +24,9 @@ def test_bindings_install(): ...@@ -24,3 +24,9 @@ def test_bindings_install():
# Placeholder to avoid unused import errors or removal by linters # Placeholder to avoid unused import errors or removal by linters
assert tdr assert tdr
def test_version():
from dynamo._core import __version__
assert __version__[0].isdigit() # semver should start with a digit
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import pytest
from dynamo.llm import HttpError
pytestmark = pytest.mark.pre_merge
def test_raise_http_error():
with pytest.raises(HttpError):
raise HttpError(404, "Not Found")
with pytest.raises(Exception):
raise HttpError(500, "Internal Server Error")
def test_invalid_http_error_code():
with pytest.raises(ValueError):
HttpError(1700, "Invalid Code")
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This test verifies that the HTTP server can be started and responds correctly to requests.
import asyncio
import json
import time
from typing import AsyncGenerator, Dict
import aiohttp
import pytest
from dynamo.llm import HttpAsyncEngine, HttpError, HttpService
from dynamo.runtime import DistributedRuntime
MSG_CONTAINS_ERROR = "This message contains an 400error."
MSG_CONTAINS_INTERNAL_ERROR = "This message contains an internal server error."
pytestmark = pytest.mark.pre_merge
class MockHttpEngine:
"""A mock engine that returns a completion or raises an error."""
def __init__(self, model_name: str = "test_model"):
self.model_name = model_name
async def generate(self, request: Dict, context) -> AsyncGenerator[Dict, None]:
"""
Raises HttpError if message contains 'error', otherwise streams a mock response.
"""
user_message = ""
for message in request.get("messages", []):
if message.get("role") == "user":
user_message = message.get("content", "")
break
# verifies that cancellation is propagated
if context.is_stopped():
print(f"Request {context.id()} was cancelled before starting.")
return
if MSG_CONTAINS_ERROR.lower() in user_message.lower():
raise HttpError(code=400, message=MSG_CONTAINS_ERROR)
elif MSG_CONTAINS_INTERNAL_ERROR.lower() in user_message.lower():
raise ValueError("Simulated internal error")
# Stream a mock response
created = int(time.time())
response_text = "This is a mock response."
for i, char in enumerate(response_text):
finish_reason = "stop" if i == len(response_text) - 1 else None
yield {
"id": f"chatcmpl-{context.id()}",
"object": "chat.completion.chunk",
"created": created,
"model": self.model_name,
"choices": [
{
"index": 0,
"delta": {"content": char},
"finish_reason": finish_reason,
}
],
}
await asyncio.sleep(0.01)
@pytest.fixture(scope="function", autouse=False)
async def http_server(runtime: DistributedRuntime):
"""Fixture to start a mock HTTP server using HttpService, contributed by Baseten."""
port = 8008
model_name = "test_model"
start_done = asyncio.Event()
child_token = runtime.child_token()
async def worker():
"""The server worker task."""
try:
loop = asyncio.get_running_loop()
python_engine = MockHttpEngine(model_name)
engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port)
service.add_chat_completions_model(model_name, engine)
service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token)
print("Starting service on port", port)
start_done.set()
await shutdown_signal
except Exception as e:
print("Server encountered an error:", e)
start_done.set()
raise ValueError(f"Server failed to start: {e}")
finally:
child_token.cancel()
server_task = asyncio.create_task(worker())
await asyncio.wait_for(start_done.wait(), timeout=30.0)
if server_task.done() and server_task.exception():
raise ValueError(f"Server task failed to start {server_task.exception()}")
yield f"http://localhost:{port}", model_name
# Teardown: Cancel the server task if it's still running
child_token.cancel()
await asyncio.sleep(0.1) # Give some time for graceful shutdown
if not server_task.done():
server_task.cancel()
try:
# Await cancellation to ensure proper cleanup for up to 10s
await asyncio.wait_for(server_task, timeout=10.0)
except asyncio.CancelledError:
print("Server task cancelled during teardown.")
pass
@pytest.mark.asyncio
@pytest.mark.forked
async def test_chat_completion_success(http_server):
"""Tests a successful chat completion request."""
base_url, model_name = http_server
url = f"{base_url}/v1/chat/completions"
data = {
"model": model_name,
"messages": [{"role": "user", "content": "Hello, this is a test."}],
"stream": True,
}
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=5)) as session:
async with session.post(url, json=data) as response:
response.raise_for_status()
content = ""
async for line in response.content:
if line.startswith(b"data: "):
chunk_data = line[len(b"data: ") :]
if chunk_data.strip() == b"[DONE]":
break
chunk = json.loads(chunk_data)
if (
chunk["choices"]
and chunk["choices"][0]["delta"]
and chunk["choices"][0]["delta"].get("content")
):
content += chunk["choices"][0]["delta"]["content"]
assert content == "This is a mock response."
@pytest.mark.asyncio
@pytest.mark.parametrize(
"msg_to_code",
[
(MSG_CONTAINS_ERROR, 500), # # TODO: should be 400, but currently 500
(
MSG_CONTAINS_INTERNAL_ERROR,
500,
), # Placeholder for future internal error test
],
)
@pytest.mark.forked
async def test_chat_completion_http_error(http_server, msg_to_code: tuple[str, int]):
"""Tests that an HttpError is raised when the message contains 'error'."""
base_url, model_name = http_server
url = f"{base_url}/v1/chat/completions"
data = {
"model": model_name,
"messages": [{"role": "user", "content": msg_to_code[0]}],
}
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=10)
) as session:
async with session.post(url, json=data) as response:
assert response.status == msg_to_code[1]
error_json = await response.json()
if msg_to_code[0] == MSG_CONTAINS_ERROR:
assert MSG_CONTAINS_ERROR in str(error_json)
elif msg_to_code[0] == MSG_CONTAINS_INTERNAL_ERROR:
assert "a python exception was caught" in str(error_json).lower()
...@@ -37,10 +37,14 @@ pytestmark = pytest.mark.pre_merge ...@@ -37,10 +37,14 @@ pytestmark = pytest.mark.pre_merge
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def distributed_runtime(): async def distributed_runtime():
"""TODO: This should not use scope='module' as DistributedRuntime has singleton requirements.
and blocks any tests with DistributedRuntime(loop, True) from running in the same process, or any forked process.
"""
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
return DistributedRuntime(loop, False) return DistributedRuntime(loop, False)
# TODO: enable pytest.mark.forked + scope='function' runtime.
async def test_radix_tree_binding(distributed_runtime): async def test_radix_tree_binding(distributed_runtime):
"""Test RadixTree binding directly with store event and find matches""" """Test RadixTree binding directly with store event and find matches"""
import json import json
...@@ -153,6 +157,7 @@ async def test_event_handler(distributed_runtime): ...@@ -153,6 +157,7 @@ async def test_event_handler(distributed_runtime):
), f"Scores still present after {(retry+1)*0.5}s: {scores.scores}" ), f"Scores still present after {(retry+1)*0.5}s: {scores.scores}"
# TODO: enable pytest.mark.forked + scope='function' runtime.
async def test_approx_kv_indexer(distributed_runtime): async def test_approx_kv_indexer(distributed_runtime):
kv_block_size = 32 kv_block_size = 32
namespace = "kv_test" namespace = "kv_test"
...@@ -210,6 +215,7 @@ class EventPublisher: ...@@ -210,6 +215,7 @@ class EventPublisher:
self.event_id_counter += 1 self.event_id_counter += 1
# TODO: enable pytest.mark.forked + scope='function' runtime.
async def test_metrics_aggregator(distributed_runtime): async def test_metrics_aggregator(distributed_runtime):
namespace = "kv_test" namespace = "kv_test"
component = "metrics" component = "metrics"
......
...@@ -5,7 +5,11 @@ from typing import Optional ...@@ -5,7 +5,11 @@ from typing import Optional
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch
try:
import torch
except ImportError:
pass
try: try:
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
...@@ -75,7 +79,7 @@ def make_request( ...@@ -75,7 +79,7 @@ def make_request(
) )
def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig: def make_kv_cache_config(block_size: int, num_blocks: int) -> "KVCacheConfig":
return KVCacheConfig( return KVCacheConfig(
num_blocks=num_blocks, num_blocks=num_blocks,
tensors={}, tensors={},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment