Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
80d8aa19
"launch/vscode:/vscode.git/clone" did not exist on "da83f820ebe7e2f353c559d18fba2d3ec4ce01a3"
Unverified
Commit
80d8aa19
authored
Aug 18, 2025
by
ishandhanani
Committed by
GitHub
Aug 18, 2025
Browse files
feat(sglang): unify entry point for SGLang backend architecture (#2493)
parent
28400714
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
624 additions
and
985 deletions
+624
-985
components/backends/sglang/src/dynamo/sglang/common/__init__.py
...ents/backends/sglang/src/dynamo/sglang/common/__init__.py
+0
-38
components/backends/sglang/src/dynamo/sglang/common/base_handlers.py
...backends/sglang/src/dynamo/sglang/common/base_handlers.py
+0
-62
components/backends/sglang/src/dynamo/sglang/common/sgl_utils.py
...nts/backends/sglang/src/dynamo/sglang/common/sgl_utils.py
+0
-102
components/backends/sglang/src/dynamo/sglang/decode_worker/__init__.py
...ckends/sglang/src/dynamo/sglang/decode_worker/__init__.py
+4
-0
components/backends/sglang/src/dynamo/sglang/decode_worker/__main__.py
...ckends/sglang/src/dynamo/sglang/decode_worker/__main__.py
+10
-1
components/backends/sglang/src/dynamo/sglang/decode_worker/main.py
...s/backends/sglang/src/dynamo/sglang/decode_worker/main.py
+0
-94
components/backends/sglang/src/dynamo/sglang/main.py
components/backends/sglang/src/dynamo/sglang/main.py
+146
-0
components/backends/sglang/src/dynamo/sglang/protocol.py
components/backends/sglang/src/dynamo/sglang/protocol.py
+1
-16
components/backends/sglang/src/dynamo/sglang/publisher.py
components/backends/sglang/src/dynamo/sglang/publisher.py
+139
-0
components/backends/sglang/src/dynamo/sglang/register.py
components/backends/sglang/src/dynamo/sglang/register.py
+74
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/__init__.py
...nds/sglang/src/dynamo/sglang/request_handlers/__init__.py
+14
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py
...lang/src/dynamo/sglang/request_handlers/decode_handler.py
+109
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/handler_base.py
...sglang/src/dynamo/sglang/request_handlers/handler_base.py
+36
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py
...ang/src/dynamo/sglang/request_handlers/prefill_handler.py
+74
-0
components/backends/sglang/src/dynamo/sglang/utils/clear_namespace.py
...ackends/sglang/src/dynamo/sglang/utils/clear_namespace.py
+1
-16
components/backends/sglang/src/dynamo/sglang/utils/sgl_http_server.py
...ackends/sglang/src/dynamo/sglang/utils/sgl_http_server.py
+0
-207
components/backends/sglang/src/dynamo/sglang/worker/__init__.py
...ents/backends/sglang/src/dynamo/sglang/worker/__init__.py
+4
-0
components/backends/sglang/src/dynamo/sglang/worker/__main__.py
...ents/backends/sglang/src/dynamo/sglang/worker/__main__.py
+11
-1
components/backends/sglang/src/dynamo/sglang/worker/main.py
components/backends/sglang/src/dynamo/sglang/worker/main.py
+0
-447
container/Dockerfile.sglang
container/Dockerfile.sglang
+1
-1
No files found.
components/backends/sglang/src/dynamo/sglang/common/__init__.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Base handlers
from
.base_handlers
import
BaseWorkerHandler
# Protocol types
from
.protocol
import
(
DisaggPreprocessedRequest
,
PreprocessedRequest
,
SamplingOptions
,
StopConditions
,
TokenIdType
,
)
# Utilities
from
.sgl_utils
import
(
graceful_shutdown
,
parse_sglang_args_inc
,
reserve_free_port
,
setup_native_endpoints
,
)
__all__
=
[
# Protocol types
"DisaggPreprocessedRequest"
,
"PreprocessedRequest"
,
"SamplingOptions"
,
"StopConditions"
,
"TokenIdType"
,
# Utilities
"parse_sglang_args_inc"
,
"reserve_free_port"
,
"graceful_shutdown"
,
"setup_native_endpoints"
,
# Base handlers
"BaseWorkerHandler"
,
]
components/backends/sglang/src/dynamo/sglang/common/base_handlers.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
class
BaseWorkerHandler
(
ABC
):
"""
Abstract base class for sglang request handlers. We use this to implement native sglang endpoints for
workers
"""
@
abstractmethod
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
server_args
:
ServerArgs
,
component
,
decode_client
:
Optional
[
Any
]
=
None
,
):
self
.
engine
=
engine
self
.
server_args
=
server_args
self
.
component
=
component
@
abstractmethod
async
def
generate
(
self
,
request
):
"""Generate tokens from the engine"""
...
async
def
flush_cache
(
self
,
request
:
dict
):
"""Flush KV cache for each worker"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
flush_cache
()
yield
True
async
def
start_expert_distribution_record
(
self
,
request
:
dict
):
"""
Start recording expert distribution.
"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
start_expert_distribution_record
()
yield
True
async
def
stop_expert_distribution_record
(
self
,
request
:
dict
):
"""
Stop recording expert distribution.
"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
stop_expert_distribution_record
()
yield
True
async
def
dump_expert_distribution_record
(
self
,
request
:
dict
):
"""
Dumps the expert distribution record to the directory specified in the environment variable `SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR`.
"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
dump_expert_distribution_record
()
yield
True
components/backends/sglang/src/dynamo/sglang/common/sgl_utils.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
contextlib
import
logging
import
socket
from
argparse
import
Namespace
from
sglang.srt.server_args
import
ServerArgs
class
SkipTokenizerInitError
(
RuntimeError
):
def
__init__
(
self
):
super
().
__init__
(
"--skip-tokenizer-init flag is required"
)
def
parse_sglang_args_inc
(
args
:
list
[
str
])
->
ServerArgs
:
# Currently we only support Dynamo doing the tokenization, so we must give
# sglang the skip-tokenizer-init flag. We don't default it because this is temporary.
# Allow the --version and --help flags through.
temp_need_tok
=
[
"--skip-tokenizer-init"
,
"--version"
,
"--help"
,
"-h"
]
if
not
any
(
w
in
args
for
w
in
temp_need_tok
):
raise
SkipTokenizerInitError
()
parser
=
argparse
.
ArgumentParser
()
bootstrap_port
=
_reserve_disaggregation_bootstrap_port
()
ServerArgs
.
add_cli_args
(
parser
)
parsed_args
=
parser
.
parse_args
(
args
)
if
not
any
(
arg
.
startswith
(
"--disaggregation-bootstrap-port"
)
for
arg
in
args
):
args_dict
=
vars
(
parsed_args
)
args_dict
[
"disaggregation_bootstrap_port"
]
=
bootstrap_port
parsed_args
=
Namespace
(
**
args_dict
)
return
ServerArgs
.
from_cli_args
(
parsed_args
)
@
contextlib
.
contextmanager
def
reserve_free_port
(
host
:
str
=
"localhost"
):
"""
Find and reserve a free port until context exits.
"""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
try
:
sock
.
bind
((
host
,
0
))
_
,
port
=
sock
.
getsockname
()
yield
port
finally
:
sock
.
close
()
def
_reserve_disaggregation_bootstrap_port
():
"""
Each worker requires a unique port for disaggregation_bootstrap_port.
We use an existing utility function that reserves a free port on your
machine to avoid collisions.
"""
with
reserve_free_port
()
as
port
:
return
port
async
def
graceful_shutdown
(
runtime
):
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
runtime
.
shutdown
()
logging
.
info
(
"DistributedRuntime shutdown complete"
)
def
setup_native_endpoints
(
server_args
,
component
,
handler
):
"""Setup sgl native endpoints"""
# flush cache
flush_endpoint
=
component
.
endpoint
(
"flush_cache"
)
tasks
=
[]
tasks
.
append
(
flush_endpoint
.
serve_endpoint
(
handler
.
flush_cache
))
# expert distribution endpoints
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
start_expert_distribution_endpoint
=
component
.
endpoint
(
"start_expert_distribution_record"
)
stop_expert_distribution_endpoint
=
component
.
endpoint
(
"stop_expert_distribution_record"
)
dump_expert_distribution_endpoint
=
component
.
endpoint
(
"dump_expert_distribution_record"
)
tasks
.
append
(
start_expert_distribution_endpoint
.
serve_endpoint
(
handler
.
start_expert_distribution_record
)
)
tasks
.
append
(
stop_expert_distribution_endpoint
.
serve_endpoint
(
handler
.
stop_expert_distribution_record
)
)
tasks
.
append
(
dump_expert_distribution_endpoint
.
serve_endpoint
(
handler
.
dump_expert_distribution_record
)
)
return
tasks
components/backends/sglang/src/dynamo/sglang/decode_worker/__init__.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# This module is deprecated. Use `python3 -m dynamo.sglang` instead.
components/backends/sglang/src/dynamo/sglang/decode_worker/__main__.py
View file @
80d8aa19
...
@@ -2,7 +2,16 @@
...
@@ -2,7 +2,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dynamo.sglang.decode_worker.main
import
main
import
logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.main
import
main
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
configure_dynamo_logging
()
logging
.
warning
(
"DEPRECATION WARNING: `python3 -m dynamo.sglang.decode_worker` is deprecated and will be removed in dynamo v0.5.0."
"Use `python3 -m dynamo.sglang` instead."
,
)
main
()
main
()
components/backends/sglang/src/dynamo/sglang/decode_worker/main.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
asyncio
import
logging
import
signal
import
sys
import
msgspec
import
sglang
as
sgl
import
uvloop
from
sglang.srt.server_args
import
ServerArgs
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.common
import
(
BaseWorkerHandler
,
graceful_shutdown
,
parse_sglang_args_inc
,
setup_native_endpoints
,
)
configure_dynamo_logging
()
class
DecodeRequestHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
server_args
:
ServerArgs
,
component
):
super
().
__init__
(
engine
,
server_args
,
component
)
logging
.
info
(
"Decode request handler initialized"
)
async
def
generate
(
self
,
request
:
str
):
req
=
msgspec
.
json
.
decode
(
request
,
type
=
dict
)
results
=
await
self
.
engine
.
async_generate
(
input_ids
=
req
[
"request"
][
"token_ids"
]
if
req
[
"request"
][
"batch_token_ids"
]
is
None
else
req
[
"request"
][
"batch_token_ids"
],
sampling_params
=
req
[
"sampling_params"
],
stream
=
True
,
bootstrap_host
=
req
[
"bootstrap_host"
],
bootstrap_port
=
req
[
"bootstrap_port"
],
bootstrap_room
=
req
[
"bootstrap_room"
],
)
async
for
result
in
results
:
yield
result
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
# Set up signal handler for graceful shutdown
loop
=
asyncio
.
get_running_loop
()
def
signal_handler
():
# Schedule the shutdown coroutine instead of calling it directly
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers set up for graceful shutdown"
)
server_args
=
parse_sglang_args_inc
(
sys
.
argv
[
1
:])
await
init
(
runtime
,
server_args
)
async
def
init
(
runtime
:
DistributedRuntime
,
server_args
:
ServerArgs
):
"""Initialize decode worker"""
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
"dynamo"
).
component
(
"decode"
)
await
component
.
create_service
()
handler
=
DecodeRequestHandler
(
engine
,
server_args
,
component
)
gen_endpoint
=
component
.
endpoint
(
"generate"
)
tasks
=
[
gen_endpoint
.
serve_endpoint
(
handler
.
generate
)]
tasks
.
extend
(
setup_native_endpoints
(
server_args
,
component
,
handler
))
await
asyncio
.
gather
(
*
tasks
)
def
main
():
uvloop
.
install
()
asyncio
.
run
(
worker
())
if
__name__
==
"__main__"
:
main
()
components/backends/sglang/src/dynamo/sglang/main.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
json
import
logging
import
signal
import
sys
import
sglang
as
sgl
import
uvloop
from
sglang.srt.utils
import
get_ip
from
dynamo.llm
import
ZmqKvEventPublisher
,
ZmqKvEventPublisherConfig
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.args
import
Config
,
DisaggregationMode
,
parse_args
from
dynamo.sglang.publisher
import
setup_sgl_metrics
from
dynamo.sglang.register
import
register_llm_with_runtime_config
from
dynamo.sglang.request_handlers
import
DecodeWorkerHandler
,
PrefillWorkerHandler
configure_dynamo_logging
()
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
loop
=
asyncio
.
get_running_loop
()
def
signal_handler
():
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers will trigger a graceful shutdown of the runtime"
)
config
=
parse_args
(
sys
.
argv
[
1
:])
if
config
.
serving_mode
!=
DisaggregationMode
.
PREFILL
:
await
init
(
runtime
,
config
)
else
:
await
init_prefill
(
runtime
,
config
)
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
):
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
dynamo_args
.
namespace
).
component
(
dynamo_args
.
component
)
await
component
.
create_service
()
generate_endpoint
=
component
.
endpoint
(
dynamo_args
.
endpoint
)
# TODO: think about implementing DisaggregationStrategy for P->D
# TODO: implement a `next` field in the config to dynamically set the next client
prefill_client
=
None
if
config
.
serving_mode
==
DisaggregationMode
.
DECODE
:
logging
.
info
(
"Initializing prefill client"
)
prefill_client
=
(
await
runtime
.
namespace
(
dynamo_args
.
namespace
)
.
component
(
"prefill"
)
.
endpoint
(
"generate"
)
.
client
()
)
publisher
,
metrics_task
=
await
setup_sgl_metrics
(
engine
,
component
)
kv_publisher
=
None
if
server_args
.
kv_events_config
:
kv_events
=
json
.
loads
(
server_args
.
kv_events_config
)
ep
=
kv_events
.
get
(
"endpoint"
)
zmq_ep
=
ep
.
replace
(
"*"
,
get_ip
())
if
ep
else
None
zmq_config
=
ZmqKvEventPublisherConfig
(
worker_id
=
generate_endpoint
.
lease_id
(),
kv_block_size
=
server_args
.
page_size
,
zmq_endpoint
=
zmq_ep
,
)
logging
.
info
(
f
"Setting up ZMQ kv event publisher at
{
zmq_ep
}
"
)
kv_publisher
=
ZmqKvEventPublisher
(
component
=
component
,
config
=
zmq_config
)
handler
=
DecodeWorkerHandler
(
component
,
engine
,
config
,
publisher
,
kv_publisher
,
prefill_client
)
await
register_llm_with_runtime_config
(
engine
,
generate_endpoint
,
server_args
,
dynamo_args
.
migration_limit
)
try
:
# TODO: add in native endpoints
await
asyncio
.
gather
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
False
),
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
finally
:
metrics_task
.
cancel
()
try
:
await
metrics_task
except
asyncio
.
CancelledError
:
logging
.
info
(
"Metrics task succesfully cancelled"
)
pass
handler
.
cleanup
()
async
def
init_prefill
(
runtime
:
DistributedRuntime
,
config
:
Config
):
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
dynamo_args
.
namespace
).
component
(
dynamo_args
.
component
)
await
component
.
create_service
()
generate_endpoint
=
component
.
endpoint
(
dynamo_args
.
endpoint
)
handler
=
PrefillWorkerHandler
(
component
,
engine
,
config
)
tasks
=
[
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
True
)]
try
:
await
asyncio
.
gather
(
*
tasks
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
finally
:
handler
.
cleanup
()
async
def
graceful_shutdown
(
runtime
):
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
runtime
.
shutdown
()
logging
.
info
(
"DistributedRuntime shutdown complete"
)
def
main
():
uvloop
.
run
(
worker
())
if
__name__
==
"__main__"
:
main
()
components/backends/sglang/src/dynamo/sglang/
common/
protocol.py
→
components/backends/sglang/src/dynamo/sglang/protocol.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
...
@@ -58,7 +46,4 @@ class PreprocessedRequest(BaseModel):
...
@@ -58,7 +46,4 @@ class PreprocessedRequest(BaseModel):
class
DisaggPreprocessedRequest
(
BaseModel
):
class
DisaggPreprocessedRequest
(
BaseModel
):
request
:
PreprocessedRequest
request
:
PreprocessedRequest
sampling_params
:
dict
sampling_params
:
dict
bootstrap_host
:
Union
[
str
,
List
[
str
]]
bootstrap_port
:
Union
[
int
,
List
[
int
]]
bootstrap_room
:
Union
[
int
,
List
[
int
]]
data_parallel_rank
:
Optional
[
int
]
=
None
data_parallel_rank
:
Optional
[
int
]
=
None
components/backends/sglang/src/dynamo/sglang/publisher.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
from
typing
import
Optional
import
sglang
as
sgl
import
zmq
import
zmq.asyncio
from
sglang.srt.utils
import
get_zmq_socket
from
dynamo.llm
import
(
ForwardPassMetrics
,
KvStats
,
SpecDecodeStats
,
WorkerMetricsPublisher
,
WorkerStats
,
)
from
dynamo.runtime
import
Component
class
DynamoSglangStatPublisher
:
"""
Handles SGLang metrics reception and publishing.
"""
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
component
:
Component
)
->
None
:
self
.
engine
=
engine
self
.
inner
=
WorkerMetricsPublisher
()
self
.
inner
.
create_endpoint
(
component
)
# Set default values (can be overridden later if needed)
self
.
request_total_slots
=
1024
self
.
dp_rank
=
0
self
.
num_gpu_block
=
1024
# ZMQ setup for receiving scheduler metrics
self
.
_ctx
=
zmq
.
asyncio
.
Context
()
# type: ignore
self
.
_sock
=
get_zmq_socket
(
self
.
_ctx
,
zmq
.
PULL
,
self
.
engine
.
port_args
.
metrics_ipc_name
,
True
# type: ignore
)
async
def
run
(
self
)
->
None
:
"""Main loop to receive scheduler metrics and publish them"""
while
True
:
try
:
kv_metrics
=
await
self
.
_sock
.
recv_pyobj
()
# type: ignore
self
.
record_values
(
request_active_slots
=
kv_metrics
.
request_active_slots
,
request_total_slots
=
kv_metrics
.
request_total_slots
,
kv_active_blocks
=
kv_metrics
.
kv_active_blocks
,
kv_total_blocks
=
kv_metrics
.
kv_total_blocks
,
num_requests_waiting
=
kv_metrics
.
num_requests_waiting
,
gpu_cache_usage_perc
=
kv_metrics
.
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
=
kv_metrics
.
gpu_prefix_cache_hit_rate
,
data_parallel_rank
=
kv_metrics
.
data_parallel_rank
,
)
except
Exception
:
logging
.
exception
(
"Failed to receive or publish SGLang scheduler metrics"
)
def
init_publish
(
self
)
->
None
:
worker_stats
=
WorkerStats
(
request_active_slots
=
0
,
request_total_slots
=
self
.
request_total_slots
,
num_requests_waiting
=
0
,
data_parallel_rank
=
self
.
dp_rank
,
)
kv_stats
=
KvStats
(
kv_active_blocks
=
0
,
kv_total_blocks
=
self
.
num_gpu_block
,
gpu_cache_usage_perc
=
0.0
,
gpu_prefix_cache_hit_rate
=
0.0
,
)
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
None
,
)
logging
.
info
(
"Sending dummy metrics to initialize"
)
self
.
inner
.
publish
(
metrics
)
def
record
(
self
,
worker_stats
:
WorkerStats
,
kv_stats
:
KvStats
,
spec_decode_stats
:
Optional
[
SpecDecodeStats
]
=
None
,
)
->
None
:
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
spec_decode_stats
,
)
self
.
inner
.
publish
(
metrics
)
def
record_values
(
self
,
request_active_slots
:
int
,
request_total_slots
:
int
,
kv_active_blocks
:
int
,
kv_total_blocks
:
int
,
num_requests_waiting
:
int
,
gpu_cache_usage_perc
:
float
,
gpu_prefix_cache_hit_rate
:
float
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
spec_decode_stats
:
Optional
[
SpecDecodeStats
]
=
None
,
)
->
None
:
worker_stats
=
WorkerStats
(
request_active_slots
=
request_active_slots
,
request_total_slots
=
request_total_slots
,
num_requests_waiting
=
num_requests_waiting
,
data_parallel_rank
=
data_parallel_rank
if
data_parallel_rank
is
not
None
else
self
.
dp_rank
,
)
kv_stats
=
KvStats
(
kv_active_blocks
=
kv_active_blocks
,
kv_total_blocks
=
kv_total_blocks
,
gpu_cache_usage_perc
=
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
=
gpu_prefix_cache_hit_rate
,
)
self
.
record
(
worker_stats
,
kv_stats
,
spec_decode_stats
)
async
def
setup_sgl_metrics
(
engine
:
sgl
.
Engine
,
component
:
Component
,
)
->
tuple
[
DynamoSglangStatPublisher
,
asyncio
.
Task
]:
"""
Convenience bootstrap: create endpoint, publish an initial update, and start the metrics loop.
"""
publisher
=
DynamoSglangStatPublisher
(
engine
,
component
)
publisher
.
init_publish
()
task
=
asyncio
.
create_task
(
publisher
.
run
())
logging
.
info
(
"SGLang metrics loop started"
)
return
publisher
,
task
components/backends/sglang/src/dynamo/sglang/register.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Optional
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
from
dynamo._core
import
Endpoint
from
dynamo.llm
import
ModelRuntimeConfig
,
ModelType
,
register_llm
async
def
register_llm_with_runtime_config
(
engine
:
sgl
.
Engine
,
endpoint
:
Endpoint
,
server_args
:
ServerArgs
,
migration_limit
:
int
,
):
"""Register LLM with runtime config"""
runtime_config
=
await
_get_runtime_config
(
engine
)
try
:
await
register_llm
(
ModelType
.
Backend
,
endpoint
,
server_args
.
model_path
,
server_args
.
served_model_name
,
kv_cache_block_size
=
server_args
.
page_size
,
migration_limit
=
migration_limit
,
runtime_config
=
runtime_config
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to register with runtime config:
{
e
}
"
)
return
None
async
def
_get_runtime_config
(
engine
:
sgl
.
Engine
)
->
Optional
[
ModelRuntimeConfig
]:
"""Get runtime config from SGLang engine"""
try
:
# Try to check if the engine has a scheduler attribute with the computed values
if
hasattr
(
engine
,
"scheduler_info"
)
and
engine
.
scheduler_info
is
not
None
:
runtime_config
=
ModelRuntimeConfig
()
# Get max_total_num_tokens from scheduler_info
if
"max_total_num_tokens"
in
engine
.
scheduler_info
:
max_total_tokens
=
engine
.
scheduler_info
[
"max_total_num_tokens"
]
if
max_total_tokens
and
hasattr
(
engine
.
tokenizer_manager
,
"server_args"
):
page_size
=
engine
.
tokenizer_manager
.
server_args
.
page_size
if
page_size
:
runtime_config
.
total_kv_blocks
=
(
max_total_tokens
+
page_size
-
1
)
//
page_size
logging
.
info
(
f
"Got total KV blocks from scheduler:
{
runtime_config
.
total_kv_blocks
}
"
f
"(max_total_tokens=
{
max_total_tokens
}
, page_size=
{
page_size
}
)"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
return
runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging
.
warning
(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return
None
except
Exception
as
e
:
logging
.
warning
(
f
"Failed to get runtime config:
{
e
}
. Proceeding without it."
)
return
None
components/backends/sglang/src/dynamo/sglang/request_handlers/__init__.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
.decode_handler
import
DecodeWorkerHandler
# Base handlers
from
.handler_base
import
BaseWorkerHandler
from
.prefill_handler
import
PrefillWorkerHandler
__all__
=
[
"BaseWorkerHandler"
,
"DecodeWorkerHandler"
,
"PrefillWorkerHandler"
,
]
components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
sglang
as
sgl
from
dynamo._core
import
Client
,
Component
from
dynamo.llm
import
WorkerMetricsPublisher
,
ZmqKvEventPublisher
from
dynamo.sglang.args
import
Config
,
DisaggregationMode
from
dynamo.sglang.protocol
import
DisaggPreprocessedRequest
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
class
DecodeWorkerHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
component
:
Component
,
engine
:
sgl
.
Engine
,
config
:
Config
,
metrics_publisher
:
WorkerMetricsPublisher
,
kv_publisher
:
ZmqKvEventPublisher
=
None
,
prefill_client
:
Client
=
None
,
):
super
().
__init__
(
component
,
engine
,
config
,
metrics_publisher
,
kv_publisher
,
prefill_client
)
if
self
.
serving_mode
==
DisaggregationMode
.
DECODE
:
if
self
.
prefill_client
is
None
:
raise
ValueError
(
"prefill_client must be provided when serving_mode is decode"
)
self
.
prefill_client
=
prefill_client
logging
.
info
(
"Decode worker handler initialized"
)
logging
.
info
(
"Worker handler initialized"
)
def
cleanup
(
self
):
self
.
engine
.
shutdown
()
logging
.
info
(
"Engine shutdown"
)
super
().
cleanup
()
def
_build_sampling_params
(
self
,
request
:
dict
)
->
dict
:
sampling_params
=
{}
if
request
[
"sampling_options"
][
"temperature"
]:
sampling_params
[
"temperature"
]
=
request
[
"sampling_options"
][
"temperature"
]
if
request
[
"sampling_options"
][
"top_p"
]:
sampling_params
[
"top_p"
]
=
request
[
"sampling_options"
][
"top_p"
]
if
request
[
"sampling_options"
][
"top_k"
]:
sampling_params
[
"top_k"
]
=
request
[
"sampling_options"
][
"top_k"
]
sampling_params
[
"max_new_tokens"
]
=
request
[
"stop_conditions"
][
"max_tokens"
]
if
request
[
"stop_conditions"
][
"ignore_eos"
]:
sampling_params
[
"ignore_eos"
]
=
request
[
"stop_conditions"
][
"ignore_eos"
]
return
sampling_params
async
def
generate
(
self
,
request
:
str
):
sampling_params
=
self
.
_build_sampling_params
(
request
)
if
self
.
serving_mode
==
DisaggregationMode
.
DECODE
:
# request the bootstrap info from the target prefill worker
prefill_stream
=
await
self
.
prefill_client
.
generate
(
DisaggPreprocessedRequest
(
request
=
request
,
sampling_params
=
sampling_params
,
).
model_dump_json
()
)
bootstrap_info
=
None
async
for
info
in
prefill_stream
:
bootstrap_info
=
info
.
data
()
break
if
not
bootstrap_info
:
raise
RuntimeError
(
"No bootstrap info received from prefill worker"
)
decode
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
bootstrap_host
=
bootstrap_info
[
"bootstrap_host"
],
bootstrap_port
=
bootstrap_info
[
"bootstrap_port"
],
bootstrap_room
=
bootstrap_info
[
"bootstrap_room"
],
)
async
for
out
in
self
.
_process_stream
(
decode
):
yield
out
else
:
agg
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
)
async
for
out
in
self
.
_process_stream
(
agg
):
yield
out
async
def
_process_stream
(
self
,
stream_source
):
num_output_tokens_so_far
=
0
async
for
res
in
stream_source
:
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
else
:
next_total_toks
=
len
(
res
[
"output_ids"
])
out
=
{
"token_ids"
:
res
[
"output_ids"
][
num_output_tokens_so_far
:]}
num_output_tokens_so_far
=
next_total_toks
yield
out
components/backends/sglang/src/dynamo/sglang/request_handlers/handler_base.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
import
sglang
as
sgl
from
dynamo._core
import
Client
,
Component
from
dynamo.llm
import
WorkerMetricsPublisher
,
ZmqKvEventPublisher
from
dynamo.sglang.args
import
Config
class
BaseWorkerHandler
(
ABC
):
def
__init__
(
self
,
component
:
Component
,
engine
:
sgl
.
Engine
,
config
:
Config
,
metrics_publisher
:
WorkerMetricsPublisher
=
None
,
kv_publisher
:
ZmqKvEventPublisher
=
None
,
prefill_client
:
Client
=
None
,
):
self
.
component
=
component
self
.
engine
=
engine
self
.
config
=
config
self
.
metrics_publisher
=
metrics_publisher
self
.
kv_publisher
=
kv_publisher
self
.
prefill_client
=
prefill_client
self
.
serving_mode
=
config
.
serving_mode
@
abstractmethod
async
def
generate
(
self
,
request
:
str
):
pass
def
cleanup
(
self
):
pass
components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
import
random
import
socket
import
msgspec
import
sglang
as
sgl
from
sglang.srt.utils
import
get_ip
from
dynamo._core
import
Component
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
class
PrefillWorkerHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
component
:
Component
,
engine
:
sgl
.
Engine
,
config
:
Config
):
self
.
engine
=
engine
self
.
bootstrap_host
,
self
.
bootstrap_port
=
self
.
_get_bootstrap_info
()
super
().
__init__
(
component
,
engine
,
config
,
None
,
None
,
None
)
logging
.
info
(
f
"Prefill worker handler initialized - bootstrap host:
{
self
.
bootstrap_host
}
, bootstrap port:
{
self
.
bootstrap_port
}
"
)
def
_generate_bootstrap_room
(
self
):
return
random
.
randint
(
0
,
2
**
63
-
1
)
def
cleanup
(
self
):
self
.
engine
.
shutdown
()
logging
.
info
(
"Prefill engine shutdown"
)
super
().
cleanup
()
def
_get_bootstrap_info
(
self
):
"""Bootstrap info from tokenizer manager"""
inner_tm
=
self
.
engine
.
tokenizer_manager
bootstrap_port
=
inner_tm
.
server_args
.
disaggregation_bootstrap_port
if
inner_tm
.
server_args
.
dist_init_addr
:
bootstrap_host
=
socket
.
gethostbyname
(
inner_tm
.
server_args
.
dist_init_addr
.
split
(
":"
)[
0
]
)
else
:
bootstrap_host
=
get_ip
()
return
bootstrap_host
,
bootstrap_port
async
def
generate
(
self
,
request
:
str
):
req
=
msgspec
.
json
.
decode
(
request
,
type
=
dict
)
bootstrap_room
=
self
.
_generate_bootstrap_room
()
bootstrap_info
=
{
"bootstrap_host"
:
self
.
bootstrap_host
,
"bootstrap_port"
:
self
.
bootstrap_port
,
"bootstrap_room"
:
bootstrap_room
,
}
yield
bootstrap_info
results
=
await
self
.
engine
.
async_generate
(
input_ids
=
req
[
"request"
][
"token_ids"
],
sampling_params
=
req
[
"sampling_params"
],
stream
=
True
,
bootstrap_host
=
self
.
bootstrap_host
,
bootstrap_port
=
self
.
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
)
asyncio
.
create_task
(
self
.
_consume_results
(
results
))
async
def
_consume_results
(
self
,
results
):
async
for
_
in
results
:
pass
components/backends/sglang/src/dynamo/sglang/utils/clear_namespace.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2020 Atalaya Tech. Inc
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import
argparse
import
argparse
import
asyncio
import
asyncio
...
@@ -23,7 +9,6 @@ from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker
...
@@ -23,7 +9,6 @@ from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
configure_dynamo_logging
()
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
@
dynamo_worker
()
@
dynamo_worker
()
...
@@ -34,7 +19,7 @@ async def clear_namespace(runtime: DistributedRuntime, namespace: str):
...
@@ -34,7 +19,7 @@ async def clear_namespace(runtime: DistributedRuntime, namespace: str):
{},
{},
)
)
await
etcd_kv_cache
.
clear_all
()
await
etcd_kv_cache
.
clear_all
()
logg
er
.
info
(
f
"Cleared /
{
namespace
}
in EtcdKvCache"
)
logg
ing
.
info
(
f
"Cleared /
{
namespace
}
in EtcdKvCache"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
components/backends/sglang/src/dynamo/sglang/utils/sgl_http_server.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
asyncio
import
logging
import
uvicorn
import
uvloop
from
fastapi
import
FastAPI
from
fastapi.routing
import
APIRoute
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
FLUSH_CACHE_ENDPOINT
=
"flush_cache"
configure_dynamo_logging
()
class
SglangHttpServer
:
def
__init__
(
self
,
port
:
int
,
runtime
:
DistributedRuntime
,
args
):
self
.
port
=
port
self
.
app
=
FastAPI
()
self
.
runtime
=
runtime
self
.
args
=
args
self
.
setup_routes
()
async
def
_discover_endpoints
(
self
,
endpoint_name
):
"""Discover endpoints that match the pattern"""
etcd_client
=
self
.
runtime
.
etcd_client
()
if
etcd_client
is
None
:
raise
RuntimeError
(
"Runtime has no etcd client; cannot discover endpoints"
)
prefix
=
"instances/"
kvs
=
await
etcd_client
.
kv_get_prefix
(
prefix
)
# Collect (namespace, component) combos that expose the target endpoint
discovered
=
set
()
for
kv
in
kvs
:
key
=
kv
[
"key"
]
if
isinstance
(
kv
,
dict
)
else
kv
.
key
if
isinstance
(
key
,
bytes
):
key
=
key
.
decode
()
if
not
key
.
startswith
(
prefix
):
continue
segments
=
key
.
split
(
"/"
)
# Format: instances/<ns>/<comp>/<endpoint:lease>
if
len
(
segments
)
<
4
:
continue
ns
,
comp
,
ep_with_lease
=
segments
[
1
],
segments
[
2
],
segments
[
3
]
if
self
.
args
.
ns
and
ns
!=
self
.
args
.
ns
:
continue
if
self
.
args
.
comp
and
comp
!=
self
.
args
.
comp
:
continue
ep_name
=
ep_with_lease
.
split
(
":"
,
1
)[
0
]
if
ep_name
==
endpoint_name
:
discovered
.
add
((
ns
,
comp
))
logging
.
debug
(
f
"Discovered endpoint:
{
ns
}
.
{
comp
}
"
)
logging
.
debug
(
f
"Endpoint discovery complete. Found
{
len
(
discovered
)
}
matching endpoints"
)
return
discovered
async
def
_dispatch_command
(
self
,
endpoint_name
:
str
,
payload
:
dict
|
str
=
"{}"
,
success_message
:
str
=
""
):
"""Dispatches a command to all instances of a discovered endpoint."""
discovered
=
await
self
.
_discover_endpoints
(
endpoint_name
=
endpoint_name
)
if
not
discovered
:
return
{
"message"
:
"No matching endpoints found"
,
"success"
:
False
}
logging
.
debug
(
f
"Found components:
{
', '
.
join
([
f
'
{
ns
}
.
{
comp
}
' for ns, comp in discovered])
}
"
)
for
ns
,
comp
in
discovered
:
ep
=
self
.
runtime
.
namespace
(
ns
).
component
(
comp
).
endpoint
(
endpoint_name
)
client
=
await
ep
.
client
()
await
client
.
wait_for_instances
()
ids
=
client
.
instance_ids
()
logging
.
debug
(
f
"--
{
ns
}
.
{
comp
}
:
{
len
(
ids
)
}
instances --"
)
for
inst_id
in
ids
:
try
:
stream
=
await
client
.
direct
(
payload
,
inst_id
)
async
for
stream_payload
in
stream
:
logging
.
debug
(
f
"[
{
ns
}
.
{
comp
}
][
{
inst_id
}
] ->
{
stream_payload
}
"
)
except
Exception
as
e
:
logging
.
error
(
f
"[
{
ns
}
.
{
comp
}
][
{
inst_id
}
]
{
endpoint_name
}
error:
{
e
}
"
)
return
{
"message"
:
success_message
,
"success"
:
True
}
def
setup_routes
(
self
):
@
self
.
app
.
post
(
"/flush_cache"
)
async
def
flush_cache
():
"""Flush the radix cache."""
endpoint_name
=
self
.
args
.
endpoint
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Cache flush initiated"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Cache flush error:
{
e
}
"
)
return
{
"message"
:
f
"Cache flush failed:
{
str
(
e
)
}
"
,
"success"
:
False
}
@
self
.
app
.
post
(
"/start_expert_distribution_record"
)
async
def
start_expert_distribution_record
():
"""Start recording expert distribution."""
endpoint_name
=
"start_expert_distribution_record"
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Expert distribution recording started"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Start expert distribution error:
{
e
}
"
)
return
{
"message"
:
f
"Start expert distribution failed:
{
str
(
e
)
}
"
,
"success"
:
False
,
}
@
self
.
app
.
post
(
"/stop_expert_distribution_record"
)
async
def
stop_expert_distribution_record
():
"""Stop recording expert distribution."""
endpoint_name
=
"stop_expert_distribution_record"
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Expert distribution recording stopped"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Stop expert distribution error:
{
e
}
"
)
return
{
"message"
:
f
"Stop expert distribution failed:
{
str
(
e
)
}
"
,
"success"
:
False
,
}
@
self
.
app
.
post
(
"/dump_expert_distribution_record"
)
async
def
dump_expert_distribution_record
(
request
:
dict
):
"""Dump expert distribution recording to specified directory."""
endpoint_name
=
"dump_expert_distribution_record"
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Expert distribution recording dumped to directory"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Dump expert distribution error:
{
e
}
"
)
return
{
"message"
:
f
"Dump expert distribution failed:
{
str
(
e
)
}
"
,
"success"
:
False
,
}
async
def
start_server
(
self
):
"""Start the HTTP server"""
config
=
uvicorn
.
Config
(
self
.
app
,
host
=
"0.0.0.0"
,
port
=
self
.
port
,
)
server
=
uvicorn
.
Server
(
config
)
# Debug: print all registered routes
for
route
in
self
.
app
.
routes
:
if
isinstance
(
route
,
APIRoute
):
logging
.
debug
(
f
"Registered route:
{
route
.
methods
}
{
route
.
path
}
"
)
await
server
.
serve
()
def
parse_args
():
p
=
argparse
.
ArgumentParser
(
description
=
"SGLang HTTP server for cache management"
)
p
.
add_argument
(
"--port"
,
type
=
int
,
default
=
9001
,
help
=
"Port to listen on"
)
p
.
add_argument
(
"--ns"
,
"--namespace"
,
default
=
"dynamo"
,
help
=
"Specify Dynamo namespace (default: discover all)"
,
)
p
.
add_argument
(
"--comp"
,
"--component"
,
default
=
None
,
help
=
"Specify component name (default: discover all)"
,
)
return
p
.
parse_args
()
@
dynamo_worker
(
static
=
False
)
async
def
main
(
runtime
:
DistributedRuntime
):
args
=
parse_args
()
http_server
=
SglangHttpServer
(
args
.
port
,
runtime
,
args
)
await
http_server
.
start_server
()
if
__name__
==
"__main__"
:
uvloop
.
install
()
asyncio
.
run
(
main
())
components/backends/sglang/src/dynamo/sglang/worker/__init__.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# This module is deprecated. Use `python3 -m dynamo.sglang` instead.
components/backends/sglang/src/dynamo/sglang/worker/__main__.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
dynamo.sglang.worker.main
import
main
import
logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.main
import
main
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
configure_dynamo_logging
()
logging
.
warning
(
"DEPRECATION WARNING: `python3 -m dynamo.sglang.worker` is deprecated and will be removed in dynamo v0.5.0."
"Use `python3 -m dynamo.sglang` instead."
,
)
main
()
main
()
components/backends/sglang/src/dynamo/sglang/worker/main.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
json
import
logging
import
random
import
signal
import
socket
import
sys
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
sglang
as
sgl
import
uvloop
import
zmq
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_ip
,
get_zmq_socket
from
dynamo._core
import
Endpoint
from
dynamo.llm
import
(
ForwardPassMetrics
,
KvStats
,
ModelRuntimeConfig
,
ModelType
,
WorkerMetricsPublisher
,
WorkerStats
,
ZmqKvEventPublisher
,
ZmqKvEventPublisherConfig
,
register_llm
,
)
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.common
import
(
BaseWorkerHandler
,
DisaggPreprocessedRequest
,
graceful_shutdown
,
parse_sglang_args_inc
,
setup_native_endpoints
,
)
configure_dynamo_logging
()
class
RequestHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
server_args
:
ServerArgs
,
component
,
decode_client
:
Optional
[
Any
]
=
None
,
):
super
().
__init__
(
engine
,
server_args
,
component
,
decode_client
)
self
.
metrics_publisher
=
WorkerMetricsPublisher
()
self
.
zmq_context
=
zmq
.
asyncio
.
Context
()
# type: ignore
self
.
receive_metrics_from_scheduler
=
None
if
server_args
.
disaggregation_mode
!=
"null"
:
self
.
bootstrap_host
,
self
.
bootstrap_port
=
self
.
_get_bootstrap_info
()
if
decode_client
is
None
:
raise
ValueError
(
"decode_client must be provided when disaggregation_mode is not 'null'"
)
self
.
decode_client
=
decode_client
logging
.
info
(
f
"Disaggregation enabled - bootstrap host:
{
self
.
bootstrap_host
}
, bootstrap port:
{
self
.
bootstrap_port
}
"
)
logging
.
info
(
"Request handler initialized"
)
def
setup_metrics
(
self
):
"""Set up metrics publisher"""
self
.
receive_metrics_from_scheduler
=
get_zmq_socket
(
self
.
zmq_context
,
zmq
.
PULL
,
self
.
engine
.
port_args
.
metrics_ipc_name
,
True
)
self
.
init_publish
()
asyncio
.
create_task
(
self
.
_receive_and_publish_metrics_loop
())
task
=
asyncio
.
create_task
(
self
.
create_metrics_publisher_endpoint
())
task
.
add_done_callback
(
lambda
_
:
logging
.
debug
(
"metrics publisher endpoint created"
)
)
def
init_publish
(
self
):
"""Publish initial set of warmup metrics"""
worker_stats
=
WorkerStats
(
request_active_slots
=
0
,
request_total_slots
=
1024
,
num_requests_waiting
=
0
,
data_parallel_rank
=
0
,
)
kv_stats
=
KvStats
(
kv_active_blocks
=
0
,
kv_total_blocks
=
1024
,
gpu_cache_usage_perc
=
0
,
gpu_prefix_cache_hit_rate
=
0
,
)
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
None
,
)
self
.
metrics_publisher
.
publish
(
metrics
)
async
def
create_metrics_publisher_endpoint
(
self
):
logging
.
debug
(
"Creating metrics publisher endpoint"
)
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
component
)
async
def
_receive_and_publish_metrics_loop
(
self
):
"""Receive metrics from SGL scheduler and publish them"""
while
True
:
try
:
kv_metrics
=
await
self
.
receive_metrics_from_scheduler
.
recv_pyobj
()
# type: ignore
worker_stats
=
WorkerStats
(
request_active_slots
=
kv_metrics
.
request_active_slots
,
request_total_slots
=
kv_metrics
.
request_total_slots
,
num_requests_waiting
=
kv_metrics
.
num_requests_waiting
,
data_parallel_rank
=
kv_metrics
.
data_parallel_rank
,
# Note: 0 means it's either 0 or None from sglang
)
kv_stats
=
KvStats
(
kv_active_blocks
=
kv_metrics
.
kv_active_blocks
,
kv_total_blocks
=
kv_metrics
.
kv_total_blocks
,
gpu_cache_usage_perc
=
kv_metrics
.
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
=
kv_metrics
.
gpu_prefix_cache_hit_rate
,
)
spec_dec_stats
=
None
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
spec_dec_stats
,
)
self
.
metrics_publisher
.
publish
(
metrics
)
except
Exception
:
logging
.
exception
(
"Failed to recieve or publish metrics"
)
def
_get_bootstrap_info
(
self
):
"""Bootstrap info from tokenizer manager"""
inner_tm
=
self
.
engine
.
tokenizer_manager
bootstrap_port
=
inner_tm
.
server_args
.
disaggregation_bootstrap_port
if
inner_tm
.
server_args
.
dist_init_addr
:
bootstrap_host
=
socket
.
gethostbyname
(
inner_tm
.
server_args
.
dist_init_addr
.
split
(
":"
)[
0
]
)
else
:
bootstrap_host
=
get_ip
()
return
bootstrap_host
,
bootstrap_port
def
_build_sampling_params
(
self
,
request
:
dict
)
->
dict
:
sampling_params
=
{}
if
request
[
"sampling_options"
][
"temperature"
]:
sampling_params
[
"temperature"
]
=
request
[
"sampling_options"
][
"temperature"
]
if
request
[
"sampling_options"
][
"top_p"
]:
sampling_params
[
"top_p"
]
=
request
[
"sampling_options"
][
"top_p"
]
if
request
[
"sampling_options"
][
"top_k"
]:
sampling_params
[
"top_k"
]
=
request
[
"sampling_options"
][
"top_k"
]
sampling_params
[
"max_new_tokens"
]
=
request
[
"stop_conditions"
][
"max_tokens"
]
if
request
[
"stop_conditions"
][
"ignore_eos"
]:
sampling_params
[
"ignore_eos"
]
=
request
[
"stop_conditions"
][
"ignore_eos"
]
return
sampling_params
def
_get_request_batch_size
(
self
,
request
:
dict
):
"""Get batch size from request, returns None for single requests"""
if
request
[
"batch_token_ids"
]
is
not
None
:
return
len
(
request
[
"batch_token_ids"
])
return
None
def
_is_batch_request
(
self
,
request
:
dict
):
"""Check if request is in batch mode"""
return
request
[
"batch_token_ids"
]
is
not
None
def
_generate_bootstrap_room
(
self
):
return
random
.
randint
(
0
,
2
**
63
-
1
)
async
def
generate
(
self
,
request
:
dict
):
is_batch
=
self
.
_is_batch_request
(
request
)
batch_size
=
self
.
_get_request_batch_size
(
request
)
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
sampling_params
=
self
.
_build_sampling_params
(
request
)
if
self
.
server_args
.
disaggregation_mode
!=
"null"
:
if
is_batch
:
bootstrap_room
=
[
self
.
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
]
bootstrap_host
=
[
self
.
bootstrap_host
]
*
batch_size
bootstrap_port
=
[
self
.
bootstrap_port
]
*
batch_size
else
:
bootstrap_host
=
self
.
bootstrap_host
bootstrap_port
=
self
.
bootstrap_port
bootstrap_room
=
self
.
_generate_bootstrap_room
()
# decode worker request
disagg_request
=
DisaggPreprocessedRequest
(
request
=
request
,
sampling_params
=
sampling_params
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
)
# prefill response is not used
prefill
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
]
if
not
is_batch
else
request
[
"batch_token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
)
prefill_task
=
asyncio
.
create_task
(
self
.
_prefill_generator
(
prefill
))
decode
=
await
self
.
decode_client
.
generate
(
disagg_request
.
model_dump_json
())
async
for
out
in
self
.
_process_stream
(
decode
,
unpack
=
True
,
is_batch
=
is_batch
):
yield
out
await
prefill_task
else
:
g
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
]
if
not
is_batch
else
request
[
"batch_token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
)
async
for
out
in
self
.
_process_stream
(
g
,
unpack
=
False
,
is_batch
=
is_batch
):
yield
out
async
def
_process_stream
(
self
,
stream_source
,
unpack
:
bool
,
is_batch
:
bool
):
# Initialize based on batch mode
num_output_tokens_so_far
:
Union
[
Dict
[
int
,
int
],
int
]
if
is_batch
:
num_output_tokens_so_far
=
{}
else
:
num_output_tokens_so_far
=
0
async
for
res
in
stream_source
:
data
=
res
.
data
()
if
unpack
else
res
finish_reason
=
data
[
"meta_info"
][
"finish_reason"
]
if
is_batch
:
# Handle batch response
assert
isinstance
(
num_output_tokens_so_far
,
dict
)
index
=
data
.
get
(
"index"
,
0
)
if
index
not
in
num_output_tokens_so_far
:
num_output_tokens_so_far
[
index
]
=
0
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
],
"index"
:
index
,
}
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
new_tokens
=
data
[
"output_ids"
][
num_output_tokens_so_far
[
index
]
:]
out
=
{
"token_ids"
:
new_tokens
,
"index"
:
index
,
}
num_output_tokens_so_far
[
index
]
=
next_total_toks
else
:
# Handle single response
assert
isinstance
(
num_output_tokens_so_far
,
int
)
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
out
=
{
"token_ids"
:
data
[
"output_ids"
][
num_output_tokens_so_far
:]}
num_output_tokens_so_far
=
next_total_toks
yield
out
async
def
_prefill_generator
(
self
,
prefill
):
async
for
_
in
prefill
:
pass
async
def
flush_cache
(
self
,
request
:
dict
):
_
=
request
asyncio
.
create_task
(
self
.
engine
.
tokenizer_manager
.
flush_cache
())
yield
{
"status"
:
"success"
,
"message"
:
"Cache flush initiated. Check backend logs for status"
,
}
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
# Set up signal handler for graceful shutdown
loop
=
asyncio
.
get_running_loop
()
def
signal_handler
():
# Schedule the shutdown coroutine instead of calling it directly
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers set up for graceful shutdown"
)
# TODO: Better handle non-sglang args
sys_argv
=
sys
.
argv
[
1
:]
migration_limit
=
0
try
:
idx
=
sys_argv
.
index
(
"--migration-limit"
)
migration_limit
=
int
(
sys_argv
[
idx
+
1
])
del
sys_argv
[
idx
:
idx
+
2
]
# Remove the args from sys_argv
except
Exception
:
pass
server_args
=
parse_sglang_args_inc
(
sys_argv
)
await
init
(
runtime
,
server_args
,
migration_limit
)
async
def
init
(
runtime
:
DistributedRuntime
,
server_args
:
ServerArgs
,
migration_limit
:
int
):
"""Initialize worker (either prefill or aggregated)"""
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
"dynamo"
).
component
(
"worker"
)
await
component
.
create_service
()
endpoint
=
component
.
endpoint
(
"generate"
)
await
register_llm_with_runtime_config
(
engine
,
endpoint
,
server_args
,
migration_limit
)
if
server_args
.
disaggregation_mode
!=
"null"
:
decode_client
=
(
await
runtime
.
namespace
(
"dynamo"
)
.
component
(
"decode"
)
.
endpoint
(
"generate"
)
.
client
()
)
handler
=
RequestHandler
(
engine
,
server_args
,
component
,
decode_client
)
else
:
handler
=
RequestHandler
(
engine
,
server_args
,
component
)
# Set up the engine metrics reciever
handler
.
setup_metrics
()
# Set up ZMQ kv event publisher
if
server_args
.
kv_events_config
:
kv_events
=
json
.
loads
(
server_args
.
kv_events_config
)
ep
=
kv_events
.
get
(
"endpoint"
)
zmq_ep
=
ep
.
replace
(
"*"
,
get_ip
())
if
ep
else
None
zmq_config
=
ZmqKvEventPublisherConfig
(
worker_id
=
endpoint
.
lease_id
(),
kv_block_size
=
server_args
.
page_size
,
zmq_endpoint
=
zmq_ep
,
)
logging
.
info
(
f
"Setting up ZMQ kv event publisher at
{
zmq_ep
}
"
)
_
=
ZmqKvEventPublisher
(
component
=
component
,
config
=
zmq_config
)
tasks
=
[
endpoint
.
serve_endpoint
(
handler
.
generate
)]
tasks
.
extend
(
setup_native_endpoints
(
server_args
,
component
,
handler
))
await
asyncio
.
gather
(
*
tasks
)
async
def
register_llm_with_runtime_config
(
engine
:
sgl
.
Engine
,
endpoint
:
Endpoint
,
server_args
:
ServerArgs
,
migration_limit
:
int
,
):
"""Register LLM with runtime config"""
runtime_config
=
await
_get_runtime_config
(
engine
)
try
:
await
register_llm
(
ModelType
.
Backend
,
endpoint
,
server_args
.
model_path
,
server_args
.
served_model_name
,
kv_cache_block_size
=
server_args
.
page_size
,
migration_limit
=
migration_limit
,
runtime_config
=
runtime_config
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to register with runtime config:
{
e
}
"
)
return
None
async
def
_get_runtime_config
(
engine
:
sgl
.
Engine
)
->
Optional
[
ModelRuntimeConfig
]:
"""Get runtime config from SGLang engine"""
try
:
# Try to check if the engine has a scheduler attribute with the computed values
if
hasattr
(
engine
,
"scheduler_info"
)
and
engine
.
scheduler_info
is
not
None
:
runtime_config
=
ModelRuntimeConfig
()
# Get max_total_num_tokens from scheduler_info
if
"max_total_num_tokens"
in
engine
.
scheduler_info
:
max_total_tokens
=
engine
.
scheduler_info
[
"max_total_num_tokens"
]
if
max_total_tokens
and
hasattr
(
engine
.
tokenizer_manager
,
"server_args"
):
page_size
=
engine
.
tokenizer_manager
.
server_args
.
page_size
if
page_size
:
runtime_config
.
total_kv_blocks
=
(
max_total_tokens
+
page_size
-
1
)
//
page_size
logging
.
info
(
f
"Got total KV blocks from scheduler:
{
runtime_config
.
total_kv_blocks
}
"
f
"(max_total_tokens=
{
max_total_tokens
}
, page_size=
{
page_size
}
)"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
# TODO: figure out where they are
return
runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging
.
warning
(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return
None
except
Exception
as
e
:
logging
.
warning
(
f
"Failed to get runtime config:
{
e
}
. Proceeding without it."
)
return
None
def
main
():
uvloop
.
install
()
asyncio
.
run
(
worker
())
if
__name__
==
"__main__"
:
main
()
container/Dockerfile.sglang
View file @
80d8aa19
...
@@ -27,7 +27,7 @@ ARG ARCH=amd64
...
@@ -27,7 +27,7 @@ ARG ARCH=amd64
ARG ARCH_ALT=x86_64
ARG ARCH_ALT=x86_64
# Make sure to update the dependency version in pyproject.toml when updating this
# Make sure to update the dependency version in pyproject.toml when updating this
ARG SGLANG_VERSION="0.
4.9.post6
"
ARG SGLANG_VERSION="0.
5.0rc2
"
##################################
##################################
########## Base Image ############
########## Base Image ############
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment