Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2712426f
Unverified
Commit
2712426f
authored
Mar 19, 2026
by
jh-nv
Committed by
GitHub
Mar 19, 2026
Browse files
feat: enable mypy in pre-merge (#6732)
parent
e5e118a1
Changes
52
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
142 additions
and
72 deletions
+142
-72
components/src/dynamo/planner/utils/disagg_planner.py
components/src/dynamo/planner/utils/disagg_planner.py
+3
-7
components/src/dynamo/planner/utils/dryrun.py
components/src/dynamo/planner/utils/dryrun.py
+9
-0
components/src/dynamo/planner/utils/load_predictor.py
components/src/dynamo/planner/utils/load_predictor.py
+2
-2
components/src/dynamo/planner/utils/perf_interpolation.py
components/src/dynamo/planner/utils/perf_interpolation.py
+2
-0
components/src/dynamo/planner/utils/planner_core.py
components/src/dynamo/planner/utils/planner_core.py
+33
-14
components/src/dynamo/planner/utils/prefill_planner.py
components/src/dynamo/planner/utils/prefill_planner.py
+2
-0
components/src/dynamo/sglang/backend_args.py
components/src/dynamo/sglang/backend_args.py
+5
-4
components/src/dynamo/sglang/request_handlers/handler_base.py
...onents/src/dynamo/sglang/request_handlers/handler_base.py
+25
-17
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
.../src/dynamo/sglang/request_handlers/llm/decode_handler.py
+2
-2
components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py
...c/dynamo/sglang/request_handlers/llm/diffusion_handler.py
+1
-1
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
...lang/request_handlers/multimodal/encode_worker_handler.py
+13
-5
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
...o/sglang/request_handlers/multimodal/processor_handler.py
+7
-4
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
...namo/sglang/request_handlers/multimodal/worker_handler.py
+4
-2
components/src/dynamo/sglang/request_handlers/video_generation/video_generation_handler.py
...est_handlers/video_generation/video_generation_handler.py
+6
-0
components/src/dynamo/trtllm/logits_processing/adapter.py
components/src/dynamo/trtllm/logits_processing/adapter.py
+2
-2
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
.../src/dynamo/trtllm/request_handlers/aggregated_handler.py
+5
-3
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+6
-6
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+4
-0
components/src/dynamo/trtllm/workers/video_diffusion_worker.py
...nents/src/dynamo/trtllm/workers/video_diffusion_worker.py
+3
-0
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+8
-3
No files found.
components/src/dynamo/planner/utils/disagg_planner.py
View file @
2712426f
...
...
@@ -4,7 +4,6 @@
import
asyncio
import
logging
import
time
from
typing
import
Optional
from
dynamo.planner
import
SubComponentType
,
TargetReplica
from
dynamo.planner.utils.decode_planner
import
DecodePlanner
...
...
@@ -24,9 +23,7 @@ logger = logging.getLogger(__name__)
class
DisaggPlanner
:
def
__init__
(
self
,
runtime
:
Optional
[
DistributedRuntime
],
config
:
PlannerConfig
)
->
None
:
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
shared_state
=
PlannerSharedState
()
prometheus_metrics
=
PlannerPrometheusMetrics
()
...
...
@@ -89,13 +86,12 @@ class DisaggPlanner:
logger
.
info
(
f
"Detected model name from deployment:
{
model_name
}
"
)
model_name
=
model_name
.
lower
()
else
:
model_name
=
getattr
(
self
.
config
,
"model_name"
,
None
)
if
not
model_name
:
if
not
self
.
config
.
model_name
:
raise
ValueError
(
"Model name is required in no-operation mode. "
"Please set model_name in the config."
)
model_name
=
model_name
.
lower
()
model_name
=
self
.
config
.
model_name
.
lower
()
self
.
prefill_planner
.
model_name
=
model_name
self
.
decode_planner
.
model_name
=
model_name
...
...
components/src/dynamo/planner/utils/dryrun.py
View file @
2712426f
...
...
@@ -127,6 +127,13 @@ def run_sla_planner_dryrun(
time_series
.
append
(
time_series
[
-
1
]
+
interval
)
_est_rr
,
_est_isl
,
_est_osl
=
predictor_planner
.
predict_load
()
# predict_load() returns Optional[float] values; in dryrun mode with
# pre-loaded data the predictors always return valid floats.
assert
(
_est_rr
is
not
None
and
_est_isl
is
not
None
and
_est_osl
is
not
None
),
"predict_load() returned None in dryrun mode"
est_rr
.
append
(
_est_rr
)
est_isl
.
append
(
_est_isl
)
est_osl
.
append
(
_est_osl
)
...
...
@@ -145,10 +152,12 @@ def run_sla_planner_dryrun(
if
prefill_planner
is
not
None
and
decode_planner
is
not
None
:
_num_p
,
_num_d
=
_apply_global_gpu_budget
(
_num_p
,
_num_d
,
config
)
elif
prefill_planner
is
not
None
:
assert
config
.
prefill_engine_num_gpu
is
not
None
_num_p
=
_apply_component_gpu_budget
(
_num_p
,
config
.
prefill_engine_num_gpu
,
config
)
elif
decode_planner
is
not
None
:
assert
config
.
decode_engine_num_gpu
is
not
None
_num_d
=
_apply_component_gpu_budget
(
_num_d
,
config
.
decode_engine_num_gpu
,
config
)
...
...
components/src/dynamo/planner/utils/load_predictor.py
View file @
2712426f
...
...
@@ -19,7 +19,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
datetime
import
datetime
,
timedelta
from
enum
import
Enum
from
typing
import
Any
from
typing
import
Any
,
Callable
import
numpy
as
np
import
pandas
as
pd
...
...
@@ -389,7 +389,7 @@ class KalmanPredictor(BasePredictor):
)
LOAD_PREDICTORS
=
{
LOAD_PREDICTORS
:
dict
[
str
,
Callable
[[
PlannerConfig
],
BasePredictor
]]
=
{
"constant"
:
ConstantPredictor
,
"arima"
:
ARIMAPredictor
,
"kalman"
:
KalmanPredictor
,
...
...
components/src/dynamo/planner/utils/perf_interpolation.py
View file @
2712426f
...
...
@@ -151,6 +151,8 @@ class DecodeInterpolator:
self
.
resolution
=
resolution
self
.
xi
=
np
.
linspace
(
0
,
1
,
resolution
)
self
.
yi
=
np
.
linspace
(
0
,
max
(
self
.
y_context_length
),
resolution
)
self
.
X
:
np
.
ndarray
self
.
Y
:
np
.
ndarray
self
.
X
,
self
.
Y
=
np
.
meshgrid
(
self
.
xi
,
self
.
yi
)
# Lazy import scipy only when interpolation is actually needed
...
...
components/src/dynamo/planner/utils/planner_core.py
View file @
2712426f
...
...
@@ -6,7 +6,7 @@ import logging
import
math
import
time
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
prometheus_client
import
Gauge
,
start_http_server
...
...
@@ -36,6 +36,9 @@ from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_moonc
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime.logging
import
configure_dynamo_logging
# Union of all connector types used by the planner
ConnectorType
=
Union
[
GlobalPlannerConnector
,
KubernetesConnector
,
VirtualConnector
]
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -248,14 +251,14 @@ class BasePlanner:
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
runtime
:
Optional
[
DistributedRuntime
]
,
config
:
PlannerConfig
,
dryrun
:
bool
=
False
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
prometheus_metrics
:
Optional
[
PlannerPrometheusMetrics
]
=
None
,
prometheus_traffic_client
:
Optional
[
PrometheusAPIClient
]
=
None
,
prometheus_engine_client
:
Optional
[
DirectRouterMetricsClient
]
=
None
,
connector
=
None
,
connector
:
Optional
[
ConnectorType
]
=
None
,
start_prometheus_server
:
bool
=
True
,
component_type
:
Optional
[
SubComponentType
]
=
None
,
):
...
...
@@ -272,11 +275,13 @@ class BasePlanner:
if
not
self
.
dryrun
:
self
.
runtime
=
runtime
self
.
namespace
=
config
.
namespace
self
.
connector
:
ConnectorType
if
not
config
.
no_operation
:
# Initialize connector based on environment
if
config
.
environment
==
"global-planner"
:
assert
config
.
global_planner_namespace
is
not
None
assert
runtime
is
not
None
self
.
connector
=
GlobalPlannerConnector
(
runtime
,
self
.
namespace
,
...
...
@@ -289,6 +294,7 @@ class BasePlanner:
self
.
namespace
,
self
.
model_name
)
elif
config
.
environment
==
"virtual"
:
assert
runtime
is
not
None
self
.
connector
=
VirtualConnector
(
runtime
,
self
.
namespace
,
...
...
@@ -430,11 +436,12 @@ class BasePlanner:
self
.
prometheus_engine_client
=
prometheus_engine_client
else
:
# Auto-discover frontend metrics URL in Kubernetes mode
connector
=
getattr
(
self
,
"connector"
,
None
)
if
not
config
.
load_router_metrics_url
and
isinstance
(
getattr
(
self
,
"
connector
"
,
None
)
,
KubernetesConnector
connector
,
KubernetesConnector
):
config
.
load_router_metrics_url
=
(
self
.
connector
.
get_frontend_metrics_url
()
connector
.
get_frontend_metrics_url
()
)
if
not
config
.
load_router_metrics_url
:
raise
ValueError
(
...
...
@@ -447,6 +454,9 @@ class BasePlanner:
f
"Auto-discovered frontend metrics URL:
{
config
.
load_router_metrics_url
}
"
)
assert
(
config
.
load_router_metrics_url
is
not
None
),
"load_router_metrics_url must be set when load-based scaling is enabled"
self
.
prometheus_engine_client
=
DirectRouterMetricsClient
(
config
.
load_router_metrics_url
,
config
.
namespace
)
...
...
@@ -494,6 +504,7 @@ class BasePlanner:
async
def
_get_or_create_client
(
self
,
component_name
:
str
,
endpoint_name
:
str
):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
assert
self
.
runtime
is
not
None
,
"Runtime is not initialized"
client
=
await
self
.
runtime
.
endpoint
(
f
"
{
self
.
namespace
}
.
{
component_name
}
.
{
endpoint_name
}
"
).
client
()
...
...
@@ -604,41 +615,46 @@ class BasePlanner:
)
# Prometheus returns seconds, convert to milliseconds
assert
(
self
.
model_name
is
not
None
),
"model_name must be set before observing traffic stats"
interval_str
=
f
"
{
self
.
config
.
throughput_adjustment_interval
}
s"
self
.
last_metrics
.
ttft
=
(
self
.
prometheus_traffic_client
.
get_avg_time_to_first_token
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
)
*
1000
)
self
.
last_metrics
.
itl
=
(
self
.
prometheus_traffic_client
.
get_avg_inter_token_latency
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
)
*
1000
)
self
.
last_metrics
.
num_req
=
(
self
.
prometheus_traffic_client
.
get_avg_request_count
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
)
)
self
.
last_metrics
.
request_duration
=
(
self
.
prometheus_traffic_client
.
get_avg_request_duration
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
)
)
self
.
last_metrics
.
isl
=
(
self
.
prometheus_traffic_client
.
get_avg_input_sequence_tokens
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
)
)
self
.
last_metrics
.
osl
=
(
self
.
prometheus_traffic_client
.
get_avg_output_sequence_tokens
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
)
)
...
...
@@ -666,9 +682,12 @@ class BasePlanner:
self
.
update_predictors_from_metrics
(
self
.
last_metrics
)
def
update_predictors_from_metrics
(
self
,
metrics
:
Metrics
)
->
None
:
self
.
num_req_predictor
.
add_data_point
(
metrics
.
num_req
)
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
if
metrics
.
num_req
is
not
None
:
self
.
num_req_predictor
.
add_data_point
(
metrics
.
num_req
)
if
metrics
.
isl
is
not
None
:
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
if
metrics
.
osl
is
not
None
:
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
def
predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]:
try
:
...
...
components/src/dynamo/planner/utils/prefill_planner.py
View file @
2712426f
...
...
@@ -94,6 +94,7 @@ class PrefillPlanner(BasePlanner):
return
None
def
_update_correction_factor
(
self
)
->
bool
:
assert
self
.
last_metrics
.
isl
is
not
None
and
self
.
last_metrics
.
ttft
is
not
None
expect_ttft
=
self
.
prefill_interpolator
.
interpolate_ttft
(
self
.
last_metrics
.
isl
)
self
.
p_correction_factor
=
self
.
last_metrics
.
ttft
/
expect_ttft
logger
.
info
(
f
"Correction factor (prefill TTFT):
{
self
.
p_correction_factor
:.
3
f
}
"
)
...
...
@@ -117,6 +118,7 @@ class PrefillPlanner(BasePlanner):
"(no throughput satisfies TTFT target), falling back to min_endpoint"
)
return
self
.
config
.
min_endpoint
assert
self
.
config
.
prefill_engine_num_gpu
is
not
None
next_num_p
=
math
.
ceil
(
pred_prefill_throughput
/
p_thpt_per_gpu
...
...
components/src/dynamo/sglang/backend_args.py
View file @
2712426f
...
...
@@ -4,7 +4,7 @@
"""Dynamo SGLang wrapper configuration ArgGroup."""
import
argparse
from
typing
import
Optional
,
Union
from
typing
import
Optional
from
dynamo.common.configuration.arg_group
import
ArgGroup
from
dynamo.common.configuration.config_base
import
ConfigBase
...
...
@@ -117,7 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_processor
:
bool
multimodal_encode_worker
:
bool
multimodal_worker
:
bool
embedding_transfer_mode
:
Union
[
str
,
EmbeddingTransferMode
]
embedding_transfer_mode
:
EmbeddingTransferMode
embedding_worker
:
bool
image_diffusion_worker
:
bool
...
...
@@ -127,10 +127,11 @@ class DynamoSGLangConfig(ConfigBase):
video_generation_worker
:
bool
def
validate
(
self
)
->
None
:
if
isinstance
(
self
.
embedding_transfer_mode
,
str
):
if
not
isinstance
(
self
.
embedding_transfer_mode
,
EmbeddingTransferMode
):
self
.
embedding_transfer_mode
=
EmbeddingTransferMode
(
self
.
embedding_transfer_mode
str
(
self
.
embedding_transfer_mode
)
)
if
(
self
.
disagg_config
is
not
None
)
^
(
self
.
disagg_config_key
is
not
None
):
raise
ValueError
(
"Both 'disagg_config' and 'disagg_config_key' must be provided together."
...
...
components/src/dynamo/sglang/request_handlers/handler_base.py
View file @
2712426f
...
...
@@ -8,13 +8,23 @@ import random
import
socket
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
asynccontextmanager
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Optional
,
Tuple
from
typing
import
(
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
Generic
,
Optional
,
Tuple
,
TypeVar
,
)
import
sglang
as
sgl
from
sglang.srt.utils
import
get_local_ip_auto
from
dynamo._core
import
Context
from
dynamo.common.utils.input_params
import
InputParamManager
from
dynamo.llm
import
KvEventPublisher
,
WorkerMetricsPublisher
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.publisher
import
DynamoSglangPublisher
...
...
@@ -72,7 +82,11 @@ class SGLangEngineQuiesceController:
self
.
_is_quiesced
=
False
class
BaseGenerativeHandler
(
ABC
):
RequestT
=
TypeVar
(
"RequestT"
)
ResponseT
=
TypeVar
(
"ResponseT"
)
class
BaseGenerativeHandler
(
ABC
,
Generic
[
RequestT
,
ResponseT
]):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
Provides common infrastructure for:
...
...
@@ -95,27 +109,24 @@ class BaseGenerativeHandler(ABC):
self
.
config
=
config
# Set up metrics and KV publishers
self
.
metrics_publisher
:
Optional
[
WorkerMetricsPublisher
]
=
None
self
.
kv_publisher
:
Optional
[
KvEventPublisher
]
=
None
if
publisher
is
not
None
:
self
.
metrics_publisher
=
publisher
.
metrics_publisher
self
.
kv_publisher
=
publisher
.
kv_publisher
else
:
self
.
metrics_publisher
=
None
self
.
kv_publisher
=
None
@
abstractmethod
async
def
generate
(
self
,
request
:
Dict
[
str
,
Any
],
context
:
Context
)
->
AsyncGenerator
[
Dict
[
str
,
Any
],
None
]:
def
generate
(
self
,
request
:
RequestT
,
context
:
Context
)
->
AsyncIterator
[
ResponseT
]:
"""Generate response from request.
Args:
request: Request
dict
with input and parameters.
request: Request with input and parameters.
context: Context object for cancellation handling.
Yields:
Response data (format varies by handler implementation).
"""
pass
...
def
cleanup
(
self
)
->
None
:
"""Cleanup resources. Override in subclasses as needed."""
...
...
@@ -137,7 +148,7 @@ class BaseGenerativeHandler(ABC):
return
{
"traceparent"
:
f
"00-
{
trace_id
}
-
{
span_id
}
-01"
}
class
BaseWorkerHandler
(
BaseGenerativeHandler
):
class
BaseWorkerHandler
(
BaseGenerativeHandler
[
RequestT
,
ResponseT
]
):
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
...
...
@@ -175,9 +186,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
if
publisher
is
not
None
:
self
.
metrics_publisher
=
publisher
.
metrics_publisher
self
.
kv_publisher
=
publisher
.
kv_publisher
else
:
self
.
metrics_publisher
=
None
self
.
kv_publisher
=
None
self
.
serving_mode
=
config
.
serving_mode
self
.
skip_tokenizer_init
=
config
.
server_args
.
skip_tokenizer_init
self
.
enable_trace
=
config
.
server_args
.
enable_trace
...
...
@@ -454,17 +462,17 @@ class BaseWorkerHandler(BaseGenerativeHandler):
)
@
abstractmethod
async
def
generate
(
self
,
request
:
Dict
[
str
,
Any
]
,
context
:
Context
):
def
generate
(
self
,
request
:
RequestT
,
context
:
Context
)
->
AsyncIterator
[
ResponseT
]
:
"""Generate response from request.
Args:
request: Request
dict
with input and parameters.
request: Request with input and parameters.
context: Context object for cancellation handling.
Yields:
Response data (format varies by handler implementation).
"""
pass
...
def
cleanup
(
self
)
->
None
:
"""Cleanup resources. Override in subclasses as needed."""
...
...
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
View file @
2712426f
...
...
@@ -24,7 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self
,
engine
:
sgl
.
Engine
,
config
:
Config
,
publisher
:
DynamoSglangPublisher
,
publisher
:
Optional
[
DynamoSglangPublisher
]
=
None
,
generate_endpoint
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
)
->
None
:
...
...
@@ -230,7 +230,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# This lets SGLang proceed to the second token generation, which will
# async context switch and allow the abort monitor to signal cancellation.
# The loop should exit by itself when context.is_stopped() returns True.
out
=
{}
out
:
dict
[
str
,
Any
]
=
{}
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
if
finish_reason
:
out
[
"finish_reason"
]
=
normalize_finish_reason
(
...
...
components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py
View file @
2712426f
...
...
@@ -21,7 +21,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
self
,
engine
:
sgl
.
Engine
,
config
:
Config
,
publisher
:
DynamoSglangPublisher
=
None
,
publisher
:
Optional
[
DynamoSglangPublisher
]
=
None
,
generate_endpoint
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
)
->
None
:
...
...
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
View file @
2712426f
...
...
@@ -38,7 +38,7 @@ except ImportError as e:
DEVICE
=
"cpu"
class
MultimodalEncodeWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalEncodeWorkerHandler
(
BaseWorkerHandler
[
SglangMultimodalRequest
,
str
]
):
"""
Handler for multimodal encode worker component that processes images/videos
and forwards them to the downstream worker.
...
...
@@ -84,12 +84,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
if
image_token_str
==
"<|vision_start|><|image_pad|><|vision_end|>"
:
# These are likely the individual special tokens for Qwen2.5-VL
image_pad_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"<|image_pad|>"
)
assert
isinstance
(
image_pad_id
,
int
),
f
"Expected int token id, got
{
type
(
image_pad_id
)
}
"
# Use the image_pad token as the main image token
self
.
image_token_id
=
image_pad_id
self
.
image_token_id
:
int
=
image_pad_id
else
:
# Fallback for other models
self
.
image_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
image_token_str
)
token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
image_token_str
)
assert
isinstance
(
token_id
,
int
),
f
"Expected int token id, got
{
type
(
token_id
)
}
"
self
.
image_token_id
=
token_id
self
.
min_workers
=
1
...
...
@@ -230,10 +237,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
zip
(
multimodal_groups
,
image_grid_thw_list
)
):
mm_group
.
image_grid_thw
=
image_grid_thw
mm_group
.
multimodal_input
.
image_url
=
None
if
mm_group
.
multimodal_input
is
not
None
:
mm_group
.
multimodal_input
.
image_url
=
None
# Store shared tensor transfer metadata at request level.
request
.
embeddings_shape
=
tuple
(
precomputed_embeddings
.
shape
)
request
.
embeddings_shape
=
tuple
(
precomputed_embeddings
.
shape
)
# type: ignore[assignment]
request
.
transfer_payload
=
None
search_start
=
0
...
...
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
View file @
2712426f
...
...
@@ -6,7 +6,7 @@ import json
import
logging
import
time
import
uuid
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Optional
from
transformers
import
AutoTokenizer
...
...
@@ -20,6 +20,7 @@ from dynamo.sglang.protocol import (
MultiModalGroup
,
MultiModalInput
,
MultiModalRequest
,
PreprocessedRequest
,
SglangMultimodalRequest
,
)
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
...
...
@@ -27,7 +28,7 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
logger
=
logging
.
getLogger
(
__name__
)
class
MultimodalProcessorHandler
(
BaseWorkerHandler
):
class
MultimodalProcessorHandler
(
BaseWorkerHandler
[
MultiModalRequest
,
Dict
[
str
,
Any
]]
):
"""
Handler for multimodal processor component that processes multimodal requests
and forwards them to the encode worker.
...
...
@@ -56,7 +57,9 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
def
cleanup
(
self
):
pass
async
def
generate
(
self
,
raw_request
:
MultiModalRequest
,
context
:
Context
):
async
def
generate
(
self
,
raw_request
:
MultiModalRequest
,
context
:
Context
)
->
AsyncGenerator
[
Dict
[
str
,
Any
],
None
]:
"""
Process multimodal request and forward to encode worker.
...
...
@@ -119,7 +122,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
)
worker_request
=
SglangMultimodalRequest
(
request
=
sglang_request
,
request
=
PreprocessedRequest
(
**
sglang_request
)
,
multimodal_inputs
=
multimodal_groups
,
)
...
...
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
View file @
2712426f
...
...
@@ -256,7 +256,7 @@ async def _build_mm_items(
return
mm_items
,
embeddings
,
tensor_id
class
MultimodalWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalWorkerHandler
(
BaseWorkerHandler
[
SglangMultimodalRequest
,
str
]
):
"""
Multimodal worker handler for LLM inference with multimodal data.
Handles both aggregated and disaggregated modes.
...
...
@@ -490,7 +490,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
logger
.
info
(
"Multimodal worker engine shutdown"
)
class
MultimodalPrefillWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalPrefillWorkerHandler
(
BaseWorkerHandler
[
DisaggSglangMultimodalRequest
,
str
]
):
"""
Multimodal prefill worker handler for disaggregated inference
Processes multimodal inputs and coordinates with decode worker.
...
...
components/src/dynamo/sglang/request_handlers/video_generation/video_generation_handler.py
View file @
2712426f
...
...
@@ -103,16 +103,22 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
)
# Parse size
assert
req
.
size
is
not
None
,
"Size is required"
width
,
height
=
self
.
_parse_size
(
req
.
size
)
# Calculate num_frames if not explicitly provided
num_frames
=
nvext
.
num_frames
assert
nvext
.
fps
is
not
None
,
"FPS is required"
if
num_frames
is
None
:
assert
req
.
seconds
is
not
None
,
"Seconds is required"
num_frames
=
nvext
.
fps
*
req
.
seconds
# Generate video
context_id
=
context
.
id
()
assert
context_id
is
not
None
assert
(
nvext
.
num_inference_steps
is
not
None
),
"Num inference steps is required"
video_bytes
=
await
self
.
_generate_video
(
prompt
=
req
.
prompt
,
width
=
width
,
...
...
components/src/dynamo/trtllm/logits_processing/adapter.py
View file @
2712426f
...
...
@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import
logging
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Sequence
import
torch
from
tensorrt_llm.sampling_params
import
LogitsProcessor
...
...
@@ -70,7 +70,7 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
def
create_trtllm_adapters
(
processors
:
List
[
BaseLogitsProcessor
],
processors
:
Sequence
[
BaseLogitsProcessor
],
)
->
List
[
TrtllmDynamoLogitsAdapter
]:
"""
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
...
...
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
View file @
2712426f
...
...
@@ -5,7 +5,9 @@
import
logging
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
torch
from
dynamo._core
import
Context
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
...
...
@@ -40,7 +42,7 @@ class AggregatedHandler(HandlerBase):
"""Generate response, optionally using remote encoder for multimodal."""
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
embeddings
=
None
embeddings
:
Optional
[
Union
[
torch
.
Tensor
,
dict
]]
=
None
ep_disaggregated_params
=
None
if
self
.
multimodal_processor
and
self
.
encode_client
:
messages
=
request
.
get
(
"extra_args"
,
{}).
get
(
...
...
@@ -57,7 +59,7 @@ class AggregatedHandler(HandlerBase):
self
.
_encoder_cache
,
)
if
isinstance
(
result
,
list
):
embeddings
=
result
embeddings
=
result
# type: ignore[assignment]
else
:
ep_disaggregated_params
=
result
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
2712426f
...
...
@@ -31,7 +31,7 @@ from tensorrt_llm.llmapi.llm import SamplingParams
from
tensorrt_llm.sampling_params
import
GuidedDecodingParams
from
tensorrt_llm.scheduling_params
import
SchedulingParams
from
dynamo._core
import
Context
from
dynamo._core
import
Client
,
Context
from
dynamo.common.utils.otel_tracing
import
build_trace_headers
from
dynamo.logits_processing.examples
import
HelloWorldLogitsProcessor
from
dynamo.nixl_connect
import
Connector
...
...
@@ -65,9 +65,9 @@ class RequestHandlerConfig:
engine
:
TensorRTLLMEngine
default_sampling_params
:
SamplingParams
publisher
:
Publisher
publisher
:
Optional
[
Publisher
]
disaggregation_mode
:
DisaggregationMode
encode_client
:
Optional
[
objec
t
]
=
None
encode_client
:
Optional
[
Clien
t
]
=
None
multimodal_processor
:
Optional
[
MultimodalRequestProcessor
]
=
None
# for multimodal support
...
...
@@ -558,11 +558,11 @@ class HandlerBase(BaseGenerativeHandler):
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
if
self
.
multimodal_processor
:
processed_inpu
t
=
await
self
.
multimodal_processor
.
process_openai_request
(
mm_resul
t
=
await
self
.
multimodal_processor
.
process_openai_request
(
request
,
embeddings
,
ep_disaggregated_params
)
if
processed_inpu
t
:
return
processed_inpu
t
if
mm_resul
t
:
return
mm_resul
t
# If multimodal processing returned None but request has multimodal data,
# this is an error (not a text-only request). Raise instead of falling back.
...
...
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
2712426f
...
...
@@ -111,6 +111,8 @@ class PrefillHandler(HandlerBase):
Encoder's embeddings tensor to be used by the prefill worker
"""
# Get response with shape info and readable metadata
if
self
.
encode_client
is
None
:
raise
RuntimeError
(
"Encode client is not configured."
)
encode_response
=
None
async
for
res
in
await
self
.
encode_client
.
round_robin
(
request
):
encode_response
=
res
.
data
()
...
...
@@ -119,6 +121,8 @@ class PrefillHandler(HandlerBase):
if
not
encode_response
:
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
if
self
.
connector
is
None
:
raise
RuntimeError
(
"Connector is not configured."
)
# Use utility function to handle NIXL reading and reconstruction
return
await
EncodeHelper
.
read_embeddings_from_encode_response
(
encode_response
,
self
.
connector
...
...
components/src/dynamo/trtllm/workers/video_diffusion_worker.py
View file @
2712426f
...
...
@@ -61,6 +61,9 @@ async def init_video_diffusion_worker(
else
[]
)
if
not
config
.
endpoint
:
raise
ValueError
(
"endpoint must be configured for video diffusion worker"
)
# Build DiffusionConfig from the main Config
diffusion_config
=
DiffusionConfig
(
namespace
=
config
.
namespace
,
...
...
components/src/dynamo/vllm/handlers.py
View file @
2712426f
...
...
@@ -13,7 +13,7 @@ import time
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
asynccontextmanager
from
dataclasses
import
dataclass
from
typing
import
Any
,
Async
Gen
erator
,
Dict
,
Final
from
typing
import
Any
,
Async
It
erator
,
Dict
,
Final
,
Generic
,
TypeVar
import
torch
from
vllm.config
import
VllmConfig
...
...
@@ -23,6 +23,7 @@ from vllm.outputs import RequestOutput
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
dynamo._core
import
Context
from
dynamo.common.multimodal.image_loader
import
ImageLoader
from
dynamo.common.utils.engine_response
import
normalize_finish_reason
from
dynamo.common.utils.input_params
import
InputParamManager
...
...
@@ -325,7 +326,11 @@ def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
)
class
BaseWorkerHandler
(
ABC
):
RequestT
=
TypeVar
(
"RequestT"
)
ResponseT
=
TypeVar
(
"ResponseT"
)
class
BaseWorkerHandler
(
ABC
,
Generic
[
RequestT
,
ResponseT
]):
"""
Request handler for the generate and clear_kv_blocks endpoints.
"""
...
...
@@ -459,7 +464,7 @@ class BaseWorkerHandler(ABC):
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
@
abstractmethod
async
def
generate
(
self
,
request
,
c
ontext
)
->
Async
Gen
erator
[
dict
,
None
]:
def
generate
(
self
,
request
:
RequestT
,
context
:
C
ontext
)
->
Async
It
erator
[
ResponseT
]:
raise
NotImplementedError
async
def
_monitor_abort
(
self
,
context
,
request_id
,
is_prefill
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment