Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
80d8aa19
Unverified
Commit
80d8aa19
authored
Aug 18, 2025
by
ishandhanani
Committed by
GitHub
Aug 18, 2025
Browse files
feat(sglang): unify entry point for SGLang backend architecture (#2493)
parent
28400714
Changes
41
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
624 additions
and
985 deletions
+624
-985
components/backends/sglang/src/dynamo/sglang/common/__init__.py
...ents/backends/sglang/src/dynamo/sglang/common/__init__.py
+0
-38
components/backends/sglang/src/dynamo/sglang/common/base_handlers.py
...backends/sglang/src/dynamo/sglang/common/base_handlers.py
+0
-62
components/backends/sglang/src/dynamo/sglang/common/sgl_utils.py
...nts/backends/sglang/src/dynamo/sglang/common/sgl_utils.py
+0
-102
components/backends/sglang/src/dynamo/sglang/decode_worker/__init__.py
...ckends/sglang/src/dynamo/sglang/decode_worker/__init__.py
+4
-0
components/backends/sglang/src/dynamo/sglang/decode_worker/__main__.py
...ckends/sglang/src/dynamo/sglang/decode_worker/__main__.py
+10
-1
components/backends/sglang/src/dynamo/sglang/decode_worker/main.py
...s/backends/sglang/src/dynamo/sglang/decode_worker/main.py
+0
-94
components/backends/sglang/src/dynamo/sglang/main.py
components/backends/sglang/src/dynamo/sglang/main.py
+146
-0
components/backends/sglang/src/dynamo/sglang/protocol.py
components/backends/sglang/src/dynamo/sglang/protocol.py
+1
-16
components/backends/sglang/src/dynamo/sglang/publisher.py
components/backends/sglang/src/dynamo/sglang/publisher.py
+139
-0
components/backends/sglang/src/dynamo/sglang/register.py
components/backends/sglang/src/dynamo/sglang/register.py
+74
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/__init__.py
...nds/sglang/src/dynamo/sglang/request_handlers/__init__.py
+14
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py
...lang/src/dynamo/sglang/request_handlers/decode_handler.py
+109
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/handler_base.py
...sglang/src/dynamo/sglang/request_handlers/handler_base.py
+36
-0
components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py
...ang/src/dynamo/sglang/request_handlers/prefill_handler.py
+74
-0
components/backends/sglang/src/dynamo/sglang/utils/clear_namespace.py
...ackends/sglang/src/dynamo/sglang/utils/clear_namespace.py
+1
-16
components/backends/sglang/src/dynamo/sglang/utils/sgl_http_server.py
...ackends/sglang/src/dynamo/sglang/utils/sgl_http_server.py
+0
-207
components/backends/sglang/src/dynamo/sglang/worker/__init__.py
...ents/backends/sglang/src/dynamo/sglang/worker/__init__.py
+4
-0
components/backends/sglang/src/dynamo/sglang/worker/__main__.py
...ents/backends/sglang/src/dynamo/sglang/worker/__main__.py
+11
-1
components/backends/sglang/src/dynamo/sglang/worker/main.py
components/backends/sglang/src/dynamo/sglang/worker/main.py
+0
-447
container/Dockerfile.sglang
container/Dockerfile.sglang
+1
-1
No files found.
components/backends/sglang/src/dynamo/sglang/common/__init__.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Base handlers
from
.base_handlers
import
BaseWorkerHandler
# Protocol types
from
.protocol
import
(
DisaggPreprocessedRequest
,
PreprocessedRequest
,
SamplingOptions
,
StopConditions
,
TokenIdType
,
)
# Utilities
from
.sgl_utils
import
(
graceful_shutdown
,
parse_sglang_args_inc
,
reserve_free_port
,
setup_native_endpoints
,
)
__all__
=
[
# Protocol types
"DisaggPreprocessedRequest"
,
"PreprocessedRequest"
,
"SamplingOptions"
,
"StopConditions"
,
"TokenIdType"
,
# Utilities
"parse_sglang_args_inc"
,
"reserve_free_port"
,
"graceful_shutdown"
,
"setup_native_endpoints"
,
# Base handlers
"BaseWorkerHandler"
,
]
components/backends/sglang/src/dynamo/sglang/common/base_handlers.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Optional
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
class
BaseWorkerHandler
(
ABC
):
"""
Abstract base class for sglang request handlers. We use this to implement native sglang endpoints for
workers
"""
@
abstractmethod
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
server_args
:
ServerArgs
,
component
,
decode_client
:
Optional
[
Any
]
=
None
,
):
self
.
engine
=
engine
self
.
server_args
=
server_args
self
.
component
=
component
@
abstractmethod
async
def
generate
(
self
,
request
):
"""Generate tokens from the engine"""
...
async
def
flush_cache
(
self
,
request
:
dict
):
"""Flush KV cache for each worker"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
flush_cache
()
yield
True
async
def
start_expert_distribution_record
(
self
,
request
:
dict
):
"""
Start recording expert distribution.
"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
start_expert_distribution_record
()
yield
True
async
def
stop_expert_distribution_record
(
self
,
request
:
dict
):
"""
Stop recording expert distribution.
"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
stop_expert_distribution_record
()
yield
True
async
def
dump_expert_distribution_record
(
self
,
request
:
dict
):
"""
Dumps the expert distribution record to the directory specified in the environment variable `SGLANG_EXPERT_DISTRIBUTION_RECORDER_DIR`.
"""
_
=
request
await
self
.
engine
.
tokenizer_manager
.
dump_expert_distribution_record
()
yield
True
components/backends/sglang/src/dynamo/sglang/common/sgl_utils.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
contextlib
import
logging
import
socket
from
argparse
import
Namespace
from
sglang.srt.server_args
import
ServerArgs
class
SkipTokenizerInitError
(
RuntimeError
):
def
__init__
(
self
):
super
().
__init__
(
"--skip-tokenizer-init flag is required"
)
def
parse_sglang_args_inc
(
args
:
list
[
str
])
->
ServerArgs
:
# Currently we only support Dynamo doing the tokenization, so we must give
# sglang the skip-tokenizer-init flag. We don't default it because this is temporary.
# Allow the --version and --help flags through.
temp_need_tok
=
[
"--skip-tokenizer-init"
,
"--version"
,
"--help"
,
"-h"
]
if
not
any
(
w
in
args
for
w
in
temp_need_tok
):
raise
SkipTokenizerInitError
()
parser
=
argparse
.
ArgumentParser
()
bootstrap_port
=
_reserve_disaggregation_bootstrap_port
()
ServerArgs
.
add_cli_args
(
parser
)
parsed_args
=
parser
.
parse_args
(
args
)
if
not
any
(
arg
.
startswith
(
"--disaggregation-bootstrap-port"
)
for
arg
in
args
):
args_dict
=
vars
(
parsed_args
)
args_dict
[
"disaggregation_bootstrap_port"
]
=
bootstrap_port
parsed_args
=
Namespace
(
**
args_dict
)
return
ServerArgs
.
from_cli_args
(
parsed_args
)
@
contextlib
.
contextmanager
def
reserve_free_port
(
host
:
str
=
"localhost"
):
"""
Find and reserve a free port until context exits.
"""
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
try
:
sock
.
bind
((
host
,
0
))
_
,
port
=
sock
.
getsockname
()
yield
port
finally
:
sock
.
close
()
def
_reserve_disaggregation_bootstrap_port
():
"""
Each worker requires a unique port for disaggregation_bootstrap_port.
We use an existing utility function that reserves a free port on your
machine to avoid collisions.
"""
with
reserve_free_port
()
as
port
:
return
port
async
def
graceful_shutdown
(
runtime
):
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
runtime
.
shutdown
()
logging
.
info
(
"DistributedRuntime shutdown complete"
)
def
setup_native_endpoints
(
server_args
,
component
,
handler
):
"""Setup sgl native endpoints"""
# flush cache
flush_endpoint
=
component
.
endpoint
(
"flush_cache"
)
tasks
=
[]
tasks
.
append
(
flush_endpoint
.
serve_endpoint
(
handler
.
flush_cache
))
# expert distribution endpoints
if
server_args
.
expert_distribution_recorder_mode
is
not
None
:
start_expert_distribution_endpoint
=
component
.
endpoint
(
"start_expert_distribution_record"
)
stop_expert_distribution_endpoint
=
component
.
endpoint
(
"stop_expert_distribution_record"
)
dump_expert_distribution_endpoint
=
component
.
endpoint
(
"dump_expert_distribution_record"
)
tasks
.
append
(
start_expert_distribution_endpoint
.
serve_endpoint
(
handler
.
start_expert_distribution_record
)
)
tasks
.
append
(
stop_expert_distribution_endpoint
.
serve_endpoint
(
handler
.
stop_expert_distribution_record
)
)
tasks
.
append
(
dump_expert_distribution_endpoint
.
serve_endpoint
(
handler
.
dump_expert_distribution_record
)
)
return
tasks
components/backends/sglang/src/dynamo/sglang/decode_worker/__init__.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# This module is deprecated. Use `python3 -m dynamo.sglang` instead.
components/backends/sglang/src/dynamo/sglang/decode_worker/__main__.py
View file @
80d8aa19
...
...
@@ -2,7 +2,16 @@
# SPDX-License-Identifier: Apache-2.0
from
dynamo.sglang.decode_worker.main
import
main
import
logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.main
import
main
if
__name__
==
"__main__"
:
configure_dynamo_logging
()
logging
.
warning
(
"DEPRECATION WARNING: `python3 -m dynamo.sglang.decode_worker` is deprecated and will be removed in dynamo v0.5.0."
"Use `python3 -m dynamo.sglang` instead."
,
)
main
()
components/backends/sglang/src/dynamo/sglang/decode_worker/main.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
import
asyncio
import
logging
import
signal
import
sys
import
msgspec
import
sglang
as
sgl
import
uvloop
from
sglang.srt.server_args
import
ServerArgs
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.common
import
(
BaseWorkerHandler
,
graceful_shutdown
,
parse_sglang_args_inc
,
setup_native_endpoints
,
)
configure_dynamo_logging
()
class
DecodeRequestHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
server_args
:
ServerArgs
,
component
):
super
().
__init__
(
engine
,
server_args
,
component
)
logging
.
info
(
"Decode request handler initialized"
)
async
def
generate
(
self
,
request
:
str
):
req
=
msgspec
.
json
.
decode
(
request
,
type
=
dict
)
results
=
await
self
.
engine
.
async_generate
(
input_ids
=
req
[
"request"
][
"token_ids"
]
if
req
[
"request"
][
"batch_token_ids"
]
is
None
else
req
[
"request"
][
"batch_token_ids"
],
sampling_params
=
req
[
"sampling_params"
],
stream
=
True
,
bootstrap_host
=
req
[
"bootstrap_host"
],
bootstrap_port
=
req
[
"bootstrap_port"
],
bootstrap_room
=
req
[
"bootstrap_room"
],
)
async
for
result
in
results
:
yield
result
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
# Set up signal handler for graceful shutdown
loop
=
asyncio
.
get_running_loop
()
def
signal_handler
():
# Schedule the shutdown coroutine instead of calling it directly
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers set up for graceful shutdown"
)
server_args
=
parse_sglang_args_inc
(
sys
.
argv
[
1
:])
await
init
(
runtime
,
server_args
)
async
def
init
(
runtime
:
DistributedRuntime
,
server_args
:
ServerArgs
):
"""Initialize decode worker"""
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
"dynamo"
).
component
(
"decode"
)
await
component
.
create_service
()
handler
=
DecodeRequestHandler
(
engine
,
server_args
,
component
)
gen_endpoint
=
component
.
endpoint
(
"generate"
)
tasks
=
[
gen_endpoint
.
serve_endpoint
(
handler
.
generate
)]
tasks
.
extend
(
setup_native_endpoints
(
server_args
,
component
,
handler
))
await
asyncio
.
gather
(
*
tasks
)
def
main
():
uvloop
.
install
()
asyncio
.
run
(
worker
())
if
__name__
==
"__main__"
:
main
()
components/backends/sglang/src/dynamo/sglang/main.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
json
import
logging
import
signal
import
sys
import
sglang
as
sgl
import
uvloop
from
sglang.srt.utils
import
get_ip
from
dynamo.llm
import
ZmqKvEventPublisher
,
ZmqKvEventPublisherConfig
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.args
import
Config
,
DisaggregationMode
,
parse_args
from
dynamo.sglang.publisher
import
setup_sgl_metrics
from
dynamo.sglang.register
import
register_llm_with_runtime_config
from
dynamo.sglang.request_handlers
import
DecodeWorkerHandler
,
PrefillWorkerHandler
configure_dynamo_logging
()
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
loop
=
asyncio
.
get_running_loop
()
def
signal_handler
():
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers will trigger a graceful shutdown of the runtime"
)
config
=
parse_args
(
sys
.
argv
[
1
:])
if
config
.
serving_mode
!=
DisaggregationMode
.
PREFILL
:
await
init
(
runtime
,
config
)
else
:
await
init_prefill
(
runtime
,
config
)
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
):
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
dynamo_args
.
namespace
).
component
(
dynamo_args
.
component
)
await
component
.
create_service
()
generate_endpoint
=
component
.
endpoint
(
dynamo_args
.
endpoint
)
# TODO: think about implementing DisaggregationStrategy for P->D
# TODO: implement a `next` field in the config to dynamically set the next client
prefill_client
=
None
if
config
.
serving_mode
==
DisaggregationMode
.
DECODE
:
logging
.
info
(
"Initializing prefill client"
)
prefill_client
=
(
await
runtime
.
namespace
(
dynamo_args
.
namespace
)
.
component
(
"prefill"
)
.
endpoint
(
"generate"
)
.
client
()
)
publisher
,
metrics_task
=
await
setup_sgl_metrics
(
engine
,
component
)
kv_publisher
=
None
if
server_args
.
kv_events_config
:
kv_events
=
json
.
loads
(
server_args
.
kv_events_config
)
ep
=
kv_events
.
get
(
"endpoint"
)
zmq_ep
=
ep
.
replace
(
"*"
,
get_ip
())
if
ep
else
None
zmq_config
=
ZmqKvEventPublisherConfig
(
worker_id
=
generate_endpoint
.
lease_id
(),
kv_block_size
=
server_args
.
page_size
,
zmq_endpoint
=
zmq_ep
,
)
logging
.
info
(
f
"Setting up ZMQ kv event publisher at
{
zmq_ep
}
"
)
kv_publisher
=
ZmqKvEventPublisher
(
component
=
component
,
config
=
zmq_config
)
handler
=
DecodeWorkerHandler
(
component
,
engine
,
config
,
publisher
,
kv_publisher
,
prefill_client
)
await
register_llm_with_runtime_config
(
engine
,
generate_endpoint
,
server_args
,
dynamo_args
.
migration_limit
)
try
:
# TODO: add in native endpoints
await
asyncio
.
gather
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
False
),
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
finally
:
metrics_task
.
cancel
()
try
:
await
metrics_task
except
asyncio
.
CancelledError
:
logging
.
info
(
"Metrics task succesfully cancelled"
)
pass
handler
.
cleanup
()
async
def
init_prefill
(
runtime
:
DistributedRuntime
,
config
:
Config
):
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
dynamo_args
.
namespace
).
component
(
dynamo_args
.
component
)
await
component
.
create_service
()
generate_endpoint
=
component
.
endpoint
(
dynamo_args
.
endpoint
)
handler
=
PrefillWorkerHandler
(
component
,
engine
,
config
)
tasks
=
[
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
True
)]
try
:
await
asyncio
.
gather
(
*
tasks
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
finally
:
handler
.
cleanup
()
async
def
graceful_shutdown
(
runtime
):
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
runtime
.
shutdown
()
logging
.
info
(
"DistributedRuntime shutdown complete"
)
def
main
():
uvloop
.
run
(
worker
())
if
__name__
==
"__main__"
:
main
()
components/backends/sglang/src/dynamo/sglang/
common/
protocol.py
→
components/backends/sglang/src/dynamo/sglang/protocol.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
from
pydantic
import
BaseModel
,
Field
...
...
@@ -58,7 +46,4 @@ class PreprocessedRequest(BaseModel):
class
DisaggPreprocessedRequest
(
BaseModel
):
request
:
PreprocessedRequest
sampling_params
:
dict
bootstrap_host
:
Union
[
str
,
List
[
str
]]
bootstrap_port
:
Union
[
int
,
List
[
int
]]
bootstrap_room
:
Union
[
int
,
List
[
int
]]
data_parallel_rank
:
Optional
[
int
]
=
None
components/backends/sglang/src/dynamo/sglang/publisher.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
from
typing
import
Optional
import
sglang
as
sgl
import
zmq
import
zmq.asyncio
from
sglang.srt.utils
import
get_zmq_socket
from
dynamo.llm
import
(
ForwardPassMetrics
,
KvStats
,
SpecDecodeStats
,
WorkerMetricsPublisher
,
WorkerStats
,
)
from
dynamo.runtime
import
Component
class
DynamoSglangStatPublisher
:
"""
Handles SGLang metrics reception and publishing.
"""
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
component
:
Component
)
->
None
:
self
.
engine
=
engine
self
.
inner
=
WorkerMetricsPublisher
()
self
.
inner
.
create_endpoint
(
component
)
# Set default values (can be overridden later if needed)
self
.
request_total_slots
=
1024
self
.
dp_rank
=
0
self
.
num_gpu_block
=
1024
# ZMQ setup for receiving scheduler metrics
self
.
_ctx
=
zmq
.
asyncio
.
Context
()
# type: ignore
self
.
_sock
=
get_zmq_socket
(
self
.
_ctx
,
zmq
.
PULL
,
self
.
engine
.
port_args
.
metrics_ipc_name
,
True
# type: ignore
)
async
def
run
(
self
)
->
None
:
"""Main loop to receive scheduler metrics and publish them"""
while
True
:
try
:
kv_metrics
=
await
self
.
_sock
.
recv_pyobj
()
# type: ignore
self
.
record_values
(
request_active_slots
=
kv_metrics
.
request_active_slots
,
request_total_slots
=
kv_metrics
.
request_total_slots
,
kv_active_blocks
=
kv_metrics
.
kv_active_blocks
,
kv_total_blocks
=
kv_metrics
.
kv_total_blocks
,
num_requests_waiting
=
kv_metrics
.
num_requests_waiting
,
gpu_cache_usage_perc
=
kv_metrics
.
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
=
kv_metrics
.
gpu_prefix_cache_hit_rate
,
data_parallel_rank
=
kv_metrics
.
data_parallel_rank
,
)
except
Exception
:
logging
.
exception
(
"Failed to receive or publish SGLang scheduler metrics"
)
def
init_publish
(
self
)
->
None
:
worker_stats
=
WorkerStats
(
request_active_slots
=
0
,
request_total_slots
=
self
.
request_total_slots
,
num_requests_waiting
=
0
,
data_parallel_rank
=
self
.
dp_rank
,
)
kv_stats
=
KvStats
(
kv_active_blocks
=
0
,
kv_total_blocks
=
self
.
num_gpu_block
,
gpu_cache_usage_perc
=
0.0
,
gpu_prefix_cache_hit_rate
=
0.0
,
)
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
None
,
)
logging
.
info
(
"Sending dummy metrics to initialize"
)
self
.
inner
.
publish
(
metrics
)
def
record
(
self
,
worker_stats
:
WorkerStats
,
kv_stats
:
KvStats
,
spec_decode_stats
:
Optional
[
SpecDecodeStats
]
=
None
,
)
->
None
:
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
spec_decode_stats
,
)
self
.
inner
.
publish
(
metrics
)
def
record_values
(
self
,
request_active_slots
:
int
,
request_total_slots
:
int
,
kv_active_blocks
:
int
,
kv_total_blocks
:
int
,
num_requests_waiting
:
int
,
gpu_cache_usage_perc
:
float
,
gpu_prefix_cache_hit_rate
:
float
,
data_parallel_rank
:
Optional
[
int
]
=
None
,
spec_decode_stats
:
Optional
[
SpecDecodeStats
]
=
None
,
)
->
None
:
worker_stats
=
WorkerStats
(
request_active_slots
=
request_active_slots
,
request_total_slots
=
request_total_slots
,
num_requests_waiting
=
num_requests_waiting
,
data_parallel_rank
=
data_parallel_rank
if
data_parallel_rank
is
not
None
else
self
.
dp_rank
,
)
kv_stats
=
KvStats
(
kv_active_blocks
=
kv_active_blocks
,
kv_total_blocks
=
kv_total_blocks
,
gpu_cache_usage_perc
=
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
=
gpu_prefix_cache_hit_rate
,
)
self
.
record
(
worker_stats
,
kv_stats
,
spec_decode_stats
)
async
def
setup_sgl_metrics
(
engine
:
sgl
.
Engine
,
component
:
Component
,
)
->
tuple
[
DynamoSglangStatPublisher
,
asyncio
.
Task
]:
"""
Convenience bootstrap: create endpoint, publish an initial update, and start the metrics loop.
"""
publisher
=
DynamoSglangStatPublisher
(
engine
,
component
)
publisher
.
init_publish
()
task
=
asyncio
.
create_task
(
publisher
.
run
())
logging
.
info
(
"SGLang metrics loop started"
)
return
publisher
,
task
components/backends/sglang/src/dynamo/sglang/register.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
Optional
import
sglang
as
sgl
from
sglang.srt.server_args
import
ServerArgs
from
dynamo._core
import
Endpoint
from
dynamo.llm
import
ModelRuntimeConfig
,
ModelType
,
register_llm
async
def
register_llm_with_runtime_config
(
engine
:
sgl
.
Engine
,
endpoint
:
Endpoint
,
server_args
:
ServerArgs
,
migration_limit
:
int
,
):
"""Register LLM with runtime config"""
runtime_config
=
await
_get_runtime_config
(
engine
)
try
:
await
register_llm
(
ModelType
.
Backend
,
endpoint
,
server_args
.
model_path
,
server_args
.
served_model_name
,
kv_cache_block_size
=
server_args
.
page_size
,
migration_limit
=
migration_limit
,
runtime_config
=
runtime_config
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to register with runtime config:
{
e
}
"
)
return
None
async
def
_get_runtime_config
(
engine
:
sgl
.
Engine
)
->
Optional
[
ModelRuntimeConfig
]:
"""Get runtime config from SGLang engine"""
try
:
# Try to check if the engine has a scheduler attribute with the computed values
if
hasattr
(
engine
,
"scheduler_info"
)
and
engine
.
scheduler_info
is
not
None
:
runtime_config
=
ModelRuntimeConfig
()
# Get max_total_num_tokens from scheduler_info
if
"max_total_num_tokens"
in
engine
.
scheduler_info
:
max_total_tokens
=
engine
.
scheduler_info
[
"max_total_num_tokens"
]
if
max_total_tokens
and
hasattr
(
engine
.
tokenizer_manager
,
"server_args"
):
page_size
=
engine
.
tokenizer_manager
.
server_args
.
page_size
if
page_size
:
runtime_config
.
total_kv_blocks
=
(
max_total_tokens
+
page_size
-
1
)
//
page_size
logging
.
info
(
f
"Got total KV blocks from scheduler:
{
runtime_config
.
total_kv_blocks
}
"
f
"(max_total_tokens=
{
max_total_tokens
}
, page_size=
{
page_size
}
)"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
return
runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging
.
warning
(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return
None
except
Exception
as
e
:
logging
.
warning
(
f
"Failed to get runtime config:
{
e
}
. Proceeding without it."
)
return
None
components/backends/sglang/src/dynamo/sglang/request_handlers/__init__.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
.decode_handler
import
DecodeWorkerHandler
# Base handlers
from
.handler_base
import
BaseWorkerHandler
from
.prefill_handler
import
PrefillWorkerHandler
__all__
=
[
"BaseWorkerHandler"
,
"DecodeWorkerHandler"
,
"PrefillWorkerHandler"
,
]
components/backends/sglang/src/dynamo/sglang/request_handlers/decode_handler.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
logging
import
sglang
as
sgl
from
dynamo._core
import
Client
,
Component
from
dynamo.llm
import
WorkerMetricsPublisher
,
ZmqKvEventPublisher
from
dynamo.sglang.args
import
Config
,
DisaggregationMode
from
dynamo.sglang.protocol
import
DisaggPreprocessedRequest
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
class
DecodeWorkerHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
component
:
Component
,
engine
:
sgl
.
Engine
,
config
:
Config
,
metrics_publisher
:
WorkerMetricsPublisher
,
kv_publisher
:
ZmqKvEventPublisher
=
None
,
prefill_client
:
Client
=
None
,
):
super
().
__init__
(
component
,
engine
,
config
,
metrics_publisher
,
kv_publisher
,
prefill_client
)
if
self
.
serving_mode
==
DisaggregationMode
.
DECODE
:
if
self
.
prefill_client
is
None
:
raise
ValueError
(
"prefill_client must be provided when serving_mode is decode"
)
self
.
prefill_client
=
prefill_client
logging
.
info
(
"Decode worker handler initialized"
)
logging
.
info
(
"Worker handler initialized"
)
def
cleanup
(
self
):
self
.
engine
.
shutdown
()
logging
.
info
(
"Engine shutdown"
)
super
().
cleanup
()
def
_build_sampling_params
(
self
,
request
:
dict
)
->
dict
:
sampling_params
=
{}
if
request
[
"sampling_options"
][
"temperature"
]:
sampling_params
[
"temperature"
]
=
request
[
"sampling_options"
][
"temperature"
]
if
request
[
"sampling_options"
][
"top_p"
]:
sampling_params
[
"top_p"
]
=
request
[
"sampling_options"
][
"top_p"
]
if
request
[
"sampling_options"
][
"top_k"
]:
sampling_params
[
"top_k"
]
=
request
[
"sampling_options"
][
"top_k"
]
sampling_params
[
"max_new_tokens"
]
=
request
[
"stop_conditions"
][
"max_tokens"
]
if
request
[
"stop_conditions"
][
"ignore_eos"
]:
sampling_params
[
"ignore_eos"
]
=
request
[
"stop_conditions"
][
"ignore_eos"
]
return
sampling_params
async
def
generate
(
self
,
request
:
str
):
sampling_params
=
self
.
_build_sampling_params
(
request
)
if
self
.
serving_mode
==
DisaggregationMode
.
DECODE
:
# request the bootstrap info from the target prefill worker
prefill_stream
=
await
self
.
prefill_client
.
generate
(
DisaggPreprocessedRequest
(
request
=
request
,
sampling_params
=
sampling_params
,
).
model_dump_json
()
)
bootstrap_info
=
None
async
for
info
in
prefill_stream
:
bootstrap_info
=
info
.
data
()
break
if
not
bootstrap_info
:
raise
RuntimeError
(
"No bootstrap info received from prefill worker"
)
decode
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
bootstrap_host
=
bootstrap_info
[
"bootstrap_host"
],
bootstrap_port
=
bootstrap_info
[
"bootstrap_port"
],
bootstrap_room
=
bootstrap_info
[
"bootstrap_room"
],
)
async
for
out
in
self
.
_process_stream
(
decode
):
yield
out
else
:
agg
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
)
async
for
out
in
self
.
_process_stream
(
agg
):
yield
out
async
def
_process_stream
(
self
,
stream_source
):
num_output_tokens_so_far
=
0
async
for
res
in
stream_source
:
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
else
:
next_total_toks
=
len
(
res
[
"output_ids"
])
out
=
{
"token_ids"
:
res
[
"output_ids"
][
num_output_tokens_so_far
:]}
num_output_tokens_so_far
=
next_total_toks
yield
out
components/backends/sglang/src/dynamo/sglang/request_handlers/handler_base.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
import
sglang
as
sgl
from
dynamo._core
import
Client
,
Component
from
dynamo.llm
import
WorkerMetricsPublisher
,
ZmqKvEventPublisher
from
dynamo.sglang.args
import
Config
class
BaseWorkerHandler
(
ABC
):
def
__init__
(
self
,
component
:
Component
,
engine
:
sgl
.
Engine
,
config
:
Config
,
metrics_publisher
:
WorkerMetricsPublisher
=
None
,
kv_publisher
:
ZmqKvEventPublisher
=
None
,
prefill_client
:
Client
=
None
,
):
self
.
component
=
component
self
.
engine
=
engine
self
.
config
=
config
self
.
metrics_publisher
=
metrics_publisher
self
.
kv_publisher
=
kv_publisher
self
.
prefill_client
=
prefill_client
self
.
serving_mode
=
config
.
serving_mode
@
abstractmethod
async
def
generate
(
self
,
request
:
str
):
pass
def
cleanup
(
self
):
pass
components/backends/sglang/src/dynamo/sglang/request_handlers/prefill_handler.py
0 → 100644
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
logging
import
random
import
socket
import
msgspec
import
sglang
as
sgl
from
sglang.srt.utils
import
get_ip
from
dynamo._core
import
Component
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
class
PrefillWorkerHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
component
:
Component
,
engine
:
sgl
.
Engine
,
config
:
Config
):
self
.
engine
=
engine
self
.
bootstrap_host
,
self
.
bootstrap_port
=
self
.
_get_bootstrap_info
()
super
().
__init__
(
component
,
engine
,
config
,
None
,
None
,
None
)
logging
.
info
(
f
"Prefill worker handler initialized - bootstrap host:
{
self
.
bootstrap_host
}
, bootstrap port:
{
self
.
bootstrap_port
}
"
)
def
_generate_bootstrap_room
(
self
):
return
random
.
randint
(
0
,
2
**
63
-
1
)
def
cleanup
(
self
):
self
.
engine
.
shutdown
()
logging
.
info
(
"Prefill engine shutdown"
)
super
().
cleanup
()
def
_get_bootstrap_info
(
self
):
"""Bootstrap info from tokenizer manager"""
inner_tm
=
self
.
engine
.
tokenizer_manager
bootstrap_port
=
inner_tm
.
server_args
.
disaggregation_bootstrap_port
if
inner_tm
.
server_args
.
dist_init_addr
:
bootstrap_host
=
socket
.
gethostbyname
(
inner_tm
.
server_args
.
dist_init_addr
.
split
(
":"
)[
0
]
)
else
:
bootstrap_host
=
get_ip
()
return
bootstrap_host
,
bootstrap_port
async
def
generate
(
self
,
request
:
str
):
req
=
msgspec
.
json
.
decode
(
request
,
type
=
dict
)
bootstrap_room
=
self
.
_generate_bootstrap_room
()
bootstrap_info
=
{
"bootstrap_host"
:
self
.
bootstrap_host
,
"bootstrap_port"
:
self
.
bootstrap_port
,
"bootstrap_room"
:
bootstrap_room
,
}
yield
bootstrap_info
results
=
await
self
.
engine
.
async_generate
(
input_ids
=
req
[
"request"
][
"token_ids"
],
sampling_params
=
req
[
"sampling_params"
],
stream
=
True
,
bootstrap_host
=
self
.
bootstrap_host
,
bootstrap_port
=
self
.
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
)
asyncio
.
create_task
(
self
.
_consume_results
(
results
))
async
def
_consume_results
(
self
,
results
):
async
for
_
in
results
:
pass
components/backends/sglang/src/dynamo/sglang/utils/clear_namespace.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2020 Atalaya Tech. Inc
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modifications Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES
import
argparse
import
asyncio
...
...
@@ -23,7 +9,6 @@ from dynamo.runtime import DistributedRuntime, EtcdKvCache, dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
@
dynamo_worker
()
...
...
@@ -34,7 +19,7 @@ async def clear_namespace(runtime: DistributedRuntime, namespace: str):
{},
)
await
etcd_kv_cache
.
clear_all
()
logg
er
.
info
(
f
"Cleared /
{
namespace
}
in EtcdKvCache"
)
logg
ing
.
info
(
f
"Cleared /
{
namespace
}
in EtcdKvCache"
)
if
__name__
==
"__main__"
:
...
...
components/backends/sglang/src/dynamo/sglang/utils/sgl_http_server.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
argparse
import
asyncio
import
logging
import
uvicorn
import
uvloop
from
fastapi
import
FastAPI
from
fastapi.routing
import
APIRoute
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
FLUSH_CACHE_ENDPOINT
=
"flush_cache"
configure_dynamo_logging
()
class
SglangHttpServer
:
def
__init__
(
self
,
port
:
int
,
runtime
:
DistributedRuntime
,
args
):
self
.
port
=
port
self
.
app
=
FastAPI
()
self
.
runtime
=
runtime
self
.
args
=
args
self
.
setup_routes
()
async
def
_discover_endpoints
(
self
,
endpoint_name
):
"""Discover endpoints that match the pattern"""
etcd_client
=
self
.
runtime
.
etcd_client
()
if
etcd_client
is
None
:
raise
RuntimeError
(
"Runtime has no etcd client; cannot discover endpoints"
)
prefix
=
"instances/"
kvs
=
await
etcd_client
.
kv_get_prefix
(
prefix
)
# Collect (namespace, component) combos that expose the target endpoint
discovered
=
set
()
for
kv
in
kvs
:
key
=
kv
[
"key"
]
if
isinstance
(
kv
,
dict
)
else
kv
.
key
if
isinstance
(
key
,
bytes
):
key
=
key
.
decode
()
if
not
key
.
startswith
(
prefix
):
continue
segments
=
key
.
split
(
"/"
)
# Format: instances/<ns>/<comp>/<endpoint:lease>
if
len
(
segments
)
<
4
:
continue
ns
,
comp
,
ep_with_lease
=
segments
[
1
],
segments
[
2
],
segments
[
3
]
if
self
.
args
.
ns
and
ns
!=
self
.
args
.
ns
:
continue
if
self
.
args
.
comp
and
comp
!=
self
.
args
.
comp
:
continue
ep_name
=
ep_with_lease
.
split
(
":"
,
1
)[
0
]
if
ep_name
==
endpoint_name
:
discovered
.
add
((
ns
,
comp
))
logging
.
debug
(
f
"Discovered endpoint:
{
ns
}
.
{
comp
}
"
)
logging
.
debug
(
f
"Endpoint discovery complete. Found
{
len
(
discovered
)
}
matching endpoints"
)
return
discovered
async
def
_dispatch_command
(
self
,
endpoint_name
:
str
,
payload
:
dict
|
str
=
"{}"
,
success_message
:
str
=
""
):
"""Dispatches a command to all instances of a discovered endpoint."""
discovered
=
await
self
.
_discover_endpoints
(
endpoint_name
=
endpoint_name
)
if
not
discovered
:
return
{
"message"
:
"No matching endpoints found"
,
"success"
:
False
}
logging
.
debug
(
f
"Found components:
{
', '
.
join
([
f
'
{
ns
}
.
{
comp
}
' for ns, comp in discovered])
}
"
)
for
ns
,
comp
in
discovered
:
ep
=
self
.
runtime
.
namespace
(
ns
).
component
(
comp
).
endpoint
(
endpoint_name
)
client
=
await
ep
.
client
()
await
client
.
wait_for_instances
()
ids
=
client
.
instance_ids
()
logging
.
debug
(
f
"--
{
ns
}
.
{
comp
}
:
{
len
(
ids
)
}
instances --"
)
for
inst_id
in
ids
:
try
:
stream
=
await
client
.
direct
(
payload
,
inst_id
)
async
for
stream_payload
in
stream
:
logging
.
debug
(
f
"[
{
ns
}
.
{
comp
}
][
{
inst_id
}
] ->
{
stream_payload
}
"
)
except
Exception
as
e
:
logging
.
error
(
f
"[
{
ns
}
.
{
comp
}
][
{
inst_id
}
]
{
endpoint_name
}
error:
{
e
}
"
)
return
{
"message"
:
success_message
,
"success"
:
True
}
def
setup_routes
(
self
):
@
self
.
app
.
post
(
"/flush_cache"
)
async
def
flush_cache
():
"""Flush the radix cache."""
endpoint_name
=
self
.
args
.
endpoint
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Cache flush initiated"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Cache flush error:
{
e
}
"
)
return
{
"message"
:
f
"Cache flush failed:
{
str
(
e
)
}
"
,
"success"
:
False
}
@
self
.
app
.
post
(
"/start_expert_distribution_record"
)
async
def
start_expert_distribution_record
():
"""Start recording expert distribution."""
endpoint_name
=
"start_expert_distribution_record"
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Expert distribution recording started"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Start expert distribution error:
{
e
}
"
)
return
{
"message"
:
f
"Start expert distribution failed:
{
str
(
e
)
}
"
,
"success"
:
False
,
}
@
self
.
app
.
post
(
"/stop_expert_distribution_record"
)
async
def
stop_expert_distribution_record
():
"""Stop recording expert distribution."""
endpoint_name
=
"stop_expert_distribution_record"
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Expert distribution recording stopped"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Stop expert distribution error:
{
e
}
"
)
return
{
"message"
:
f
"Stop expert distribution failed:
{
str
(
e
)
}
"
,
"success"
:
False
,
}
@
self
.
app
.
post
(
"/dump_expert_distribution_record"
)
async
def
dump_expert_distribution_record
(
request
:
dict
):
"""Dump expert distribution recording to specified directory."""
endpoint_name
=
"dump_expert_distribution_record"
try
:
return
await
self
.
_dispatch_command
(
endpoint_name
,
success_message
=
"Expert distribution recording dumped to directory"
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Dump expert distribution error:
{
e
}
"
)
return
{
"message"
:
f
"Dump expert distribution failed:
{
str
(
e
)
}
"
,
"success"
:
False
,
}
async
def
start_server
(
self
):
"""Start the HTTP server"""
config
=
uvicorn
.
Config
(
self
.
app
,
host
=
"0.0.0.0"
,
port
=
self
.
port
,
)
server
=
uvicorn
.
Server
(
config
)
# Debug: print all registered routes
for
route
in
self
.
app
.
routes
:
if
isinstance
(
route
,
APIRoute
):
logging
.
debug
(
f
"Registered route:
{
route
.
methods
}
{
route
.
path
}
"
)
await
server
.
serve
()
def
parse_args
():
p
=
argparse
.
ArgumentParser
(
description
=
"SGLang HTTP server for cache management"
)
p
.
add_argument
(
"--port"
,
type
=
int
,
default
=
9001
,
help
=
"Port to listen on"
)
p
.
add_argument
(
"--ns"
,
"--namespace"
,
default
=
"dynamo"
,
help
=
"Specify Dynamo namespace (default: discover all)"
,
)
p
.
add_argument
(
"--comp"
,
"--component"
,
default
=
None
,
help
=
"Specify component name (default: discover all)"
,
)
return
p
.
parse_args
()
@
dynamo_worker
(
static
=
False
)
async
def
main
(
runtime
:
DistributedRuntime
):
args
=
parse_args
()
http_server
=
SglangHttpServer
(
args
.
port
,
runtime
,
args
)
await
http_server
.
start_server
()
if
__name__
==
"__main__"
:
uvloop
.
install
()
asyncio
.
run
(
main
())
components/backends/sglang/src/dynamo/sglang/worker/__init__.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# This module is deprecated. Use `python3 -m dynamo.sglang` instead.
components/backends/sglang/src/dynamo/sglang/worker/__main__.py
View file @
80d8aa19
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from
dynamo.sglang.worker.main
import
main
import
logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.main
import
main
if
__name__
==
"__main__"
:
configure_dynamo_logging
()
logging
.
warning
(
"DEPRECATION WARNING: `python3 -m dynamo.sglang.worker` is deprecated and will be removed in dynamo v0.5.0."
"Use `python3 -m dynamo.sglang` instead."
,
)
main
()
components/backends/sglang/src/dynamo/sglang/worker/main.py
deleted
100644 → 0
View file @
28400714
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import
asyncio
import
json
import
logging
import
random
import
signal
import
socket
import
sys
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
sglang
as
sgl
import
uvloop
import
zmq
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_ip
,
get_zmq_socket
from
dynamo._core
import
Endpoint
from
dynamo.llm
import
(
ForwardPassMetrics
,
KvStats
,
ModelRuntimeConfig
,
ModelType
,
WorkerMetricsPublisher
,
WorkerStats
,
ZmqKvEventPublisher
,
ZmqKvEventPublisherConfig
,
register_llm
,
)
from
dynamo.runtime
import
DistributedRuntime
,
dynamo_worker
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.common
import
(
BaseWorkerHandler
,
DisaggPreprocessedRequest
,
graceful_shutdown
,
parse_sglang_args_inc
,
setup_native_endpoints
,
)
configure_dynamo_logging
()
class
RequestHandler
(
BaseWorkerHandler
):
def
__init__
(
self
,
engine
:
sgl
.
Engine
,
server_args
:
ServerArgs
,
component
,
decode_client
:
Optional
[
Any
]
=
None
,
):
super
().
__init__
(
engine
,
server_args
,
component
,
decode_client
)
self
.
metrics_publisher
=
WorkerMetricsPublisher
()
self
.
zmq_context
=
zmq
.
asyncio
.
Context
()
# type: ignore
self
.
receive_metrics_from_scheduler
=
None
if
server_args
.
disaggregation_mode
!=
"null"
:
self
.
bootstrap_host
,
self
.
bootstrap_port
=
self
.
_get_bootstrap_info
()
if
decode_client
is
None
:
raise
ValueError
(
"decode_client must be provided when disaggregation_mode is not 'null'"
)
self
.
decode_client
=
decode_client
logging
.
info
(
f
"Disaggregation enabled - bootstrap host:
{
self
.
bootstrap_host
}
, bootstrap port:
{
self
.
bootstrap_port
}
"
)
logging
.
info
(
"Request handler initialized"
)
def
setup_metrics
(
self
):
"""Set up metrics publisher"""
self
.
receive_metrics_from_scheduler
=
get_zmq_socket
(
self
.
zmq_context
,
zmq
.
PULL
,
self
.
engine
.
port_args
.
metrics_ipc_name
,
True
)
self
.
init_publish
()
asyncio
.
create_task
(
self
.
_receive_and_publish_metrics_loop
())
task
=
asyncio
.
create_task
(
self
.
create_metrics_publisher_endpoint
())
task
.
add_done_callback
(
lambda
_
:
logging
.
debug
(
"metrics publisher endpoint created"
)
)
def
init_publish
(
self
):
"""Publish initial set of warmup metrics"""
worker_stats
=
WorkerStats
(
request_active_slots
=
0
,
request_total_slots
=
1024
,
num_requests_waiting
=
0
,
data_parallel_rank
=
0
,
)
kv_stats
=
KvStats
(
kv_active_blocks
=
0
,
kv_total_blocks
=
1024
,
gpu_cache_usage_perc
=
0
,
gpu_prefix_cache_hit_rate
=
0
,
)
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
None
,
)
self
.
metrics_publisher
.
publish
(
metrics
)
async
def
create_metrics_publisher_endpoint
(
self
):
logging
.
debug
(
"Creating metrics publisher endpoint"
)
await
self
.
metrics_publisher
.
create_endpoint
(
self
.
component
)
async
def
_receive_and_publish_metrics_loop
(
self
):
"""Receive metrics from SGL scheduler and publish them"""
while
True
:
try
:
kv_metrics
=
await
self
.
receive_metrics_from_scheduler
.
recv_pyobj
()
# type: ignore
worker_stats
=
WorkerStats
(
request_active_slots
=
kv_metrics
.
request_active_slots
,
request_total_slots
=
kv_metrics
.
request_total_slots
,
num_requests_waiting
=
kv_metrics
.
num_requests_waiting
,
data_parallel_rank
=
kv_metrics
.
data_parallel_rank
,
# Note: 0 means it's either 0 or None from sglang
)
kv_stats
=
KvStats
(
kv_active_blocks
=
kv_metrics
.
kv_active_blocks
,
kv_total_blocks
=
kv_metrics
.
kv_total_blocks
,
gpu_cache_usage_perc
=
kv_metrics
.
gpu_cache_usage_perc
,
gpu_prefix_cache_hit_rate
=
kv_metrics
.
gpu_prefix_cache_hit_rate
,
)
spec_dec_stats
=
None
metrics
=
ForwardPassMetrics
(
worker_stats
=
worker_stats
,
kv_stats
=
kv_stats
,
spec_decode_stats
=
spec_dec_stats
,
)
self
.
metrics_publisher
.
publish
(
metrics
)
except
Exception
:
logging
.
exception
(
"Failed to recieve or publish metrics"
)
def
_get_bootstrap_info
(
self
):
"""Bootstrap info from tokenizer manager"""
inner_tm
=
self
.
engine
.
tokenizer_manager
bootstrap_port
=
inner_tm
.
server_args
.
disaggregation_bootstrap_port
if
inner_tm
.
server_args
.
dist_init_addr
:
bootstrap_host
=
socket
.
gethostbyname
(
inner_tm
.
server_args
.
dist_init_addr
.
split
(
":"
)[
0
]
)
else
:
bootstrap_host
=
get_ip
()
return
bootstrap_host
,
bootstrap_port
def
_build_sampling_params
(
self
,
request
:
dict
)
->
dict
:
sampling_params
=
{}
if
request
[
"sampling_options"
][
"temperature"
]:
sampling_params
[
"temperature"
]
=
request
[
"sampling_options"
][
"temperature"
]
if
request
[
"sampling_options"
][
"top_p"
]:
sampling_params
[
"top_p"
]
=
request
[
"sampling_options"
][
"top_p"
]
if
request
[
"sampling_options"
][
"top_k"
]:
sampling_params
[
"top_k"
]
=
request
[
"sampling_options"
][
"top_k"
]
sampling_params
[
"max_new_tokens"
]
=
request
[
"stop_conditions"
][
"max_tokens"
]
if
request
[
"stop_conditions"
][
"ignore_eos"
]:
sampling_params
[
"ignore_eos"
]
=
request
[
"stop_conditions"
][
"ignore_eos"
]
return
sampling_params
def
_get_request_batch_size
(
self
,
request
:
dict
):
"""Get batch size from request, returns None for single requests"""
if
request
[
"batch_token_ids"
]
is
not
None
:
return
len
(
request
[
"batch_token_ids"
])
return
None
def
_is_batch_request
(
self
,
request
:
dict
):
"""Check if request is in batch mode"""
return
request
[
"batch_token_ids"
]
is
not
None
def
_generate_bootstrap_room
(
self
):
return
random
.
randint
(
0
,
2
**
63
-
1
)
async
def
generate
(
self
,
request
:
dict
):
is_batch
=
self
.
_is_batch_request
(
request
)
batch_size
=
self
.
_get_request_batch_size
(
request
)
# TODO: maintain a mapping from SGLang's Ouput struct to LLMEngineOuput
sampling_params
=
self
.
_build_sampling_params
(
request
)
if
self
.
server_args
.
disaggregation_mode
!=
"null"
:
if
is_batch
:
bootstrap_room
=
[
self
.
_generate_bootstrap_room
()
for
_
in
range
(
batch_size
)
]
bootstrap_host
=
[
self
.
bootstrap_host
]
*
batch_size
bootstrap_port
=
[
self
.
bootstrap_port
]
*
batch_size
else
:
bootstrap_host
=
self
.
bootstrap_host
bootstrap_port
=
self
.
bootstrap_port
bootstrap_room
=
self
.
_generate_bootstrap_room
()
# decode worker request
disagg_request
=
DisaggPreprocessedRequest
(
request
=
request
,
sampling_params
=
sampling_params
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
)
# prefill response is not used
prefill
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
]
if
not
is_batch
else
request
[
"batch_token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
bootstrap_host
=
bootstrap_host
,
bootstrap_port
=
bootstrap_port
,
bootstrap_room
=
bootstrap_room
,
)
prefill_task
=
asyncio
.
create_task
(
self
.
_prefill_generator
(
prefill
))
decode
=
await
self
.
decode_client
.
generate
(
disagg_request
.
model_dump_json
())
async
for
out
in
self
.
_process_stream
(
decode
,
unpack
=
True
,
is_batch
=
is_batch
):
yield
out
await
prefill_task
else
:
g
=
await
self
.
engine
.
async_generate
(
input_ids
=
request
[
"token_ids"
]
if
not
is_batch
else
request
[
"batch_token_ids"
],
sampling_params
=
sampling_params
,
stream
=
True
,
)
async
for
out
in
self
.
_process_stream
(
g
,
unpack
=
False
,
is_batch
=
is_batch
):
yield
out
async
def
_process_stream
(
self
,
stream_source
,
unpack
:
bool
,
is_batch
:
bool
):
# Initialize based on batch mode
num_output_tokens_so_far
:
Union
[
Dict
[
int
,
int
],
int
]
if
is_batch
:
num_output_tokens_so_far
=
{}
else
:
num_output_tokens_so_far
=
0
async
for
res
in
stream_source
:
data
=
res
.
data
()
if
unpack
else
res
finish_reason
=
data
[
"meta_info"
][
"finish_reason"
]
if
is_batch
:
# Handle batch response
assert
isinstance
(
num_output_tokens_so_far
,
dict
)
index
=
data
.
get
(
"index"
,
0
)
if
index
not
in
num_output_tokens_so_far
:
num_output_tokens_so_far
[
index
]
=
0
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
],
"index"
:
index
,
}
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
new_tokens
=
data
[
"output_ids"
][
num_output_tokens_so_far
[
index
]
:]
out
=
{
"token_ids"
:
new_tokens
,
"index"
:
index
,
}
num_output_tokens_so_far
[
index
]
=
next_total_toks
else
:
# Handle single response
assert
isinstance
(
num_output_tokens_so_far
,
int
)
if
finish_reason
:
out
=
{
"token_ids"
:
[],
"finish_reason"
:
finish_reason
[
"type"
]}
else
:
next_total_toks
=
len
(
data
[
"output_ids"
])
out
=
{
"token_ids"
:
data
[
"output_ids"
][
num_output_tokens_so_far
:]}
num_output_tokens_so_far
=
next_total_toks
yield
out
async
def
_prefill_generator
(
self
,
prefill
):
async
for
_
in
prefill
:
pass
async
def
flush_cache
(
self
,
request
:
dict
):
_
=
request
asyncio
.
create_task
(
self
.
engine
.
tokenizer_manager
.
flush_cache
())
yield
{
"status"
:
"success"
,
"message"
:
"Cache flush initiated. Check backend logs for status"
,
}
@
dynamo_worker
(
static
=
False
)
async
def
worker
(
runtime
:
DistributedRuntime
):
# Set up signal handler for graceful shutdown
loop
=
asyncio
.
get_running_loop
()
def
signal_handler
():
# Schedule the shutdown coroutine instead of calling it directly
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
logging
.
info
(
"Signal handlers set up for graceful shutdown"
)
# TODO: Better handle non-sglang args
sys_argv
=
sys
.
argv
[
1
:]
migration_limit
=
0
try
:
idx
=
sys_argv
.
index
(
"--migration-limit"
)
migration_limit
=
int
(
sys_argv
[
idx
+
1
])
del
sys_argv
[
idx
:
idx
+
2
]
# Remove the args from sys_argv
except
Exception
:
pass
server_args
=
parse_sglang_args_inc
(
sys_argv
)
await
init
(
runtime
,
server_args
,
migration_limit
)
async
def
init
(
runtime
:
DistributedRuntime
,
server_args
:
ServerArgs
,
migration_limit
:
int
):
"""Initialize worker (either prefill or aggregated)"""
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
component
=
runtime
.
namespace
(
"dynamo"
).
component
(
"worker"
)
await
component
.
create_service
()
endpoint
=
component
.
endpoint
(
"generate"
)
await
register_llm_with_runtime_config
(
engine
,
endpoint
,
server_args
,
migration_limit
)
if
server_args
.
disaggregation_mode
!=
"null"
:
decode_client
=
(
await
runtime
.
namespace
(
"dynamo"
)
.
component
(
"decode"
)
.
endpoint
(
"generate"
)
.
client
()
)
handler
=
RequestHandler
(
engine
,
server_args
,
component
,
decode_client
)
else
:
handler
=
RequestHandler
(
engine
,
server_args
,
component
)
# Set up the engine metrics reciever
handler
.
setup_metrics
()
# Set up ZMQ kv event publisher
if
server_args
.
kv_events_config
:
kv_events
=
json
.
loads
(
server_args
.
kv_events_config
)
ep
=
kv_events
.
get
(
"endpoint"
)
zmq_ep
=
ep
.
replace
(
"*"
,
get_ip
())
if
ep
else
None
zmq_config
=
ZmqKvEventPublisherConfig
(
worker_id
=
endpoint
.
lease_id
(),
kv_block_size
=
server_args
.
page_size
,
zmq_endpoint
=
zmq_ep
,
)
logging
.
info
(
f
"Setting up ZMQ kv event publisher at
{
zmq_ep
}
"
)
_
=
ZmqKvEventPublisher
(
component
=
component
,
config
=
zmq_config
)
tasks
=
[
endpoint
.
serve_endpoint
(
handler
.
generate
)]
tasks
.
extend
(
setup_native_endpoints
(
server_args
,
component
,
handler
))
await
asyncio
.
gather
(
*
tasks
)
async
def
register_llm_with_runtime_config
(
engine
:
sgl
.
Engine
,
endpoint
:
Endpoint
,
server_args
:
ServerArgs
,
migration_limit
:
int
,
):
"""Register LLM with runtime config"""
runtime_config
=
await
_get_runtime_config
(
engine
)
try
:
await
register_llm
(
ModelType
.
Backend
,
endpoint
,
server_args
.
model_path
,
server_args
.
served_model_name
,
kv_cache_block_size
=
server_args
.
page_size
,
migration_limit
=
migration_limit
,
runtime_config
=
runtime_config
,
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to register with runtime config:
{
e
}
"
)
return
None
async
def
_get_runtime_config
(
engine
:
sgl
.
Engine
)
->
Optional
[
ModelRuntimeConfig
]:
"""Get runtime config from SGLang engine"""
try
:
# Try to check if the engine has a scheduler attribute with the computed values
if
hasattr
(
engine
,
"scheduler_info"
)
and
engine
.
scheduler_info
is
not
None
:
runtime_config
=
ModelRuntimeConfig
()
# Get max_total_num_tokens from scheduler_info
if
"max_total_num_tokens"
in
engine
.
scheduler_info
:
max_total_tokens
=
engine
.
scheduler_info
[
"max_total_num_tokens"
]
if
max_total_tokens
and
hasattr
(
engine
.
tokenizer_manager
,
"server_args"
):
page_size
=
engine
.
tokenizer_manager
.
server_args
.
page_size
if
page_size
:
runtime_config
.
total_kv_blocks
=
(
max_total_tokens
+
page_size
-
1
)
//
page_size
logging
.
info
(
f
"Got total KV blocks from scheduler:
{
runtime_config
.
total_kv_blocks
}
"
f
"(max_total_tokens=
{
max_total_tokens
}
, page_size=
{
page_size
}
)"
)
# Note: max_running_requests and max_prefill_tokens are NOT available in scheduler_info
# TODO: figure out where they are
return
runtime_config
# If scheduler approach doesn't work, log and return None to indicate we'll skip runtime config
logging
.
warning
(
"Could not access runtime config from SGLang engine. "
"The engine may compute these values internally after initialization. "
"Proceeding without runtime config - SGLang will use its internal defaults."
)
return
None
except
Exception
as
e
:
logging
.
warning
(
f
"Failed to get runtime config:
{
e
}
. Proceeding without it."
)
return
None
def
main
():
uvloop
.
install
()
asyncio
.
run
(
worker
())
if
__name__
==
"__main__"
:
main
()
container/Dockerfile.sglang
View file @
80d8aa19
...
...
@@ -27,7 +27,7 @@ ARG ARCH=amd64
ARG ARCH_ALT=x86_64
# Make sure to update the dependency version in pyproject.toml when updating this
ARG SGLANG_VERSION="0.
4.9.post6
"
ARG SGLANG_VERSION="0.
5.0rc2
"
##################################
########## Base Image ############
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment