Unverified Commit c0ed76da authored by ishandhanani's avatar ishandhanani Committed by GitHub
Browse files

test: add kv routing test for sglang (#2424)


Signed-off-by: default avatarishandhanani <82981111+ishandhanani@users.noreply.github.com>
Co-authored-by: default avatarNeelay Shah <neelays@nvidia.com>
parent c3ecaf6c
...@@ -141,6 +141,9 @@ cd $DYNAMO_HOME/components/backends/sglang ...@@ -141,6 +141,9 @@ cd $DYNAMO_HOME/components/backends/sglang
### Aggregated Serving with KV Routing ### Aggregated Serving with KV Routing
> [!NOTE]
> Until sglang releases a version > v0.5.0rc0, you will have to install from source to use kv_routing. You can do this by running `git clone https://github.com/sgl-project/sglang.git && cd sglang && uv pip install -e "python[all]"`. We will update this section once sglang releases a newer version.
```bash ```bash
cd $DYNAMO_HOME/components/backends/sglang cd $DYNAMO_HOME/components/backends/sglang
./launch/agg_router.sh ./launch/agg_router.sh
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
# Setup cleanup trap # Setup cleanup trap
cleanup() { cleanup() {
echo "Cleaning up background processes..." echo "Cleaning up background processes..."
kill $DYNAMO_PID 2>/dev/null || true kill $DYNAMO_PID $WORKER_PID 2>/dev/null || true
wait $DYNAMO_PID 2>/dev/null || true wait $DYNAMO_PID $WORKER_PID 2>/dev/null || true
echo "Cleanup complete." echo "Cleanup complete."
} }
trap cleanup EXIT INT TERM trap cleanup EXIT INT TERM
...@@ -26,4 +26,14 @@ python3 -m dynamo.sglang.worker \ ...@@ -26,4 +26,14 @@ python3 -m dynamo.sglang.worker \
--tp 1 \ --tp 1 \
--trust-remote-code \ --trust-remote-code \
--skip-tokenizer-init \ --skip-tokenizer-init \
--kv-events-config '{"publisher": "zmq", "topic": "kv-events"}' --kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5557"}' &
WORKER_PID=$!
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.sglang.worker \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--served-model-name deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--page-size 16 \
--tp 1 \
--trust-remote-code \
--skip-tokenizer-init \
--kv-events-config '{"publisher":"zmq","topic":"kv-events","endpoint":"tcp://*:5558"}'
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
import json
import logging import logging
import random import random
import signal import signal
...@@ -357,11 +358,18 @@ async def init( ...@@ -357,11 +358,18 @@ async def init(
handler.setup_metrics() handler.setup_metrics()
# Set up ZMQ kv event publisher # Set up ZMQ kv event publisher
zmq_config = ZmqKvEventPublisherConfig( if server_args.kv_events_config:
worker_id=endpoint.lease_id(), kv_events = json.loads(server_args.kv_events_config)
kv_block_size=server_args.page_size, ep = kv_events.get("endpoint")
) zmq_ep = ep.replace("*", get_ip()) if ep else None
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
zmq_config = ZmqKvEventPublisherConfig(
worker_id=endpoint.lease_id(),
kv_block_size=server_args.page_size,
zmq_endpoint=zmq_ep,
)
logging.info(f"Setting up ZMQ kv event publisher at {zmq_ep}")
_ = ZmqKvEventPublisher(component=component, config=zmq_config)
tasks = [endpoint.serve_endpoint(handler.generate)] tasks = [endpoint.serve_endpoint(handler.generate)]
......
...@@ -95,3 +95,11 @@ via ```./container/build.sh --framework X``` and run via ...@@ -95,3 +95,11 @@ via ```./container/build.sh --framework X``` and run via
The tests will automatically use a local cache at `~/.cache/huggingface` to avoid The tests will automatically use a local cache at `~/.cache/huggingface` to avoid
repeated downloads of model files. This cache is shared across test runs to improve performance. repeated downloads of model files. This cache is shared across test runs to improve performance.
## Running tests locally outside of a container
To run tests outside of the development container, ensure that you have properly setup your environment and have installed the following dependencies in your `venv`:
```bash
uv pip install pytest-mypy
uv pip install pytest-asyncio
```
\ No newline at end of file
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
import os import os
import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List from typing import Any, List
...@@ -14,6 +15,31 @@ from tests.utils.managed_process import ManagedProcess ...@@ -14,6 +15,31 @@ from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_log_patterns(log_file, patterns):
"""Validate log patterns after test completion."""
if not os.path.exists(log_file):
raise AssertionError(f"Log file not found: {log_file}")
with open(log_file, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
compiled = [re.compile(p) for p in patterns]
missing = []
for pattern, rx in zip(patterns, compiled):
if not rx.search(content):
missing.append(pattern)
if missing:
# Include sample of log content for debugging
sample = content[-1000:] if len(content) > 1000 else content
raise AssertionError(
f"Missing expected log patterns: {missing}\n\nLog sample:\n{sample}"
)
return True
@dataclass @dataclass
class SGLangConfig: class SGLangConfig:
"""Configuration for SGLang test scenarios""" """Configuration for SGLang test scenarios"""
...@@ -28,7 +54,9 @@ class SGLangProcess(ManagedProcess): ...@@ -28,7 +54,9 @@ class SGLangProcess(ManagedProcess):
def __init__(self, script_name, request): def __init__(self, script_name, request):
self.port = 8000 self.port = 8000
sglang_dir = "/workspace/components/backends/sglang" sglang_dir = os.environ.get(
"SGLANG_DIR", "/workspace/components/backends/sglang"
)
script_path = os.path.join(sglang_dir, "launch", script_name) script_path = os.path.join(sglang_dir, "launch", script_name)
# Verify script exists # Verify script exists
...@@ -38,8 +66,17 @@ class SGLangProcess(ManagedProcess): ...@@ -38,8 +66,17 @@ class SGLangProcess(ManagedProcess):
# Make script executable and run it # Make script executable and run it
command = ["bash", script_path] command = ["bash", script_path]
# Focus kv-router logs for kv_events run
env = os.environ.copy()
if script_name == "agg_router.sh":
env.setdefault(
"DYN_LOG",
"dynamo_llm::kv_router::publisher=trace,dynamo_llm::kv_router::scheduler=info",
)
super().__init__( super().__init__(
command=command, command=command,
env=env,
timeout=900, timeout=900,
display_output=True, display_output=True,
working_dir=sglang_dir, working_dir=sglang_dir,
...@@ -50,7 +87,6 @@ class SGLangProcess(ManagedProcess): ...@@ -50,7 +87,6 @@ class SGLangProcess(ManagedProcess):
delayed_start=60, # Give SGLang more time to fully start delayed_start=60, # Give SGLang more time to fully start
terminate_existing=False, terminate_existing=False,
stragglers=[], # Don't kill any stragglers automatically stragglers=[], # Don't kill any stragglers automatically
log_dir=request.node.name,
) )
def _check_models_api(self, response): def _check_models_api(self, response):
...@@ -72,6 +108,9 @@ sglang_configs = { ...@@ -72,6 +108,9 @@ sglang_configs = {
"disaggregated": SGLangConfig( "disaggregated": SGLangConfig(
script_name="disagg.sh", marks=[pytest.mark.gpu_2], name="disaggregated" script_name="disagg.sh", marks=[pytest.mark.gpu_2], name="disaggregated"
), ),
"kv_events": SGLangConfig(
script_name="agg_router.sh", marks=[pytest.mark.gpu_2], name="kv_events"
),
} }
...@@ -79,6 +118,7 @@ sglang_configs = { ...@@ -79,6 +118,7 @@ sglang_configs = {
params=[ params=[
pytest.param("aggregated", marks=[pytest.mark.gpu_1]), pytest.param("aggregated", marks=[pytest.mark.gpu_1]),
pytest.param("disaggregated", marks=[pytest.mark.gpu_2]), pytest.param("disaggregated", marks=[pytest.mark.gpu_2]),
pytest.param("kv_events", marks=[pytest.mark.gpu_2]),
] ]
) )
def sglang_config_test(request): def sglang_config_test(request):
...@@ -104,28 +144,50 @@ def test_sglang_deployment(request, runtime_services, sglang_config_test): ...@@ -104,28 +144,50 @@ def test_sglang_deployment(request, runtime_services, sglang_config_test):
with SGLangProcess(config.script_name, request) as server: with SGLangProcess(config.script_name, request) as server:
# Test chat completions # Test chat completions
response = requests.post( prompts = [
f"http://localhost:{server.port}/v1/chat/completions", "why is roger federer the best tennis player of all time?",
json={ "why is novak djokovic not the best tennis player of all time?",
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", "why is rafa nadal a sneaky good grass court player?",
"messages": [ "explain the difference between federer and nadal's backhand.",
{ "who is the most clutch tennis player in history?",
"role": "user", ]
"content": "Why is Roger Federer the best tennis player of all time?", responses = []
} for prompt in prompts:
], response = requests.post(
"max_tokens": 50, f"http://localhost:{server.port}/v1/chat/completions",
}, json={
timeout=120, "model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
) "messages": [
{
assert response.status_code == 200 "role": "user",
result = response.json() "content": prompt,
assert "choices" in result }
assert len(result["choices"]) > 0 ],
content = result["choices"][0]["message"]["content"] "max_tokens": 50,
assert len(content) > 0 },
logger.info(f"SGLang {config.name} response: {content}") timeout=120,
)
assert response.status_code == 200
result = response.json()
assert "choices" in result
assert len(result["choices"]) > 0
content = result["choices"][0]["message"]["content"]
assert len(content) > 0
responses.append(content)
logger.info(f"SGLang {config.name} response: {content}")
# For kv_events (KV routing path), assert KV publisher/scheduler log lines appear
if config.name == "kv_events":
log_file = os.path.join(server.log_dir, "bash.log.txt")
assert os.path.exists(log_file), f"Log file not found: {log_file}"
patterns = [
r"ZMQ listener .* received batch with \d+ events \(seq=\d+\)",
r"Event processor for worker_id \d+ processing event: Stored\(",
r"Selected worker: \d+, logit: ",
]
validate_log_patterns(log_file, patterns)
# Test completions endpoint for disaggregated only # Test completions endpoint for disaggregated only
if config.name == "disaggregated": if config.name == "disaggregated":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment