Unverified Commit 4960dbb3 authored by William Arnold's avatar William Arnold Committed by GitHub
Browse files

feat: enable returning routed experts info up through sglang (#6137)


Signed-off-by: default avatarWilliam Arnold <7565007+Aphoh@users.noreply.github.com>
parent a1766e4a
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import time import time
from typing import Any, AsyncGenerator, Dict, Optional from typing import Any, AsyncGenerator, Dict, Optional
import pybase64
import sglang as sgl import sglang as sgl
from dynamo._core import Context from dynamo._core import Context
...@@ -107,6 +108,9 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -107,6 +108,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
trace_id = context.trace_id trace_id = context.trace_id
sampling_params = self._build_sampling_params(request) sampling_params = self._build_sampling_params(request)
input_param = self._get_input_param(request) input_param = self._get_input_param(request)
return_routed_experts = getattr(
self.config.server_args, "enable_return_routed_experts", False
)
priority = (request.get("routing") or {}).get("priority") priority = (request.get("routing") or {}).get("priority")
if self.serving_mode == DisaggregationMode.DECODE: if self.serving_mode == DisaggregationMode.DECODE:
...@@ -137,6 +141,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -137,6 +141,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
**input_param, **input_param,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
return_routed_experts=return_routed_experts,
bootstrap_host=bootstrap_info["bootstrap_host"], bootstrap_host=bootstrap_info["bootstrap_host"],
bootstrap_port=bootstrap_info["bootstrap_port"], bootstrap_port=bootstrap_info["bootstrap_port"],
bootstrap_room=bootstrap_info["bootstrap_room"], bootstrap_room=bootstrap_info["bootstrap_room"],
...@@ -179,6 +184,7 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -179,6 +184,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
image_data=image_data, image_data=image_data,
sampling_params=sampling_params, sampling_params=sampling_params,
stream=True, stream=True,
return_routed_experts=return_routed_experts,
external_trace_header=trace_header, external_trace_header=trace_header,
rid=trace_id, rid=trace_id,
data_parallel_rank=dp_rank, data_parallel_rank=dp_rank,
...@@ -242,6 +248,14 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -242,6 +248,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# Pass through disjoint token segments directly # Pass through disjoint token segments directly
out["token_ids"] = output_ids out["token_ids"] = output_ids
routed_experts = res["meta_info"].get("routed_experts")
if routed_experts is not None:
# Base64-encode tensor bytes to match sglang's output format.
routed_experts = pybase64.b64encode(
routed_experts.numpy().tobytes()
).decode("utf-8")
# Internal transport field consumed by frontend nvext mapping.
out["disaggregated_params"] = {"routed_experts": routed_experts}
if finish_reason: if finish_reason:
input_tokens = res["meta_info"]["prompt_tokens"] input_tokens = res["meta_info"]["prompt_tokens"]
completion_tokens = res["meta_info"]["completion_tokens"] completion_tokens = res["meta_info"]["completion_tokens"]
...@@ -316,6 +330,13 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -316,6 +330,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"model": self.config.server_args.served_model_name, "model": self.config.server_args.served_model_name,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
} }
routed_experts = res["meta_info"].get("routed_experts")
if routed_experts is not None:
# Base64-encode tensor bytes to match sglang's output format.
routed_experts = pybase64.b64encode(
routed_experts.numpy().tobytes()
).decode("utf-8")
response["nvext"] = {"routed_experts": routed_experts}
if not context.is_stopped(): if not context.is_stopped():
yield response yield response
count = next_count count = next_count
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Test script for routed expert info return.
Starts a Dynamo frontend + SGLang backend with --enable-return-routed-experts
and verifies that expert routing info appears in the response nvext.
Usage:
python test_sglang_expert_info.py
Requires etcd and nats running (see deploy/docker-compose.yml).
"""
import json
import os
import signal
import subprocess
import sys
import time
import numpy as np
import pybase64
import requests
# Configuration
MODEL = os.environ.get("MODEL_PATH", os.path.expanduser("~/proj/models/dsv2-lite-fp8"))
HOST = "127.0.0.1"
FRONTEND_PORT = int(os.environ.get("FRONTEND_PORT", "30080"))
SYSTEM_PORT = int(os.environ.get("DYN_SYSTEM_PORT", "9092"))
FRONTEND_URL = f"http://{HOST}:{FRONTEND_PORT}"
SYSTEM_URL = f"http://{HOST}:{SYSTEM_PORT}"
LOG_DIR = "/tmp/sglang_expert_info_test"
def start_frontend():
"""Start the Dynamo frontend."""
print("\nStarting Dynamo frontend...")
os.makedirs(LOG_DIR, exist_ok=True)
log = open(f"{LOG_DIR}/frontend.log", "w")
cmd = [sys.executable, "-m", "dynamo.frontend", "--http-port", str(FRONTEND_PORT)]
print(f" Command: {' '.join(cmd)}")
print(f" Logs: {LOG_DIR}/frontend.log")
process = subprocess.Popen(cmd, stdout=log, stderr=subprocess.STDOUT)
max_wait = 30
start_time = time.time()
while time.time() - start_time < max_wait:
try:
resp = requests.get(f"{FRONTEND_URL}/health", timeout=1)
if resp.status_code == 200:
print(" Frontend is ready!")
return process
except requests.exceptions.RequestException:
pass
if process.poll() is not None:
print(" Frontend process died!")
sys.exit(1)
time.sleep(1)
print(" Frontend failed to start in time!")
process.kill()
sys.exit(1)
def start_sglang_backend():
"""Start the SGLang backend."""
print("\nStarting SGLang backend...")
log = open(f"{LOG_DIR}/backend.log", "w")
env = os.environ.copy()
env["DYN_SYSTEM_PORT"] = str(SYSTEM_PORT)
cmd = [
sys.executable,
"-m",
"dynamo.sglang",
"--model-path",
MODEL,
"--tp",
"1",
"--mem-fraction-static",
"0.8",
"--enable-return-routed-experts",
]
print(f" Command: {' '.join(cmd)}")
print(f" Logs: {LOG_DIR}/backend.log")
process = subprocess.Popen(cmd, env=env, stdout=log, stderr=subprocess.STDOUT)
max_wait = 300
start_time = time.time()
while time.time() - start_time < max_wait:
try:
resp = requests.get(f"{SYSTEM_URL}/health", timeout=1)
if resp.status_code == 200:
print(" Backend is ready!")
return process
except requests.exceptions.RequestException:
pass
if process.poll() is not None:
print(" Backend process died! Check logs:")
print(f" tail {LOG_DIR}/backend.log")
sys.exit(1)
time.sleep(2)
print(" Backend failed to start in time!")
process.kill()
sys.exit(1)
def validate_routed_experts(routed_experts):
"""Check that routed_experts is a base64-encoded string of int32 expert IDs."""
assert isinstance(
routed_experts, str
), f"Expected base64 string, got {type(routed_experts)}"
decoded = np.frombuffer(
pybase64.b64decode(routed_experts.encode("utf-8")), dtype=np.int32
)
assert len(decoded) > 0, "routed_experts decoded to empty array"
def test_completions_non_streaming():
"""Non-streaming completions should return routed_experts in nvext."""
print("\n--- test_completions_non_streaming ---")
resp = requests.post(
f"{FRONTEND_URL}/v1/completions",
json={
"model": MODEL,
"prompt": "Hello",
"max_tokens": 5,
"temperature": 0.0,
"stream": False,
},
timeout=30,
)
print(f" Status: {resp.status_code}")
data = resp.json()
print(f" Response keys: {list(data.keys())}")
assert resp.status_code == 200
assert "choices" in data
assert len(data["choices"]) > 0
nvext = data.get("nvext", {})
assert (
"routed_experts" in nvext
), f"Expected routed_experts in nvext, got keys: {list(nvext.keys())}"
validate_routed_experts(nvext["routed_experts"])
print(f" routed_experts shape: {len(nvext['routed_experts'])} layers")
print(" PASSED")
def test_completions_streaming():
"""Streaming completions should return routed_experts in final chunk's nvext."""
print("\n--- test_completions_streaming ---")
resp = requests.post(
f"{FRONTEND_URL}/v1/completions",
json={
"model": MODEL,
"prompt": "Hello",
"max_tokens": 5,
"temperature": 0.0,
"stream": True,
},
timeout=30,
stream=True,
)
print(f" Status: {resp.status_code}")
assert resp.status_code == 200
chunks = []
found_routed_experts = False
for line in resp.iter_lines():
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
payload = line[len("data: ") :]
if payload == "[DONE]":
break
chunk = json.loads(payload)
chunks.append(chunk)
nvext = chunk.get("nvext", {})
if "routed_experts" in nvext:
found_routed_experts = True
validate_routed_experts(nvext["routed_experts"])
print(f" routed_experts shape: {len(nvext['routed_experts'])} layers")
print(f" Total chunks: {len(chunks)}")
assert len(chunks) > 0, "Expected at least one chunk"
assert found_routed_experts, "Expected routed_experts in at least one nvext chunk."
print(" PASSED")
def test_chat_completions_streaming():
"""Streaming chat completions should return routed_experts in final chunk's nvext."""
print("\n--- test_chat_completions_streaming ---")
resp = requests.post(
f"{FRONTEND_URL}/v1/chat/completions",
json={
"model": MODEL,
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
"temperature": 0.0,
"stream": True,
},
timeout=30,
stream=True,
)
print(f" Status: {resp.status_code}")
assert resp.status_code == 200
chunks = []
found_routed_experts = False
for line in resp.iter_lines():
line = line.decode("utf-8").strip()
if not line or not line.startswith("data: "):
continue
payload = line[len("data: ") :]
if payload == "[DONE]":
break
chunk = json.loads(payload)
chunks.append(chunk)
nvext = chunk.get("nvext", {})
if "routed_experts" in nvext:
found_routed_experts = True
validate_routed_experts(nvext["routed_experts"])
print(f" routed_experts shape: {len(nvext['routed_experts'])} layers")
print(f" Total chunks: {len(chunks)}")
assert len(chunks) > 0, "Expected at least one chunk"
assert found_routed_experts, "Expected routed_experts in at least one nvext chunk."
print(" PASSED")
def main():
frontend_process = None
backend_process = None
try:
frontend_process = start_frontend()
backend_process = start_sglang_backend()
time.sleep(2)
print("\n" + "=" * 60)
print("Running expert info tests")
print("=" * 60)
test_completions_non_streaming()
test_completions_streaming()
test_chat_completions_streaming()
print("\n" + "=" * 60)
print("All tests passed!")
print("=" * 60)
except KeyboardInterrupt:
print("\nInterrupted by user")
except Exception as e:
print(f"\nTest failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
finally:
print("\nShutting down...")
for name, proc in [
("backend", backend_process),
("frontend", frontend_process),
]:
if proc:
print(f" Stopping {name}...")
proc.send_signal(signal.SIGTERM)
try:
proc.wait(timeout=10)
except subprocess.TimeoutExpired:
proc.kill()
print("Done")
if __name__ == "__main__":
main()
...@@ -437,6 +437,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -437,6 +437,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
.as_ref() .as_ref()
.and_then(|params| params.get("token_ids")) .and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok()); .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
// Get timing info if this is the final response (has finish_reason) // Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() { let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
...@@ -448,12 +453,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes ...@@ -448,12 +453,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateChatCompletionStreamRes
None None
}; };
// Inject nvext if we have worker_id, token_ids, or timing // Inject nvext if we have worker_id, token_ids, timing, or routed experts.
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() { if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
{
let nvext_response = NvExtResponse { let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(), worker_id: worker_id_info.clone(),
timing: timing_info, timing: timing_info,
token_ids: token_ids.clone(), token_ids: token_ids.clone(),
routed_experts,
}; };
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......
...@@ -331,6 +331,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -331,6 +331,11 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
.as_ref() .as_ref()
.and_then(|params| params.get("token_ids")) .and_then(|params| params.get("token_ids"))
.and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok()); .and_then(|v| serde_json::from_value::<Vec<u32>>(v.clone()).ok());
let routed_experts = delta
.disaggregated_params
.as_ref()
.and_then(|params| params.get("routed_experts"))
.cloned();
// Get timing info if this is the final response (has finish_reason) // Get timing info if this is the final response (has finish_reason)
let timing_info: Option<TimingInfo> = if finish_reason.is_some() { let timing_info: Option<TimingInfo> = if finish_reason.is_some() {
...@@ -342,12 +347,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for ...@@ -342,12 +347,17 @@ impl crate::protocols::openai::DeltaGeneratorExt<NvCreateCompletionResponse> for
None None
}; };
// Inject nvext if we have worker_id, token_ids, or timing // Inject nvext if we have worker_id, token_ids, timing, or routed experts.
if worker_id_info.is_some() || token_ids.is_some() || timing_info.is_some() { if worker_id_info.is_some()
|| token_ids.is_some()
|| timing_info.is_some()
|| routed_experts.is_some()
{
let nvext_response = NvExtResponse { let nvext_response = NvExtResponse {
worker_id: worker_id_info.clone(), worker_id: worker_id_info.clone(),
timing: timing_info, timing: timing_info,
token_ids: token_ids.clone(), token_ids: token_ids.clone(),
routed_experts,
}; };
if let Ok(nvext_json) = serde_json::to_value(&nvext_response) { if let Ok(nvext_json) = serde_json::to_value(&nvext_response) {
......
...@@ -87,6 +87,10 @@ pub struct NvExtResponse { ...@@ -87,6 +87,10 @@ pub struct NvExtResponse {
/// Contains the tokenized prompt for reuse in Stage 2 /// Contains the tokenized prompt for reuse in Stage 2
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub token_ids: Option<Vec<u32>>, pub token_ids: Option<Vec<u32>>,
/// Routed expert capture payload (SGLang-specific)
#[serde(skip_serializing_if = "Option::is_none")]
pub routed_experts: Option<serde_json::Value>,
} }
/// NVIDIA LLM extensions to the OpenAI API /// NVIDIA LLM extensions to the OpenAI API
...@@ -135,7 +139,8 @@ pub struct NvExt { ...@@ -135,7 +139,8 @@ pub struct NvExt {
/// Extra fields to be included in the response's nvext /// Extra fields to be included in the response's nvext
/// This is a list of field names that should be populated in the response /// This is a list of field names that should be populated in the response
/// Supported fields: "worker_id", "timing", which has a 1:1 mapping with the NvExtResponse names /// Supported fields include "worker_id", "timing", "routed_experts",
/// which map to fields in NvExtResponse.
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
#[builder(default, setter(strip_option))] #[builder(default, setter(strip_option))]
pub extra_fields: Option<Vec<String>>, pub extra_fields: Option<Vec<String>>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment