Unverified Commit 9e9240db authored by ptarasiewiczNV's avatar ptarasiewiczNV Committed by GitHub
Browse files

fix(omni): fix disaggregated Qwen2.5-Omni serving pipeline (#8301)


Signed-off-by: default avatarPiotr Tarasiewicz <ptarasiewicz@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 (1M context) <noreply@anthropic.com>
parent 9293abb8
...@@ -13,9 +13,11 @@ from typing import Any, AsyncGenerator ...@@ -13,9 +13,11 @@ from typing import Any, AsyncGenerator
import yaml import yaml
from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors from vllm_omni.distributed.omni_connectors import initialize_orchestrator_connectors
from vllm_omni.engine.orchestrator import build_engine_core_request_from_tokens
from vllm_omni.entrypoints.async_omni import AsyncOmni from vllm_omni.entrypoints.async_omni import AsyncOmni
from vllm_omni.entrypoints.stage_utils import serialize_obj, shm_write_bytes from vllm_omni.entrypoints.stage_utils import serialize_obj, shm_write_bytes
from vllm_omni.entrypoints.utils import load_stage_configs_from_yaml from vllm_omni.entrypoints.utils import load_stage_configs_from_yaml
from vllm_omni.inputs.data import OmniTokensPrompt
from dynamo import prometheus_names from dynamo import prometheus_names
from dynamo.llm import ModelType from dynamo.llm import ModelType
...@@ -64,7 +66,6 @@ class OmniStageWorker: ...@@ -64,7 +66,6 @@ class OmniStageWorker:
self.engine = engine self.engine = engine
self.stage_id = stage_id self.stage_id = stage_id
self.connectors = connectors # {(from_stage, to_stage): vllm_omni connector} self.connectors = connectors # {(from_stage, to_stage): vllm_omni connector}
self.final_output: bool = getattr(stage_config, "final_output", False)
self._output_modalities = output_modalities or [] self._output_modalities = output_modalities or []
self._default_video_fps = default_video_fps self._default_video_fps = default_video_fps
self.stage_config = stage_config self.stage_config = stage_config
...@@ -116,12 +117,28 @@ class OmniStageWorker: ...@@ -116,12 +117,28 @@ class OmniStageWorker:
if isinstance(prompt, list) and len(prompt) == 1: if isinstance(prompt, list) and len(prompt) == 1:
prompt = prompt[0] prompt = prompt[0]
else: else:
# No processor: use the most recent fetched stage output directly. # No processor: check if the upstream output has the
prompt = stage_list[-1].engine_outputs[0] # structure needed to build an OmniEngineCoreRequest
# (e.g. code2wav receiving token_ids from talker).
# Otherwise fall back to passing the raw data directly.
upstream = stage_list[-1].engine_outputs[0]
if hasattr(upstream, "outputs") and upstream.outputs:
try:
prompt = self._build_engine_core_request_from_upstream(
stage_list, request_id, sampling_params_list_override
)
except RuntimeError as e:
yield {"error": str(e), "finished": True}
return
else:
prompt = upstream
elif req.request_id is not None: elif req.request_id is not None:
# Stage 0 via router: raw request forwarded with request_id — parse it. # Stage 0 via router: raw request forwarded with request_id — parse it.
parsed = parse_omni_request( parsed = await parse_omni_request(
request, self._output_modalities, self._default_video_fps request,
self._output_modalities,
self._default_video_fps,
tokenizer_getter=self.engine.get_tokenizer,
) )
prompt = parsed["engine_inputs"] prompt = parsed["engine_inputs"]
original_prompt = parsed["original_prompt"] original_prompt = parsed["original_prompt"]
...@@ -142,7 +159,7 @@ class OmniStageWorker: ...@@ -142,7 +159,7 @@ class OmniStageWorker:
try: try:
async for chunk in self.engine.generate( async for chunk in self.engine.generate(
prompt, request_id, sampling_params_list=sp prompt, request_id=request_id, sampling_params_list=sp
): ):
last_result = chunk last_result = chunk
except Exception as e: except Exception as e:
...@@ -157,45 +174,43 @@ class OmniStageWorker: ...@@ -157,45 +174,43 @@ class OmniStageWorker:
return return
# --- Write output --- # --- Write output ---
if not self.final_output: # Check for a downstream connector first, regardless of final_output.
from_s, to_s = _connector_key(self.stage_id, self.stage_id + 1) # In vllm-omni's native mode, multiple stages can set final_output=True
connector = self.connectors.get((from_s, to_s)) # (meaning "produces user-visible output"). In Dynamo's disaggregated
if connector is not None: # mode the actual pipeline topology — connector edges from the YAML —
try: # determines whether output should go to a connector or to SHM.
ok, _, metadata = connector.put( # type: ignore[arg-type] from_s, to_s = _connector_key(self.stage_id, self.stage_id + 1)
from_s, to_s, request_id, last_result connector = self.connectors.get((from_s, to_s))
) if connector is not None:
except Exception as e: try:
logger.error( ok, _, metadata = connector.put( # type: ignore[arg-type]
"Stage %d: connector.put() raised %s: %s", from_s, to_s, request_id, last_result
self.stage_id, )
type(e).__name__, except Exception as e:
e, logger.error(
exc_info=True, "Stage %d: connector.put() raised %s: %s",
) self.stage_id,
yield {"error": f"connector.put() raised: {e}", "finished": True} type(e).__name__,
return e,
if not ok: exc_info=True,
yield {"error": "connector.put() failed", "finished": True} )
return yield {"error": f"connector.put() raised: {e}", "finished": True}
out: dict = {
"original_prompt": original_prompt,
"stage_connector_refs": {
**{str(k): v for k, v in stage_connector_refs.items()},
str(self.stage_id): metadata,
},
"finished": True,
}
if sampling_params_list_override is not None:
out["sampling_params_list"] = sampling_params_list_override
yield out
return return
logger.warning( if not ok:
"Stage %d: no connector found for edge (%s→%s), falling through to SHM", yield {"error": "connector.put() failed", "finished": True}
self.stage_id, return
from_s, out: dict = {
to_s, "original_prompt": original_prompt,
) "stage_connector_refs": {
**{str(k): v for k, v in stage_connector_refs.items()},
str(self.stage_id): metadata,
},
"finished": True,
}
if sampling_params_list_override is not None:
out["sampling_params_list"] = sampling_params_list_override
yield out
return
# Final stage → router: write output to shared memory and return the SHM handle. # Final stage → router: write output to shared memory and return the SHM handle.
# The router reads it back via shm_deserialize() to format the response. # The router reads it back via shm_deserialize() to format the response.
...@@ -207,6 +222,60 @@ class OmniStageWorker: ...@@ -207,6 +222,60 @@ class OmniStageWorker:
shm_meta = shm_write_bytes(serialize_obj(last_result), name=request_id) shm_meta = shm_write_bytes(serialize_obj(last_result), name=request_id)
yield {"shm_meta": shm_meta, "finished": True} yield {"shm_meta": shm_meta, "finished": True}
def _build_engine_core_request_from_upstream(
self,
stage_list: list[_Proxy],
request_id: str,
sampling_params_list_override: dict | None,
):
"""Build an OmniEngineCoreRequest from the upstream stage output.
Used for stages without a custom processor (e.g. code2wav). Mirrors
what the native orchestrator does via ``build_engine_core_request_from_tokens``
and ``_forward_to_next_stage``. Building an ``EngineCoreRequest``
bypasses ``InputProcessor.process_inputs()`` which would fail for
non-autoregressive stages (``worker_type: generation``) with
"This model does not support generation".
Raises RuntimeError on unexpected upstream output structure.
"""
try:
# engine_outputs[0]: first (and only) RequestOutput — Dynamo
# processes one request at a time per stage.
# outputs[0]: first CompletionOutput (n=1 sampling).
# Matches native orchestrator's process_engine_inputs pattern.
upstream = stage_list[-1].engine_outputs[0]
token_ids = upstream.outputs[0].token_ids
except (IndexError, AttributeError) as e:
raise RuntimeError(
f"Stage {self.stage_id}: cannot extract token_ids from "
f"upstream output: {e}"
) from e
tokens_prompt = OmniTokensPrompt(prompt_token_ids=list(token_ids))
sp_list = _build_sampling_params(
self.stage_config, sampling_params_list_override
)
params = sp_list[0] if sp_list else None
prompt = build_engine_core_request_from_tokens(
request_id=request_id,
prompt=tokens_prompt,
params=params,
)
# Pre-built EngineCoreRequests skip the output processor registration
# in _build_add_request_message (the isinstance(prompt, EngineCoreRequest)
# branch bypasses that block). Register manually so that the engine's
# output processor can match the response back to this request.
prompt.external_req_id = prompt.request_id
self.engine.engine.output_processors[0].add_request(
request=prompt,
prompt=None,
parent_req=None,
request_index=0,
queue=None,
)
return prompt
def _fetch_stage_inputs( def _fetch_stage_inputs(
self, stage_connector_refs: dict[int, Any], request_id: str self, stage_connector_refs: dict[int, Any], request_id: str
) -> list[_Proxy]: ) -> list[_Proxy]:
......
...@@ -15,10 +15,12 @@ from pydantic import BaseModel, ConfigDict, model_validator ...@@ -15,10 +15,12 @@ from pydantic import BaseModel, ConfigDict, model_validator
class StageEngine(Protocol): class StageEngine(Protocol):
"""Any engine that can generate outputs for a single pipeline stage. """Any engine that can generate outputs for a single pipeline stage.
Matches AsyncOmni.generate() signature — the only vllm_omni engine Matches AsyncOmni — the only vllm_omni engine with a consistent async
with a consistent async generator interface for both LLM and diffusion. generator interface for both LLM and diffusion.
""" """
engine: Any # AsyncOmniEngine — exposes output_processors for registration
def generate( def generate(
self, self,
prompt: Any, prompt: Any,
...@@ -28,6 +30,10 @@ class StageEngine(Protocol): ...@@ -28,6 +30,10 @@ class StageEngine(Protocol):
) -> AsyncGenerator[Any, None]: ) -> AsyncGenerator[Any, None]:
... ...
def get_tokenizer(self) -> Any:
"""Return the tokenizer (may be async — callers should await)."""
...
class StageOutput(BaseModel): class StageOutput(BaseModel):
"""Validated output dict from a stage worker. """Validated output dict from a stage worker.
......
...@@ -37,11 +37,19 @@ def build_original_prompt(request: dict, nvext: dict, height: int, width: int) - ...@@ -37,11 +37,19 @@ def build_original_prompt(request: dict, nvext: dict, height: int, width: int) -
return prompt return prompt
def parse_omni_request( async def parse_omni_request(
request: dict, output_modalities: list, default_video_fps: int = 16 request: dict,
output_modalities: list,
default_video_fps: int = 16,
tokenizer_getter=None,
) -> dict: ) -> dict:
"""Parse a raw frontend request into engine_inputs, original_prompt, sampling_params_list. """Parse a raw frontend request into engine_inputs, original_prompt, sampling_params_list.
Args:
tokenizer_getter: async callable returning a tokenizer (e.g. engine.get_tokenizer).
When provided, chat requests are formatted through the model's chat template
so the thinker receives the same prompt as native ``vllm serve --omni``.
Returns: Returns:
engine_inputs: text prompt (str or OmniTextPrompt) for the stage 0 engine engine_inputs: text prompt (str or OmniTextPrompt) for the stage 0 engine
original_prompt: rich prompt dict with geometry/params for processor functions original_prompt: rich prompt dict with geometry/params for processor functions
...@@ -74,6 +82,22 @@ def parse_omni_request( ...@@ -74,6 +82,22 @@ def parse_omni_request(
(m.get("content", "") for m in reversed(messages) if m.get("role") == "user"), (m.get("content", "") for m in reversed(messages) if m.get("role") == "user"),
request.get("prompt", ""), request.get("prompt", ""),
) )
# Apply chat template when a tokenizer is available. The native
# OpenAI API server applies the template before the engine sees it;
# without it the thinker receives bare text instead of the full
# chat-formatted prompt.
if messages and tokenizer_getter is not None:
try:
tokenizer = await tokenizer_getter()
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
except Exception:
logging.getLogger(__name__).debug(
"Chat template not available, using raw text"
)
return { return {
"engine_inputs": text, "engine_inputs": text,
"original_prompt": {"prompt": text}, "original_prompt": {"prompt": text},
......
...@@ -24,12 +24,17 @@ pytestmark = [ ...@@ -24,12 +24,17 @@ pytestmark = [
class _MockEngine: class _MockEngine:
engine = None
def generate(self, prompt, request_id="", *, sampling_params_list=None): def generate(self, prompt, request_id="", *, sampling_params_list=None):
async def _gen(): async def _gen():
yield {} yield {}
return _gen() return _gen()
def get_tokenizer(self):
return None
def test_stage_engine_protocol_satisfied(): def test_stage_engine_protocol_satisfied():
assert isinstance(_MockEngine(), StageEngine) assert isinstance(_MockEngine(), StageEngine)
......
...@@ -264,36 +264,39 @@ class TestParseOmniRequest: ...@@ -264,36 +264,39 @@ class TestParseOmniRequest:
"""parse_omni_request: original_prompt only has prompt/negative_prompt, """parse_omni_request: original_prompt only has prompt/negative_prompt,
geometry goes into sampling_params_list dict.""" geometry goes into sampling_params_list dict."""
def test_image_sampling_params_has_geometry(self): @pytest.mark.asyncio
async def test_image_sampling_params_has_geometry(self):
request = { request = {
"prompt": "a sunset", "prompt": "a sunset",
"size": "512x512", "size": "512x512",
"output_modalities": ["image"], "output_modalities": ["image"],
} }
result = parse_omni_request(request, ["image"]) result = await parse_omni_request(request, ["image"])
sp = result["sampling_params_list"] sp = result["sampling_params_list"]
assert sp["height"] == 512 assert sp["height"] == 512
assert sp["width"] == 512 assert sp["width"] == 512
def test_image_original_prompt_no_geometry(self): @pytest.mark.asyncio
async def test_image_original_prompt_no_geometry(self):
request = { request = {
"prompt": "a sunset", "prompt": "a sunset",
"size": "512x512", "size": "512x512",
"output_modalities": ["image"], "output_modalities": ["image"],
} }
result = parse_omni_request(request, ["image"]) result = await parse_omni_request(request, ["image"])
op = result["original_prompt"] op = result["original_prompt"]
assert op["prompt"] == "a sunset" assert op["prompt"] == "a sunset"
assert "height" not in op assert "height" not in op
assert "width" not in op assert "width" not in op
def test_nvext_params_go_into_sampling_params_not_prompt(self): @pytest.mark.asyncio
async def test_nvext_params_go_into_sampling_params_not_prompt(self):
request = { request = {
"prompt": "x", "prompt": "x",
"size": "512x512", "size": "512x512",
"nvext": {"num_inference_steps": 30, "guidance_scale": 4.0}, "nvext": {"num_inference_steps": 30, "guidance_scale": 4.0},
} }
result = parse_omni_request(request, ["image"]) result = await parse_omni_request(request, ["image"])
sp = result["sampling_params_list"] sp = result["sampling_params_list"]
assert sp["num_inference_steps"] == 30 assert sp["num_inference_steps"] == 30
assert sp["guidance_scale"] == 4.0 assert sp["guidance_scale"] == 4.0
......
...@@ -26,7 +26,9 @@ pytestmark = [ ...@@ -26,7 +26,9 @@ pytestmark = [
class _MockEngine: class _MockEngine:
"""Satisfies StageEngine Protocol — matches AsyncOmni.generate() signature.""" """Satisfies StageEngine Protocol — matches AsyncOmni interface."""
engine = None # satisfies StageEngine.engine
def __init__(self, output=None): def __init__(self, output=None):
self.received_prompt = None self.received_prompt = None
...@@ -44,6 +46,9 @@ class _MockEngine: ...@@ -44,6 +46,9 @@ class _MockEngine:
return _gen() return _gen()
async def get_tokenizer(self):
return None
class _ErrorEngine: class _ErrorEngine:
def generate(self, prompt, request_id="", *, sampling_params_list=None): def generate(self, prompt, request_id="", *, sampling_params_list=None):
...@@ -128,6 +133,50 @@ async def test_stage_connector_refs_input_path(): ...@@ -128,6 +133,50 @@ async def test_stage_connector_refs_input_path():
assert chunks[0]["original_prompt"] == {"prompt": "hello"} assert chunks[0]["original_prompt"] == {"prompt": "hello"}
@pytest.mark.asyncio
async def test_stage_connector_refs_builds_engine_core_request():
"""Stage N>0 without processor: upstream with .outputs builds OmniEngineCoreRequest."""
engine = _MockEngine()
# Mock upstream output that looks like a real RequestOutput (has .outputs[0].token_ids)
mock_output = SimpleNamespace(
outputs=[SimpleNamespace(token_ids=[100, 200, 300])],
prompt_token_ids=[1, 2],
)
in_connector = MagicMock()
in_connector.get.return_value = (
mock_output # raw object, not {"engine_inputs": ...}
)
out_connector = MagicMock()
out_connector.put.return_value = (True, 0, {"name": "ref1"})
worker = _make_worker(
engine=engine,
connectors={("0", "1"): in_connector, ("1", "2"): out_connector},
stage_id=1,
stage_config=_make_stage_config(
default_sampling_params={"temperature": 0.9, "max_tokens": 100},
),
)
# Mock the engine's output_processors for registration
engine.engine = MagicMock()
request = {
"request_id": "req-ecr",
"original_prompt": {"prompt": "hello"},
"stage_connector_refs": {"0": {"name": "ref0"}},
}
_ = [chunk async for chunk in worker.generate(request, _MockContext())]
# The engine should receive an OmniEngineCoreRequest (not the raw dict)
assert hasattr(engine.received_prompt, "prompt_token_ids")
assert engine.received_prompt.prompt_token_ids == [100, 200, 300]
# Output processor should have been registered
engine.engine.output_processors[0].add_request.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stage_connector_refs_with_processor(): async def test_stage_connector_refs_with_processor():
"""Stage N>0 with processor: processor receives stage_list built from connector output.""" """Stage N>0 with processor: processor receives stage_list built from connector output."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment