Unverified Commit 9412bdff authored by William Arnold's avatar William Arnold Committed by GitHub
Browse files

feat: generic tokenizer_manager passthrough route for RL training (#6836)


Signed-off-by: default avatarWilliam Arnold <warnold@nvidia.com>
parent 6b75d6b0
......@@ -101,6 +101,13 @@ class DynamoSGLangArgGroup(ArgGroup):
default=False,
help="Run as video generation worker for video generation (T2V/I2V).",
)
add_negatable_bool_argument(
g,
flag_name="--enable-rl",
env_var="DYN_SGL_ENABLE_RL",
default=False,
help="Enable RL training support. Registers the call_tokenizer_manager engine route for generic tokenizer_manager passthrough.",
)
class DynamoSGLangConfig(ConfigBase):
......@@ -117,6 +124,7 @@ class DynamoSGLangConfig(ConfigBase):
disagg_config_key: Optional[str] = None
video_generation_worker: bool
enable_rl: bool
def validate(self) -> None:
if not isinstance(self.embedding_transfer_mode, EmbeddingTransferMode):
......
......@@ -5,7 +5,7 @@
from .embedding import EmbeddingWorkerHandler
# Base handlers
from .handler_base import BaseGenerativeHandler, BaseWorkerHandler
from .handler_base import BaseGenerativeHandler, BaseWorkerHandler, RLMixin
# Image diffusion handlers
from .image_diffusion import ImageDiffusionWorkerHandler
......@@ -27,6 +27,7 @@ __all__ = [
# Base handlers
"BaseGenerativeHandler",
"BaseWorkerHandler",
"RLMixin",
# LLM handlers
"DecodeWorkerHandler",
"DiffusionWorkerHandler",
......
......@@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import dataclasses
import importlib
import inspect
import json
import logging
......@@ -131,7 +133,106 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
pass
class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
class RLMixin:
"""Mixin providing generic tokenizer_manager passthrough for RL training.
Requires the host class to have ``self.engine`` with a
``tokenizer_manager`` attribute.
"""
engine: sgl.Engine # provided by BaseWorkerHandler
def _resolve_arg(self, arg: Any) -> Any:
"""Resolve a single argument from the generic call body.
If ``arg`` is a dict with exactly one key starting with ``"io_struct."``,
treat it as a typed constructor: import the class from
``sglang.srt.managers.io_struct`` and construct it with the nested kwargs.
Otherwise return the value as-is.
"""
if isinstance(arg, dict) and len(arg) == 1:
key = next(iter(arg))
if isinstance(key, str) and key.startswith("io_struct."):
class_name = key[len("io_struct.") :]
module = importlib.import_module("sglang.srt.managers.io_struct")
cls = getattr(module, class_name)
return cls(**arg[key])
return arg
def _normalize_result(self, result: Any) -> dict:
"""Convert a tokenizer_manager method return value to a JSON-safe dict."""
if result is None:
return {"status": "ok"}
if isinstance(result, tuple):
if len(result) == 2:
return {"success": result[0], "message": result[1]}
if len(result) == 3:
return {
"success": result[0],
"message": result[1],
"num_paused_requests": result[2],
}
if isinstance(result, list):
return {
"result": [
dataclasses.asdict(item)
if dataclasses.is_dataclass(item) and not isinstance(item, type)
else item
for item in result
]
}
if dataclasses.is_dataclass(result) and not isinstance(result, type):
return dataclasses.asdict(result)
if isinstance(result, dict):
return result
if isinstance(result, (str, int, float, bool)):
return {"result": result}
return {"result": str(result)}
async def call_tokenizer_manager(self, body: dict) -> dict:
"""Generic passthrough to any tokenizer_manager method.
Body format::
{
"method": "method_name",
"args": [arg1, arg2, ...],
"kwargs": {"key": value, ...}
}
Each element in args/kwargs is either a plain value or a typed
constructor ``{"io_struct.ClassName": {kwargs}}``.
"""
method_name = body["method"]
raw_args = body.get("args", [])
raw_kwargs = body.get("kwargs", {})
args = [self._resolve_arg(a) for a in raw_args]
kwargs = {k: self._resolve_arg(v) for k, v in raw_kwargs.items()}
tm = self.engine.tokenizer_manager
# Ensure the handle_loop task is running so communicator responses
# are received. Several tokenizer_manager methods call this
# internally, but not all of them (e.g. flush_cache does not).
if hasattr(tm, "auto_create_handle_loop"):
tm.auto_create_handle_loop()
method = getattr(tm, method_name)
result = await method(*args, **kwargs)
return self._normalize_result(result)
def register_rl_engine_routes(self, runtime) -> None:
"""Register RL-specific engine routes.
Args:
runtime: The DistributedRuntime instance to register routes on.
"""
runtime.register_engine_route(
"call_tokenizer_manager", self.call_tokenizer_manager
)
class BaseWorkerHandler(RLMixin, BaseGenerativeHandler[RequestT, ResponseT]):
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
......@@ -406,6 +507,10 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
runtime.register_engine_route(
"update_weight_version", self.update_weight_version
)
if getattr(self.config, "dynamo_args", None) and getattr(
self.config.dynamo_args, "enable_rl", False
):
self.register_rl_engine_routes(runtime)
@abstractmethod
def generate(self, request: RequestT, context: Context) -> AsyncIterator[ResponseT]:
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for RLMixin generic tokenizer_manager passthrough."""
import dataclasses
import sys
import types
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
pytestmark = [
pytest.mark.unit,
pytest.mark.sglang,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _stub_sglang_io_struct(monkeypatch):
"""Keep unit tests independent from CUDA-only sglang imports."""
io_struct = types.ModuleType("sglang.srt.managers.io_struct")
monkeypatch.setitem(sys.modules, "sglang.srt.managers.io_struct", io_struct)
yield io_struct
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
class _TestWorkerHandler(BaseWorkerHandler):
async def generate(self, request, context):
yield {}
def _make_handler() -> _TestWorkerHandler:
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.engine = SimpleNamespace(
tokenizer_manager=SimpleNamespace(
auto_create_handle_loop=MagicMock(),
)
)
return handler
# ---------------------------------------------------------------------------
# _resolve_arg
# ---------------------------------------------------------------------------
class TestResolveArg:
def setup_method(self):
self.handler = _make_handler()
def test_plain_string(self):
assert self.handler._resolve_arg("hello") == "hello"
def test_plain_int(self):
assert self.handler._resolve_arg(42) == 42
def test_plain_none(self):
assert self.handler._resolve_arg(None) is None
def test_plain_list(self):
assert self.handler._resolve_arg([1, 2, 3]) == [1, 2, 3]
def test_plain_dict_multiple_keys(self):
d = {"a": 1, "b": 2}
assert self.handler._resolve_arg(d) == d
def test_plain_dict_single_key_no_prefix(self):
d = {"some_key": {"x": 1}}
assert self.handler._resolve_arg(d) == d
def test_io_struct_constructor(self):
"""A dict with one key starting with 'io_struct.' constructs the class."""
mock_cls = MagicMock()
mock_cls.return_value = "constructed_instance"
mock_module = MagicMock()
mock_module.MyReqInput = mock_cls
with patch("importlib.import_module", return_value=mock_module) as imp:
result = self.handler._resolve_arg(
{"io_struct.MyReqInput": {"addr": "1.2.3.4", "port": 1234}}
)
imp.assert_called_once_with("sglang.srt.managers.io_struct")
mock_cls.assert_called_once_with(addr="1.2.3.4", port=1234)
assert result == "constructed_instance"
def test_io_struct_empty_kwargs(self):
"""Constructor with empty kwargs."""
mock_cls = MagicMock()
mock_cls.return_value = "empty_instance"
mock_module = MagicMock()
mock_module.PauseGenerationReqInput = mock_cls
with patch("importlib.import_module", return_value=mock_module):
result = self.handler._resolve_arg(
{"io_struct.PauseGenerationReqInput": {}}
)
mock_cls.assert_called_once_with()
assert result == "empty_instance"
# ---------------------------------------------------------------------------
# _normalize_result
# ---------------------------------------------------------------------------
class TestNormalizeResult:
def setup_method(self):
self.handler = _make_handler()
def test_none(self):
assert self.handler._normalize_result(None) == {"status": "ok"}
def test_tuple_2(self):
assert self.handler._normalize_result((True, "done")) == {
"success": True,
"message": "done",
}
def test_tuple_2_failure(self):
assert self.handler._normalize_result((False, "error msg")) == {
"success": False,
"message": "error msg",
}
def test_tuple_3(self):
assert self.handler._normalize_result((True, "ok", 5)) == {
"success": True,
"message": "ok",
"num_paused_requests": 5,
}
def test_dict_passthrough(self):
d = {"foo": "bar", "count": 3}
assert self.handler._normalize_result(d) is d
def test_dataclass(self):
@dataclasses.dataclass
class FakeResult:
success: bool
nodes_pinned: int
result = FakeResult(success=True, nodes_pinned=10)
assert self.handler._normalize_result(result) == {
"success": True,
"nodes_pinned": 10,
}
def test_list_of_dataclasses(self):
@dataclasses.dataclass
class LoadInfo:
dp_rank: int
num_reqs: int
items = [LoadInfo(dp_rank=0, num_reqs=5), LoadInfo(dp_rank=1, num_reqs=3)]
assert self.handler._normalize_result(items) == {
"result": [
{"dp_rank": 0, "num_reqs": 5},
{"dp_rank": 1, "num_reqs": 3},
]
}
def test_list_of_plain_values(self):
assert self.handler._normalize_result([1, "two", 3]) == {
"result": [1, "two", 3]
}
def test_list_mixed(self):
@dataclasses.dataclass
class Info:
val: int
items = [Info(val=1), "plain", 42]
assert self.handler._normalize_result(items) == {
"result": [{"val": 1}, "plain", 42]
}
def test_other_value(self):
assert self.handler._normalize_result(42) == {"result": 42}
assert self.handler._normalize_result("text") == {"result": "text"}
def test_non_serializable_falls_back_to_str(self):
obj = object()
result = self.handler._normalize_result(obj)
assert result == {"result": str(obj)}
# ---------------------------------------------------------------------------
# call_tokenizer_manager
# ---------------------------------------------------------------------------
class TestCallTokenizerManager:
def setup_method(self):
self.handler = _make_handler()
@pytest.mark.asyncio
async def test_method_only(self):
"""Calling with just 'method', no args/kwargs."""
self.handler.engine.tokenizer_manager.flush_cache = AsyncMock(return_value=None)
result = await self.handler.call_tokenizer_manager({"method": "flush_cache"})
self.handler.engine.tokenizer_manager.flush_cache.assert_awaited_once_with()
assert result == {"status": "ok"}
@pytest.mark.asyncio
async def test_with_plain_args(self):
"""Plain value args are passed through."""
self.handler.engine.tokenizer_manager.some_method = AsyncMock(
return_value=(True, "ok")
)
result = await self.handler.call_tokenizer_manager(
{"method": "some_method", "args": ["arg1", 42]}
)
self.handler.engine.tokenizer_manager.some_method.assert_awaited_once_with(
"arg1", 42
)
assert result == {"success": True, "message": "ok"}
@pytest.mark.asyncio
async def test_with_kwargs(self):
"""kwargs including null are passed through."""
self.handler.engine.tokenizer_manager.some_method = AsyncMock(
return_value=(True, "done")
)
result = await self.handler.call_tokenizer_manager(
{
"method": "some_method",
"args": ["positional"],
"kwargs": {"request": None},
}
)
self.handler.engine.tokenizer_manager.some_method.assert_awaited_once_with(
"positional", request=None
)
assert result == {"success": True, "message": "done"}
@pytest.mark.asyncio
async def test_with_io_struct_arg(self):
"""io_struct constructor args are resolved before calling."""
mock_cls = MagicMock()
constructed = MagicMock()
mock_cls.return_value = constructed
mock_module = MagicMock()
mock_module.InitWeightsUpdateGroupReqInput = mock_cls
self.handler.engine.tokenizer_manager.init_weights_update_group = AsyncMock(
return_value=(True, "group initialized")
)
with patch("importlib.import_module", return_value=mock_module):
result = await self.handler.call_tokenizer_manager(
{
"method": "init_weights_update_group",
"args": [
{
"io_struct.InitWeightsUpdateGroupReqInput": {
"master_address": "1.2.3.4",
"master_port": 1234,
"rank_offset": 0,
"world_size": 4,
}
}
],
"kwargs": {"request": None},
}
)
mock_cls.assert_called_once_with(
master_address="1.2.3.4", master_port=1234, rank_offset=0, world_size=4
)
self.handler.engine.tokenizer_manager.init_weights_update_group.assert_awaited_once_with(
constructed, request=None
)
assert result == {"success": True, "message": "group initialized"}
@pytest.mark.asyncio
async def test_tuple_3_result(self):
"""3-tuple results include num_paused_requests."""
self.handler.engine.tokenizer_manager.update_weights_from_disk = AsyncMock(
return_value=(True, "updated", 3)
)
result = await self.handler.call_tokenizer_manager(
{
"method": "update_weights_from_disk",
"args": ["req_obj"],
"kwargs": {"request": None},
}
)
assert result == {
"success": True,
"message": "updated",
"num_paused_requests": 3,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment