Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
9412bdff
Unverified
Commit
9412bdff
authored
Apr 10, 2026
by
William Arnold
Committed by
GitHub
Apr 10, 2026
Browse files
feat: generic tokenizer_manager passthrough route for RL training (#6836)
Signed-off-by:
William Arnold
<
warnold@nvidia.com
>
parent
6b75d6b0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
431 additions
and
2 deletions
+431
-2
components/src/dynamo/sglang/backend_args.py
components/src/dynamo/sglang/backend_args.py
+8
-0
components/src/dynamo/sglang/request_handlers/__init__.py
components/src/dynamo/sglang/request_handlers/__init__.py
+2
-1
components/src/dynamo/sglang/request_handlers/handler_base.py
...onents/src/dynamo/sglang/request_handlers/handler_base.py
+106
-1
components/src/dynamo/sglang/tests/test_sglang_rl_mixin.py
components/src/dynamo/sglang/tests/test_sglang_rl_mixin.py
+315
-0
No files found.
components/src/dynamo/sglang/backend_args.py
View file @
9412bdff
...
...
@@ -101,6 +101,13 @@ class DynamoSGLangArgGroup(ArgGroup):
default
=
False
,
help
=
"Run as video generation worker for video generation (T2V/I2V)."
,
)
add_negatable_bool_argument
(
g
,
flag_name
=
"--enable-rl"
,
env_var
=
"DYN_SGL_ENABLE_RL"
,
default
=
False
,
help
=
"Enable RL training support. Registers the call_tokenizer_manager engine route for generic tokenizer_manager passthrough."
,
)
class
DynamoSGLangConfig
(
ConfigBase
):
...
...
@@ -117,6 +124,7 @@ class DynamoSGLangConfig(ConfigBase):
disagg_config_key
:
Optional
[
str
]
=
None
video_generation_worker
:
bool
enable_rl
:
bool
def
validate
(
self
)
->
None
:
if
not
isinstance
(
self
.
embedding_transfer_mode
,
EmbeddingTransferMode
):
...
...
components/src/dynamo/sglang/request_handlers/__init__.py
View file @
9412bdff
...
...
@@ -5,7 +5,7 @@
from
.embedding
import
EmbeddingWorkerHandler
# Base handlers
from
.handler_base
import
BaseGenerativeHandler
,
BaseWorkerHandler
from
.handler_base
import
BaseGenerativeHandler
,
BaseWorkerHandler
,
RLMixin
# Image diffusion handlers
from
.image_diffusion
import
ImageDiffusionWorkerHandler
...
...
@@ -27,6 +27,7 @@ __all__ = [
# Base handlers
"BaseGenerativeHandler"
,
"BaseWorkerHandler"
,
"RLMixin"
,
# LLM handlers
"DecodeWorkerHandler"
,
"DiffusionWorkerHandler"
,
...
...
components/src/dynamo/sglang/request_handlers/handler_base.py
View file @
9412bdff
...
...
@@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
dataclasses
import
importlib
import
inspect
import
json
import
logging
...
...
@@ -131,7 +133,106 @@ class BaseGenerativeHandler(ABC, Generic[RequestT, ResponseT]):
pass
class
BaseWorkerHandler
(
BaseGenerativeHandler
[
RequestT
,
ResponseT
]):
class
RLMixin
:
"""Mixin providing generic tokenizer_manager passthrough for RL training.
Requires the host class to have ``self.engine`` with a
``tokenizer_manager`` attribute.
"""
engine
:
sgl
.
Engine
# provided by BaseWorkerHandler
def
_resolve_arg
(
self
,
arg
:
Any
)
->
Any
:
"""Resolve a single argument from the generic call body.
If ``arg`` is a dict with exactly one key starting with ``"io_struct."``,
treat it as a typed constructor: import the class from
``sglang.srt.managers.io_struct`` and construct it with the nested kwargs.
Otherwise return the value as-is.
"""
if
isinstance
(
arg
,
dict
)
and
len
(
arg
)
==
1
:
key
=
next
(
iter
(
arg
))
if
isinstance
(
key
,
str
)
and
key
.
startswith
(
"io_struct."
):
class_name
=
key
[
len
(
"io_struct."
)
:]
module
=
importlib
.
import_module
(
"sglang.srt.managers.io_struct"
)
cls
=
getattr
(
module
,
class_name
)
return
cls
(
**
arg
[
key
])
return
arg
def
_normalize_result
(
self
,
result
:
Any
)
->
dict
:
"""Convert a tokenizer_manager method return value to a JSON-safe dict."""
if
result
is
None
:
return
{
"status"
:
"ok"
}
if
isinstance
(
result
,
tuple
):
if
len
(
result
)
==
2
:
return
{
"success"
:
result
[
0
],
"message"
:
result
[
1
]}
if
len
(
result
)
==
3
:
return
{
"success"
:
result
[
0
],
"message"
:
result
[
1
],
"num_paused_requests"
:
result
[
2
],
}
if
isinstance
(
result
,
list
):
return
{
"result"
:
[
dataclasses
.
asdict
(
item
)
if
dataclasses
.
is_dataclass
(
item
)
and
not
isinstance
(
item
,
type
)
else
item
for
item
in
result
]
}
if
dataclasses
.
is_dataclass
(
result
)
and
not
isinstance
(
result
,
type
):
return
dataclasses
.
asdict
(
result
)
if
isinstance
(
result
,
dict
):
return
result
if
isinstance
(
result
,
(
str
,
int
,
float
,
bool
)):
return
{
"result"
:
result
}
return
{
"result"
:
str
(
result
)}
async
def
call_tokenizer_manager
(
self
,
body
:
dict
)
->
dict
:
"""Generic passthrough to any tokenizer_manager method.
Body format::
{
"method": "method_name",
"args": [arg1, arg2, ...],
"kwargs": {"key": value, ...}
}
Each element in args/kwargs is either a plain value or a typed
constructor ``{"io_struct.ClassName": {kwargs}}``.
"""
method_name
=
body
[
"method"
]
raw_args
=
body
.
get
(
"args"
,
[])
raw_kwargs
=
body
.
get
(
"kwargs"
,
{})
args
=
[
self
.
_resolve_arg
(
a
)
for
a
in
raw_args
]
kwargs
=
{
k
:
self
.
_resolve_arg
(
v
)
for
k
,
v
in
raw_kwargs
.
items
()}
tm
=
self
.
engine
.
tokenizer_manager
# Ensure the handle_loop task is running so communicator responses
# are received. Several tokenizer_manager methods call this
# internally, but not all of them (e.g. flush_cache does not).
if
hasattr
(
tm
,
"auto_create_handle_loop"
):
tm
.
auto_create_handle_loop
()
method
=
getattr
(
tm
,
method_name
)
result
=
await
method
(
*
args
,
**
kwargs
)
return
self
.
_normalize_result
(
result
)
def
register_rl_engine_routes
(
self
,
runtime
)
->
None
:
"""Register RL-specific engine routes.
Args:
runtime: The DistributedRuntime instance to register routes on.
"""
runtime
.
register_engine_route
(
"call_tokenizer_manager"
,
self
.
call_tokenizer_manager
)
class
BaseWorkerHandler
(
RLMixin
,
BaseGenerativeHandler
[
RequestT
,
ResponseT
]):
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
...
...
@@ -406,6 +507,10 @@ class BaseWorkerHandler(BaseGenerativeHandler[RequestT, ResponseT]):
runtime
.
register_engine_route
(
"update_weight_version"
,
self
.
update_weight_version
)
if
getattr
(
self
.
config
,
"dynamo_args"
,
None
)
and
getattr
(
self
.
config
.
dynamo_args
,
"enable_rl"
,
False
):
self
.
register_rl_engine_routes
(
runtime
)
@
abstractmethod
def
generate
(
self
,
request
:
RequestT
,
context
:
Context
)
->
AsyncIterator
[
ResponseT
]:
...
...
components/src/dynamo/sglang/tests/test_sglang_rl_mixin.py
0 → 100644
View file @
9412bdff
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for RLMixin generic tokenizer_manager passthrough."""
import
dataclasses
import
sys
import
types
from
types
import
SimpleNamespace
from
unittest.mock
import
AsyncMock
,
MagicMock
,
patch
import
pytest
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
sglang
,
pytest
.
mark
.
gpu_0
,
pytest
.
mark
.
pre_merge
,
]
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@
pytest
.
fixture
(
autouse
=
True
)
def
_stub_sglang_io_struct
(
monkeypatch
):
"""Keep unit tests independent from CUDA-only sglang imports."""
io_struct
=
types
.
ModuleType
(
"sglang.srt.managers.io_struct"
)
monkeypatch
.
setitem
(
sys
.
modules
,
"sglang.srt.managers.io_struct"
,
io_struct
)
yield
io_struct
# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------
class
_TestWorkerHandler
(
BaseWorkerHandler
):
async
def
generate
(
self
,
request
,
context
):
yield
{}
def
_make_handler
()
->
_TestWorkerHandler
:
handler
=
_TestWorkerHandler
.
__new__
(
_TestWorkerHandler
)
handler
.
engine
=
SimpleNamespace
(
tokenizer_manager
=
SimpleNamespace
(
auto_create_handle_loop
=
MagicMock
(),
)
)
return
handler
# ---------------------------------------------------------------------------
# _resolve_arg
# ---------------------------------------------------------------------------
class
TestResolveArg
:
def
setup_method
(
self
):
self
.
handler
=
_make_handler
()
def
test_plain_string
(
self
):
assert
self
.
handler
.
_resolve_arg
(
"hello"
)
==
"hello"
def
test_plain_int
(
self
):
assert
self
.
handler
.
_resolve_arg
(
42
)
==
42
def
test_plain_none
(
self
):
assert
self
.
handler
.
_resolve_arg
(
None
)
is
None
def
test_plain_list
(
self
):
assert
self
.
handler
.
_resolve_arg
([
1
,
2
,
3
])
==
[
1
,
2
,
3
]
def
test_plain_dict_multiple_keys
(
self
):
d
=
{
"a"
:
1
,
"b"
:
2
}
assert
self
.
handler
.
_resolve_arg
(
d
)
==
d
def
test_plain_dict_single_key_no_prefix
(
self
):
d
=
{
"some_key"
:
{
"x"
:
1
}}
assert
self
.
handler
.
_resolve_arg
(
d
)
==
d
def
test_io_struct_constructor
(
self
):
"""A dict with one key starting with 'io_struct.' constructs the class."""
mock_cls
=
MagicMock
()
mock_cls
.
return_value
=
"constructed_instance"
mock_module
=
MagicMock
()
mock_module
.
MyReqInput
=
mock_cls
with
patch
(
"importlib.import_module"
,
return_value
=
mock_module
)
as
imp
:
result
=
self
.
handler
.
_resolve_arg
(
{
"io_struct.MyReqInput"
:
{
"addr"
:
"1.2.3.4"
,
"port"
:
1234
}}
)
imp
.
assert_called_once_with
(
"sglang.srt.managers.io_struct"
)
mock_cls
.
assert_called_once_with
(
addr
=
"1.2.3.4"
,
port
=
1234
)
assert
result
==
"constructed_instance"
def
test_io_struct_empty_kwargs
(
self
):
"""Constructor with empty kwargs."""
mock_cls
=
MagicMock
()
mock_cls
.
return_value
=
"empty_instance"
mock_module
=
MagicMock
()
mock_module
.
PauseGenerationReqInput
=
mock_cls
with
patch
(
"importlib.import_module"
,
return_value
=
mock_module
):
result
=
self
.
handler
.
_resolve_arg
(
{
"io_struct.PauseGenerationReqInput"
:
{}}
)
mock_cls
.
assert_called_once_with
()
assert
result
==
"empty_instance"
# ---------------------------------------------------------------------------
# _normalize_result
# ---------------------------------------------------------------------------
class
TestNormalizeResult
:
def
setup_method
(
self
):
self
.
handler
=
_make_handler
()
def
test_none
(
self
):
assert
self
.
handler
.
_normalize_result
(
None
)
==
{
"status"
:
"ok"
}
def
test_tuple_2
(
self
):
assert
self
.
handler
.
_normalize_result
((
True
,
"done"
))
==
{
"success"
:
True
,
"message"
:
"done"
,
}
def
test_tuple_2_failure
(
self
):
assert
self
.
handler
.
_normalize_result
((
False
,
"error msg"
))
==
{
"success"
:
False
,
"message"
:
"error msg"
,
}
def
test_tuple_3
(
self
):
assert
self
.
handler
.
_normalize_result
((
True
,
"ok"
,
5
))
==
{
"success"
:
True
,
"message"
:
"ok"
,
"num_paused_requests"
:
5
,
}
def
test_dict_passthrough
(
self
):
d
=
{
"foo"
:
"bar"
,
"count"
:
3
}
assert
self
.
handler
.
_normalize_result
(
d
)
is
d
def
test_dataclass
(
self
):
@
dataclasses
.
dataclass
class
FakeResult
:
success
:
bool
nodes_pinned
:
int
result
=
FakeResult
(
success
=
True
,
nodes_pinned
=
10
)
assert
self
.
handler
.
_normalize_result
(
result
)
==
{
"success"
:
True
,
"nodes_pinned"
:
10
,
}
def
test_list_of_dataclasses
(
self
):
@
dataclasses
.
dataclass
class
LoadInfo
:
dp_rank
:
int
num_reqs
:
int
items
=
[
LoadInfo
(
dp_rank
=
0
,
num_reqs
=
5
),
LoadInfo
(
dp_rank
=
1
,
num_reqs
=
3
)]
assert
self
.
handler
.
_normalize_result
(
items
)
==
{
"result"
:
[
{
"dp_rank"
:
0
,
"num_reqs"
:
5
},
{
"dp_rank"
:
1
,
"num_reqs"
:
3
},
]
}
def
test_list_of_plain_values
(
self
):
assert
self
.
handler
.
_normalize_result
([
1
,
"two"
,
3
])
==
{
"result"
:
[
1
,
"two"
,
3
]
}
def
test_list_mixed
(
self
):
@
dataclasses
.
dataclass
class
Info
:
val
:
int
items
=
[
Info
(
val
=
1
),
"plain"
,
42
]
assert
self
.
handler
.
_normalize_result
(
items
)
==
{
"result"
:
[{
"val"
:
1
},
"plain"
,
42
]
}
def
test_other_value
(
self
):
assert
self
.
handler
.
_normalize_result
(
42
)
==
{
"result"
:
42
}
assert
self
.
handler
.
_normalize_result
(
"text"
)
==
{
"result"
:
"text"
}
def
test_non_serializable_falls_back_to_str
(
self
):
obj
=
object
()
result
=
self
.
handler
.
_normalize_result
(
obj
)
assert
result
==
{
"result"
:
str
(
obj
)}
# ---------------------------------------------------------------------------
# call_tokenizer_manager
# ---------------------------------------------------------------------------
class
TestCallTokenizerManager
:
def
setup_method
(
self
):
self
.
handler
=
_make_handler
()
@
pytest
.
mark
.
asyncio
async
def
test_method_only
(
self
):
"""Calling with just 'method', no args/kwargs."""
self
.
handler
.
engine
.
tokenizer_manager
.
flush_cache
=
AsyncMock
(
return_value
=
None
)
result
=
await
self
.
handler
.
call_tokenizer_manager
({
"method"
:
"flush_cache"
})
self
.
handler
.
engine
.
tokenizer_manager
.
flush_cache
.
assert_awaited_once_with
()
assert
result
==
{
"status"
:
"ok"
}
@
pytest
.
mark
.
asyncio
async
def
test_with_plain_args
(
self
):
"""Plain value args are passed through."""
self
.
handler
.
engine
.
tokenizer_manager
.
some_method
=
AsyncMock
(
return_value
=
(
True
,
"ok"
)
)
result
=
await
self
.
handler
.
call_tokenizer_manager
(
{
"method"
:
"some_method"
,
"args"
:
[
"arg1"
,
42
]}
)
self
.
handler
.
engine
.
tokenizer_manager
.
some_method
.
assert_awaited_once_with
(
"arg1"
,
42
)
assert
result
==
{
"success"
:
True
,
"message"
:
"ok"
}
@
pytest
.
mark
.
asyncio
async
def
test_with_kwargs
(
self
):
"""kwargs including null are passed through."""
self
.
handler
.
engine
.
tokenizer_manager
.
some_method
=
AsyncMock
(
return_value
=
(
True
,
"done"
)
)
result
=
await
self
.
handler
.
call_tokenizer_manager
(
{
"method"
:
"some_method"
,
"args"
:
[
"positional"
],
"kwargs"
:
{
"request"
:
None
},
}
)
self
.
handler
.
engine
.
tokenizer_manager
.
some_method
.
assert_awaited_once_with
(
"positional"
,
request
=
None
)
assert
result
==
{
"success"
:
True
,
"message"
:
"done"
}
@
pytest
.
mark
.
asyncio
async
def
test_with_io_struct_arg
(
self
):
"""io_struct constructor args are resolved before calling."""
mock_cls
=
MagicMock
()
constructed
=
MagicMock
()
mock_cls
.
return_value
=
constructed
mock_module
=
MagicMock
()
mock_module
.
InitWeightsUpdateGroupReqInput
=
mock_cls
self
.
handler
.
engine
.
tokenizer_manager
.
init_weights_update_group
=
AsyncMock
(
return_value
=
(
True
,
"group initialized"
)
)
with
patch
(
"importlib.import_module"
,
return_value
=
mock_module
):
result
=
await
self
.
handler
.
call_tokenizer_manager
(
{
"method"
:
"init_weights_update_group"
,
"args"
:
[
{
"io_struct.InitWeightsUpdateGroupReqInput"
:
{
"master_address"
:
"1.2.3.4"
,
"master_port"
:
1234
,
"rank_offset"
:
0
,
"world_size"
:
4
,
}
}
],
"kwargs"
:
{
"request"
:
None
},
}
)
mock_cls
.
assert_called_once_with
(
master_address
=
"1.2.3.4"
,
master_port
=
1234
,
rank_offset
=
0
,
world_size
=
4
)
self
.
handler
.
engine
.
tokenizer_manager
.
init_weights_update_group
.
assert_awaited_once_with
(
constructed
,
request
=
None
)
assert
result
==
{
"success"
:
True
,
"message"
:
"group initialized"
}
@
pytest
.
mark
.
asyncio
async
def
test_tuple_3_result
(
self
):
"""3-tuple results include num_paused_requests."""
self
.
handler
.
engine
.
tokenizer_manager
.
update_weights_from_disk
=
AsyncMock
(
return_value
=
(
True
,
"updated"
,
3
)
)
result
=
await
self
.
handler
.
call_tokenizer_manager
(
{
"method"
:
"update_weights_from_disk"
,
"args"
:
[
"req_obj"
],
"kwargs"
:
{
"request"
:
None
},
}
)
assert
result
==
{
"success"
:
True
,
"message"
:
"updated"
,
"num_paused_requests"
:
3
,
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment