"vscode:/vscode.git/clone" did not exist on "4ebb244b2af77af81a797569080c66c6015afb70"
Unverified Commit 4960dbb3 authored by William Arnold's avatar William Arnold Committed by GitHub
Browse files

feat: enable returning routed experts info up through sglang (#6137)


Signed-off-by: default avatarWilliam Arnold <7565007+Aphoh@users.noreply.github.com>
parent a1766e4a
......@@ -6,6 +6,7 @@ import logging
import time
from typing import Any, AsyncGenerator, Dict, Optional
import pybase64
import sglang as sgl
from dynamo._core import Context
......@@ -107,6 +108,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
trace_id = context.trace_id
sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request)
return_routed_experts = getattr(
self.config.server_args, "enable_return_routed_experts", False
)
priority = (request.get("routing") or {}).get("priority")
if self.serving_mode == DisaggregationMode.DECODE:
......@@ -137,6 +141,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
**input_param,
sampling_params=sampling_params,
stream=True,
return_routed_experts=return_routed_experts,
bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"],
......@@ -179,6 +184,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
image_data=image_data,
sampling_params=sampling_params,
stream=True,
return_routed_experts=return_routed_experts,
external_trace_header=trace_header,
rid=trace_id,
data_parallel_rank=dp_rank,
......@@ -242,6 +248,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Pass through disjoint token segments directly
out["token_ids"] = output_ids
routed_experts = res["meta_info"].get("routed_experts")
if routed_experts is not None:
# Base64-encode tensor bytes to match sglang's output format.
routed_experts = pybase64.b64encode(
routed_experts.numpy().tobytes()
).decode("utf-8")
# Internal transport field consumed by frontend nvext mapping.
out["disaggregated_params"] = {"routed_experts": routed_experts}
if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"]
completion_tokens = res["meta_info"]["completion_tokens"]
......@@ -316,6 +330,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"model": self.config.server_args.served_model_name,
"object": "chat.completion.chunk",
}
routed_experts = res["meta_info"].get("routed_experts")
if routed_experts is not None:
# Base64-encode tensor bytes to match sglang's output format.
routed_experts = pybase64.b64encode(
routed_experts.numpy().tobytes()
).decode("utf-8")
response["nvext"] = {"routed_experts": routed_experts}
if not context.is_stopped():
yield response
count = next_count
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Test script for routed expert info return.
Starts a Dynamo frontend + SGLang backend with --enable-return-routed-experts
and verifies that expert routing info appears in the response nvext.
Usage:
python test_sglang_expert_info.py
Requires etcd and nats running (see deploy/docker-compose.yml).
"""
import json
import os
import signal
import subprocess
import sys
import time
import numpy as np
import pybase64
import requests
# Configuration
MODEL = os.environ.get("MODEL_PATH", os.path.expanduser("~/proj/models/dsv2-lite-fp8"))
HOST = "127.0.0.1"
FRONTEND_PORT = int(os.environ.get("FRONTEND_PORT", "30080"))
SYSTEM_PORT = int(os.environ.get("DYN_SYSTEM_PORT", "9092"))
FRONTEND_URL = f"http://{HOST}:{FRONTEND_PORT}"
SYSTEM_URL = f"http://{HOST}:{SYSTEM_PORT}"
LOG_DIR = "/tmp/sglang_expert_info_test"
def start_frontend():
"""Start the Dynamo frontend."""
print("\nStarting Dynamo frontend...")
os.makedirs(LOG_DIR, exist_ok=True)
log = open(f"{LOG_DIR}/frontend.log", "w")
cmd = [sys.executable, "-m", "dynamo.frontend", "--http-port", str(FRONTEND_PORT)]
print(f" Command: {' '.join(cmd)}")
print(f" Logs: {LOG_DIR}/frontend.log")
process = subprocess.Popen(cmd, stdout=log, stderr=subprocess.STDOUT)
max_wait = 30
start_time = time.time()
while time.time() - start_time < max_wait:
try:
resp = requests.get(f"{FRONTEND_URL}/health", timeout=1)
if resp.status_code == 200:
print(" Frontend is ready!")
return process
except requests.exceptions.RequestException:
pass
if process.poll() is not None:
print(" Frontend process died!")
sys.exit(1)
time.sleep(1)
print(" Frontend failed to start in time!")
process.kill()
sys.exit(1)
def start_sglang_backend():
"""Start the SGLang backend."""
print("\nStarting SGLang backend...")
log = open(f"{LOG_DIR}/backend.log", "w")
env = os.environ.copy()
env["DYN_SYSTEM_PORT"] = str(SYSTEM_PORT)
cmd = [
sys.executable,
"-m",
"dynamo.sglang",
"--model-path",
MODEL,
"--tp",
"1",
"--mem-fraction-static",
"0.8",
"--enable-return-routed-experts",
]
print(f" Command: {' '.join(cmd)}")
print(f" Logs: {LOG_DIR}/backend.log")
process = subprocess.Popen(cmd, env=env, stdout=log, stderr=subprocess.STDOUT)
max_wait = 300
start_time = time.time()
while time.time() - start_time < max_wait:
try:
resp = requests.get(f"{SYSTEM_URL}/health", timeout=1)
if resp.status_code == 200:
print(" Backend is ready!")
return process
except requests.exceptions.RequestException:
pass
if process.poll() is not None:
print(" Backend process died! Check logs:")
print(f" tail {LOG_DIR}/backend.log")
sys.exit(1)
time.sleep(2)
print(" Backend failed to start in time!")
process.kill()
sys.exit(1)
def validate_routed_experts(routed_experts):
"""Check that routed_experts is a base64-encoded string of int32 expert IDs."""
assert isinstance(
routed_experts, str
), f"Expected base64 string, got {type(routed_experts)}"
decoded = np.frombuffer(
pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32
)
assert len(decoded) > 0, "routed_experts decoded to empty array"
def test_completions_non_streaming():
"""Non-streaming completions should return routed_experts in nvext."""
print("\n--- test_completions_non_streaming ---")
resp = requests.post(
f"{FRONTEND_URL}/v1/completions",
json={
"model": MODEL,
"prompt": "Hello",
"max_tokens": 5,
"temperature": 0.0,
"stream": False,
},
timeout=30,
)
print(f" Status: {resp.status_code}")
data = resp.json()
print(f" Response keys: {list(data.keys())}")
assert resp.status_code == 200
assert "choices" in data
assert len(data["choices"]) > 0
nvext = data.get("nvext", {})
assert (
"routed_experts" in nvext
), f"Expected routed_experts in nvext, got keys: {list(nvext.keys())}"
validate_routed_experts(nvext["routed_experts"])
print(f" routed_experts shape: {len(nvext['routed_experts'])} layers")
print(" PASSED")
def test_completions_streaming():
"""Streaming completions should return routed_experts in final chunk's nvext."""
print("\n--- test_completions_streaming ---")
resp = requests.post(
f"{FRONTEND_URL}/v1/completions",
json={
"model": MODEL,
"prompt": "Hello",
"max_tokens": 5,
"temperature": 0.0,
"stream": True,
},
timeout=30,
stream=True,
)
print(f" Status: {resp.status_code}")
assert resp.status_code == 200
chunks = []
found_routed_experts = False
for line in resp.iter_lines():
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
payload = line[len("data: ") :]
if payload == "[DONE]":
break
chunk = json.loads(payload)
chunks.append(chunk)
nvext = chunk.get("nvext", {})
if "routed_experts" in nvext:
found_routed_experts = True
validate_routed_experts(nvext["routed_experts"])
print(f" routed_experts shape: {len(nvext['routed_experts'])} layers")
print(f" Total chunks: {len(chunks)}")
assert len(chunks) > 0, "Expected at least one chunk"
assert found_routed_experts, "Expected routed_experts in at least one nvext chunk."
print(" PASSED")
def test_chat_completions_streaming():
"""Streaming chat completions should return routed_experts in final chunk's nvext."""
print("\n--- test_chat_completions_streaming ---")
resp = requests.post(
f"{FRONTEND_URL}/v1/chat/completions",
json={
"model": MODEL,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
"temperature": 0.0,
"stream": True,
},
timeout=30,
stream=True,
)
print(f" Status: {resp.status_code}")
assert resp.status_code == 200
chunks = []
found_routed_experts = False
for line in resp.iter_lines():
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
payload = line[len("data: ") :]
if payload == "[DONE]":
break
chunk = json.loads(payload)
chunks.append(chunk)
nvext = chunk.get("nvext", {})
if "routed_experts" in nvext:
found_routed_experts = True
validate_routed_experts(nvext["routed_experts"])
print(f" routed_experts shape: {len(nvext['routed_experts'])} layers")
print(f" Total chunks: {len(chunks)}")
assert len(chunks) > 0, "Expected at least one chunk"
assert found_routed_experts, "Expected routed_experts in at least one nvext chunk."
print(" PASSED")
def main():
frontend_process = None
backend_process = None
try:
frontend_process = start_frontend()
backend_process = start_sglang_backend()
time.sleep(2)
print("\n" + "=" * 60)
print("Running expert info tests")
print("=" * 60)
test_completions_non_streaming()
test_completions_streaming()
test_chat_completions_streaming()
print("\n" + "=" * 60)
print("All tests passed!")
print("=" * 60)
except KeyboardInterrupt:
print("\nInterrupted by user")
except Exception as e:
print(f"\nTest failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
print("\nShutting down...")
for name, proc in [
("backend", backend_process),
("frontend", frontend_process),
]:
if proc:
print(f" Stopping {name}...")
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
print("Done")
if __name__ == "__main__":
main()
......@@ -437,6 +437,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
......@@ -448,12 +453,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None
};
// Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
// Inject nvext if we have worker_id, token_ids, timing, or routed experts.
if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
{
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
routed_experts,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......
......@@ -331,6 +331,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
.as_ref()
.and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
// Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
......@@ -342,12 +347,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
None
};
// Inject nvext if we have worker_id, token_ids, or timing
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() {
// Inject nvext if we have worker_id, token_ids, timing, or routed experts.
if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
{
let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(),
timing: timing_info,
token_ids: token_ids.clone(),
routed_experts,
};
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......
......@@ -87,6 +87,10 @@ pub struct NvExtResponse {
/// Contains the tokenized prompt for reuse in Stage 2
#[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>,
/// Routed expert capture payload (SGLang-specific)
#[serde(skip_serializing_if = "Option::is_none")]
pub routed_experts: Option<serde_json::Value>,
}
/// NVIDIA LLM extensions to the OpenAI API
......@@ -135,7 +139,8 @@ pub struct NvExt {
/// Extra fields to be included in the response's nvext
/// This is a list of field names that should be populated in the response
/// Supported fields: "worker_id", "timing", which has a 1:1 mapping with the NvExtResponse names
/// Supported fields include "worker_id", "timing", "routed_experts",
/// which map to fields in NvExtResponse.
#[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment