Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2712426f
Unverified
Commit
2712426f
authored
Mar 19, 2026
by
jh-nv
Committed by
GitHub
Mar 19, 2026
Browse files
feat: enable mypy in pre-merge (#6732)
parent
e5e118a1
Changes
52
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
142 additions
and
72 deletions
+142
-72
components/src/dynamo/planner/utils/disagg_planner.py
components/src/dynamo/planner/utils/disagg_planner.py
+3
-7
components/src/dynamo/planner/utils/dryrun.py
components/src/dynamo/planner/utils/dryrun.py
+9
-0
components/src/dynamo/planner/utils/load_predictor.py
components/src/dynamo/planner/utils/load_predictor.py
+2
-2
components/src/dynamo/planner/utils/perf_interpolation.py
components/src/dynamo/planner/utils/perf_interpolation.py
+2
-0
components/src/dynamo/planner/utils/planner_core.py
components/src/dynamo/planner/utils/planner_core.py
+33
-14
components/src/dynamo/planner/utils/prefill_planner.py
components/src/dynamo/planner/utils/prefill_planner.py
+2
-0
components/src/dynamo/sglang/backend_args.py
components/src/dynamo/sglang/backend_args.py
+5
-4
components/src/dynamo/sglang/request_handlers/handler_base.py
...onents/src/dynamo/sglang/request_handlers/handler_base.py
+25
-17
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
.../src/dynamo/sglang/request_handlers/llm/decode_handler.py
+2
-2
components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py
...c/dynamo/sglang/request_handlers/llm/diffusion_handler.py
+1
-1
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
...lang/request_handlers/multimodal/encode_worker_handler.py
+13
-5
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
...o/sglang/request_handlers/multimodal/processor_handler.py
+7
-4
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
...namo/sglang/request_handlers/multimodal/worker_handler.py
+4
-2
components/src/dynamo/sglang/request_handlers/video_generation/video_generation_handler.py
...est_handlers/video_generation/video_generation_handler.py
+6
-0
components/src/dynamo/trtllm/logits_processing/adapter.py
components/src/dynamo/trtllm/logits_processing/adapter.py
+2
-2
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
.../src/dynamo/trtllm/request_handlers/aggregated_handler.py
+5
-3
components/src/dynamo/trtllm/request_handlers/handler_base.py
...onents/src/dynamo/trtllm/request_handlers/handler_base.py
+6
-6
components/src/dynamo/trtllm/request_handlers/handlers.py
components/src/dynamo/trtllm/request_handlers/handlers.py
+4
-0
components/src/dynamo/trtllm/workers/video_diffusion_worker.py
...nents/src/dynamo/trtllm/workers/video_diffusion_worker.py
+3
-0
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+8
-3
No files found.
components/src/dynamo/planner/utils/disagg_planner.py
View file @
2712426f
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
asyncio
import
asyncio
import
logging
import
logging
import
time
import
time
from
typing
import
Optional
from
dynamo.planner
import
SubComponentType
,
TargetReplica
from
dynamo.planner
import
SubComponentType
,
TargetReplica
from
dynamo.planner.utils.decode_planner
import
DecodePlanner
from
dynamo.planner.utils.decode_planner
import
DecodePlanner
...
@@ -24,9 +23,7 @@ logger = logging.getLogger(__name__)
...
@@ -24,9 +23,7 @@ logger = logging.getLogger(__name__)
class
DisaggPlanner
:
class
DisaggPlanner
:
def
__init__
(
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
config
:
PlannerConfig
)
->
None
:
self
,
runtime
:
Optional
[
DistributedRuntime
],
config
:
PlannerConfig
)
->
None
:
self
.
config
=
config
self
.
config
=
config
self
.
shared_state
=
PlannerSharedState
()
self
.
shared_state
=
PlannerSharedState
()
prometheus_metrics
=
PlannerPrometheusMetrics
()
prometheus_metrics
=
PlannerPrometheusMetrics
()
...
@@ -89,13 +86,12 @@ class DisaggPlanner:
...
@@ -89,13 +86,12 @@ class DisaggPlanner:
logger
.
info
(
f
"Detected model name from deployment:
{
model_name
}
"
)
logger
.
info
(
f
"Detected model name from deployment:
{
model_name
}
"
)
model_name
=
model_name
.
lower
()
model_name
=
model_name
.
lower
()
else
:
else
:
model_name
=
getattr
(
self
.
config
,
"model_name"
,
None
)
if
not
self
.
config
.
model_name
:
if
not
model_name
:
raise
ValueError
(
raise
ValueError
(
"Model name is required in no-operation mode. "
"Model name is required in no-operation mode. "
"Please set model_name in the config."
"Please set model_name in the config."
)
)
model_name
=
model_name
.
lower
()
model_name
=
self
.
config
.
model_name
.
lower
()
self
.
prefill_planner
.
model_name
=
model_name
self
.
prefill_planner
.
model_name
=
model_name
self
.
decode_planner
.
model_name
=
model_name
self
.
decode_planner
.
model_name
=
model_name
...
...
components/src/dynamo/planner/utils/dryrun.py
View file @
2712426f
...
@@ -127,6 +127,13 @@ def run_sla_planner_dryrun(
...
@@ -127,6 +127,13 @@ def run_sla_planner_dryrun(
time_series
.
append
(
time_series
[
-
1
]
+
interval
)
time_series
.
append
(
time_series
[
-
1
]
+
interval
)
_est_rr
,
_est_isl
,
_est_osl
=
predictor_planner
.
predict_load
()
_est_rr
,
_est_isl
,
_est_osl
=
predictor_planner
.
predict_load
()
# predict_load() returns Optional[float] values; in dryrun mode with
# pre-loaded data the predictors always return valid floats.
assert
(
_est_rr
is
not
None
and
_est_isl
is
not
None
and
_est_osl
is
not
None
),
"predict_load() returned None in dryrun mode"
est_rr
.
append
(
_est_rr
)
est_rr
.
append
(
_est_rr
)
est_isl
.
append
(
_est_isl
)
est_isl
.
append
(
_est_isl
)
est_osl
.
append
(
_est_osl
)
est_osl
.
append
(
_est_osl
)
...
@@ -145,10 +152,12 @@ def run_sla_planner_dryrun(
...
@@ -145,10 +152,12 @@ def run_sla_planner_dryrun(
if
prefill_planner
is
not
None
and
decode_planner
is
not
None
:
if
prefill_planner
is
not
None
and
decode_planner
is
not
None
:
_num_p
,
_num_d
=
_apply_global_gpu_budget
(
_num_p
,
_num_d
,
config
)
_num_p
,
_num_d
=
_apply_global_gpu_budget
(
_num_p
,
_num_d
,
config
)
elif
prefill_planner
is
not
None
:
elif
prefill_planner
is
not
None
:
assert
config
.
prefill_engine_num_gpu
is
not
None
_num_p
=
_apply_component_gpu_budget
(
_num_p
=
_apply_component_gpu_budget
(
_num_p
,
config
.
prefill_engine_num_gpu
,
config
_num_p
,
config
.
prefill_engine_num_gpu
,
config
)
)
elif
decode_planner
is
not
None
:
elif
decode_planner
is
not
None
:
assert
config
.
decode_engine_num_gpu
is
not
None
_num_d
=
_apply_component_gpu_budget
(
_num_d
=
_apply_component_gpu_budget
(
_num_d
,
config
.
decode_engine_num_gpu
,
config
_num_d
,
config
.
decode_engine_num_gpu
,
config
)
)
...
...
components/src/dynamo/planner/utils/load_predictor.py
View file @
2712426f
...
@@ -19,7 +19,7 @@ import warnings
...
@@ -19,7 +19,7 @@ import warnings
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
datetime
import
datetime
,
timedelta
from
datetime
import
datetime
,
timedelta
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Any
from
typing
import
Any
,
Callable
import
numpy
as
np
import
numpy
as
np
import
pandas
as
pd
import
pandas
as
pd
...
@@ -389,7 +389,7 @@ class KalmanPredictor(BasePredictor):
...
@@ -389,7 +389,7 @@ class KalmanPredictor(BasePredictor):
)
)
LOAD_PREDICTORS
=
{
LOAD_PREDICTORS
:
dict
[
str
,
Callable
[[
PlannerConfig
],
BasePredictor
]]
=
{
"constant"
:
ConstantPredictor
,
"constant"
:
ConstantPredictor
,
"arima"
:
ARIMAPredictor
,
"arima"
:
ARIMAPredictor
,
"kalman"
:
KalmanPredictor
,
"kalman"
:
KalmanPredictor
,
...
...
components/src/dynamo/planner/utils/perf_interpolation.py
View file @
2712426f
...
@@ -151,6 +151,8 @@ class DecodeInterpolator:
...
@@ -151,6 +151,8 @@ class DecodeInterpolator:
self
.
resolution
=
resolution
self
.
resolution
=
resolution
self
.
xi
=
np
.
linspace
(
0
,
1
,
resolution
)
self
.
xi
=
np
.
linspace
(
0
,
1
,
resolution
)
self
.
yi
=
np
.
linspace
(
0
,
max
(
self
.
y_context_length
),
resolution
)
self
.
yi
=
np
.
linspace
(
0
,
max
(
self
.
y_context_length
),
resolution
)
self
.
X
:
np
.
ndarray
self
.
Y
:
np
.
ndarray
self
.
X
,
self
.
Y
=
np
.
meshgrid
(
self
.
xi
,
self
.
yi
)
self
.
X
,
self
.
Y
=
np
.
meshgrid
(
self
.
xi
,
self
.
yi
)
# Lazy import scipy only when interpolation is actually needed
# Lazy import scipy only when interpolation is actually needed
...
...
components/src/dynamo/planner/utils/planner_core.py
View file @
2712426f
...
@@ -6,7 +6,7 @@ import logging
...
@@ -6,7 +6,7 @@ import logging
import
math
import
math
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
,
Union
from
prometheus_client
import
Gauge
,
start_http_server
from
prometheus_client
import
Gauge
,
start_http_server
...
@@ -36,6 +36,9 @@ from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_moonc
...
@@ -36,6 +36,9 @@ from dynamo.planner.utils.trace_data_extractor import extract_metrics_from_moonc
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
# Union of all connector types used by the planner
ConnectorType
=
Union
[
GlobalPlannerConnector
,
KubernetesConnector
,
VirtualConnector
]
configure_dynamo_logging
()
configure_dynamo_logging
()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -248,14 +251,14 @@ class BasePlanner:
...
@@ -248,14 +251,14 @@ class BasePlanner:
def
__init__
(
def
__init__
(
self
,
self
,
runtime
:
DistributedRuntime
,
runtime
:
Optional
[
DistributedRuntime
]
,
config
:
PlannerConfig
,
config
:
PlannerConfig
,
dryrun
:
bool
=
False
,
dryrun
:
bool
=
False
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
shared_state
:
Optional
[
PlannerSharedState
]
=
None
,
prometheus_metrics
:
Optional
[
PlannerPrometheusMetrics
]
=
None
,
prometheus_metrics
:
Optional
[
PlannerPrometheusMetrics
]
=
None
,
prometheus_traffic_client
:
Optional
[
PrometheusAPIClient
]
=
None
,
prometheus_traffic_client
:
Optional
[
PrometheusAPIClient
]
=
None
,
prometheus_engine_client
:
Optional
[
DirectRouterMetricsClient
]
=
None
,
prometheus_engine_client
:
Optional
[
DirectRouterMetricsClient
]
=
None
,
connector
=
None
,
connector
:
Optional
[
ConnectorType
]
=
None
,
start_prometheus_server
:
bool
=
True
,
start_prometheus_server
:
bool
=
True
,
component_type
:
Optional
[
SubComponentType
]
=
None
,
component_type
:
Optional
[
SubComponentType
]
=
None
,
):
):
...
@@ -272,11 +275,13 @@ class BasePlanner:
...
@@ -272,11 +275,13 @@ class BasePlanner:
if
not
self
.
dryrun
:
if
not
self
.
dryrun
:
self
.
runtime
=
runtime
self
.
runtime
=
runtime
self
.
namespace
=
config
.
namespace
self
.
namespace
=
config
.
namespace
self
.
connector
:
ConnectorType
if
not
config
.
no_operation
:
if
not
config
.
no_operation
:
# Initialize connector based on environment
# Initialize connector based on environment
if
config
.
environment
==
"global-planner"
:
if
config
.
environment
==
"global-planner"
:
assert
config
.
global_planner_namespace
is
not
None
assert
config
.
global_planner_namespace
is
not
None
assert
runtime
is
not
None
self
.
connector
=
GlobalPlannerConnector
(
self
.
connector
=
GlobalPlannerConnector
(
runtime
,
runtime
,
self
.
namespace
,
self
.
namespace
,
...
@@ -289,6 +294,7 @@ class BasePlanner:
...
@@ -289,6 +294,7 @@ class BasePlanner:
self
.
namespace
,
self
.
model_name
self
.
namespace
,
self
.
model_name
)
)
elif
config
.
environment
==
"virtual"
:
elif
config
.
environment
==
"virtual"
:
assert
runtime
is
not
None
self
.
connector
=
VirtualConnector
(
self
.
connector
=
VirtualConnector
(
runtime
,
runtime
,
self
.
namespace
,
self
.
namespace
,
...
@@ -430,11 +436,12 @@ class BasePlanner:
...
@@ -430,11 +436,12 @@ class BasePlanner:
self
.
prometheus_engine_client
=
prometheus_engine_client
self
.
prometheus_engine_client
=
prometheus_engine_client
else
:
else
:
# Auto-discover frontend metrics URL in Kubernetes mode
# Auto-discover frontend metrics URL in Kubernetes mode
connector
=
getattr
(
self
,
"connector"
,
None
)
if
not
config
.
load_router_metrics_url
and
isinstance
(
if
not
config
.
load_router_metrics_url
and
isinstance
(
getattr
(
self
,
"
connector
"
,
None
)
,
KubernetesConnector
connector
,
KubernetesConnector
):
):
config
.
load_router_metrics_url
=
(
config
.
load_router_metrics_url
=
(
self
.
connector
.
get_frontend_metrics_url
()
connector
.
get_frontend_metrics_url
()
)
)
if
not
config
.
load_router_metrics_url
:
if
not
config
.
load_router_metrics_url
:
raise
ValueError
(
raise
ValueError
(
...
@@ -447,6 +454,9 @@ class BasePlanner:
...
@@ -447,6 +454,9 @@ class BasePlanner:
f
"Auto-discovered frontend metrics URL:
{
config
.
load_router_metrics_url
}
"
f
"Auto-discovered frontend metrics URL:
{
config
.
load_router_metrics_url
}
"
)
)
assert
(
config
.
load_router_metrics_url
is
not
None
),
"load_router_metrics_url must be set when load-based scaling is enabled"
self
.
prometheus_engine_client
=
DirectRouterMetricsClient
(
self
.
prometheus_engine_client
=
DirectRouterMetricsClient
(
config
.
load_router_metrics_url
,
config
.
namespace
config
.
load_router_metrics_url
,
config
.
namespace
)
)
...
@@ -494,6 +504,7 @@ class BasePlanner:
...
@@ -494,6 +504,7 @@ class BasePlanner:
async
def
_get_or_create_client
(
self
,
component_name
:
str
,
endpoint_name
:
str
):
async
def
_get_or_create_client
(
self
,
component_name
:
str
,
endpoint_name
:
str
):
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
"""Create a client for the given component and endpoint, with a brief sleep for state sync."""
assert
self
.
runtime
is
not
None
,
"Runtime is not initialized"
client
=
await
self
.
runtime
.
endpoint
(
client
=
await
self
.
runtime
.
endpoint
(
f
"
{
self
.
namespace
}
.
{
component_name
}
.
{
endpoint_name
}
"
f
"
{
self
.
namespace
}
.
{
component_name
}
.
{
endpoint_name
}
"
).
client
()
).
client
()
...
@@ -604,41 +615,46 @@ class BasePlanner:
...
@@ -604,41 +615,46 @@ class BasePlanner:
)
)
# Prometheus returns seconds, convert to milliseconds
# Prometheus returns seconds, convert to milliseconds
assert
(
self
.
model_name
is
not
None
),
"model_name must be set before observing traffic stats"
interval_str
=
f
"
{
self
.
config
.
throughput_adjustment_interval
}
s"
self
.
last_metrics
.
ttft
=
(
self
.
last_metrics
.
ttft
=
(
self
.
prometheus_traffic_client
.
get_avg_time_to_first_token
(
self
.
prometheus_traffic_client
.
get_avg_time_to_first_token
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
self
.
model_name
,
)
)
*
1000
*
1000
)
)
self
.
last_metrics
.
itl
=
(
self
.
last_metrics
.
itl
=
(
self
.
prometheus_traffic_client
.
get_avg_inter_token_latency
(
self
.
prometheus_traffic_client
.
get_avg_inter_token_latency
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
self
.
model_name
,
)
)
*
1000
*
1000
)
)
self
.
last_metrics
.
num_req
=
(
self
.
last_metrics
.
num_req
=
(
self
.
prometheus_traffic_client
.
get_avg_request_count
(
self
.
prometheus_traffic_client
.
get_avg_request_count
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
self
.
model_name
,
)
)
)
)
self
.
last_metrics
.
request_duration
=
(
self
.
last_metrics
.
request_duration
=
(
self
.
prometheus_traffic_client
.
get_avg_request_duration
(
self
.
prometheus_traffic_client
.
get_avg_request_duration
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
self
.
model_name
,
)
)
)
)
self
.
last_metrics
.
isl
=
(
self
.
last_metrics
.
isl
=
(
self
.
prometheus_traffic_client
.
get_avg_input_sequence_tokens
(
self
.
prometheus_traffic_client
.
get_avg_input_sequence_tokens
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
self
.
model_name
,
)
)
)
)
self
.
last_metrics
.
osl
=
(
self
.
last_metrics
.
osl
=
(
self
.
prometheus_traffic_client
.
get_avg_output_sequence_tokens
(
self
.
prometheus_traffic_client
.
get_avg_output_sequence_tokens
(
f
"
{
self
.
config
.
throughput_adjustment_
interval
}
s"
,
interval
_str
,
self
.
model_name
,
self
.
model_name
,
)
)
)
)
...
@@ -666,8 +682,11 @@ class BasePlanner:
...
@@ -666,8 +682,11 @@ class BasePlanner:
self
.
update_predictors_from_metrics
(
self
.
last_metrics
)
self
.
update_predictors_from_metrics
(
self
.
last_metrics
)
def
update_predictors_from_metrics
(
self
,
metrics
:
Metrics
)
->
None
:
def
update_predictors_from_metrics
(
self
,
metrics
:
Metrics
)
->
None
:
if
metrics
.
num_req
is
not
None
:
self
.
num_req_predictor
.
add_data_point
(
metrics
.
num_req
)
self
.
num_req_predictor
.
add_data_point
(
metrics
.
num_req
)
if
metrics
.
isl
is
not
None
:
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
self
.
isl_predictor
.
add_data_point
(
metrics
.
isl
)
if
metrics
.
osl
is
not
None
:
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
self
.
osl_predictor
.
add_data_point
(
metrics
.
osl
)
def
predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]:
def
predict_load
(
self
)
->
tuple
[
Optional
[
float
],
Optional
[
float
],
Optional
[
float
]]:
...
...
components/src/dynamo/planner/utils/prefill_planner.py
View file @
2712426f
...
@@ -94,6 +94,7 @@ class PrefillPlanner(BasePlanner):
...
@@ -94,6 +94,7 @@ class PrefillPlanner(BasePlanner):
return
None
return
None
def
_update_correction_factor
(
self
)
->
bool
:
def
_update_correction_factor
(
self
)
->
bool
:
assert
self
.
last_metrics
.
isl
is
not
None
and
self
.
last_metrics
.
ttft
is
not
None
expect_ttft
=
self
.
prefill_interpolator
.
interpolate_ttft
(
self
.
last_metrics
.
isl
)
expect_ttft
=
self
.
prefill_interpolator
.
interpolate_ttft
(
self
.
last_metrics
.
isl
)
self
.
p_correction_factor
=
self
.
last_metrics
.
ttft
/
expect_ttft
self
.
p_correction_factor
=
self
.
last_metrics
.
ttft
/
expect_ttft
logger
.
info
(
f
"Correction factor (prefill TTFT):
{
self
.
p_correction_factor
:.
3
f
}
"
)
logger
.
info
(
f
"Correction factor (prefill TTFT):
{
self
.
p_correction_factor
:.
3
f
}
"
)
...
@@ -117,6 +118,7 @@ class PrefillPlanner(BasePlanner):
...
@@ -117,6 +118,7 @@ class PrefillPlanner(BasePlanner):
"(no throughput satisfies TTFT target), falling back to min_endpoint"
"(no throughput satisfies TTFT target), falling back to min_endpoint"
)
)
return
self
.
config
.
min_endpoint
return
self
.
config
.
min_endpoint
assert
self
.
config
.
prefill_engine_num_gpu
is
not
None
next_num_p
=
math
.
ceil
(
next_num_p
=
math
.
ceil
(
pred_prefill_throughput
pred_prefill_throughput
/
p_thpt_per_gpu
/
p_thpt_per_gpu
...
...
components/src/dynamo/sglang/backend_args.py
View file @
2712426f
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
"""Dynamo SGLang wrapper configuration ArgGroup."""
"""Dynamo SGLang wrapper configuration ArgGroup."""
import
argparse
import
argparse
from
typing
import
Optional
,
Union
from
typing
import
Optional
from
dynamo.common.configuration.arg_group
import
ArgGroup
from
dynamo.common.configuration.arg_group
import
ArgGroup
from
dynamo.common.configuration.config_base
import
ConfigBase
from
dynamo.common.configuration.config_base
import
ConfigBase
...
@@ -117,7 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
...
@@ -117,7 +117,7 @@ class DynamoSGLangConfig(ConfigBase):
multimodal_processor
:
bool
multimodal_processor
:
bool
multimodal_encode_worker
:
bool
multimodal_encode_worker
:
bool
multimodal_worker
:
bool
multimodal_worker
:
bool
embedding_transfer_mode
:
Union
[
str
,
EmbeddingTransferMode
]
embedding_transfer_mode
:
EmbeddingTransferMode
embedding_worker
:
bool
embedding_worker
:
bool
image_diffusion_worker
:
bool
image_diffusion_worker
:
bool
...
@@ -127,10 +127,11 @@ class DynamoSGLangConfig(ConfigBase):
...
@@ -127,10 +127,11 @@ class DynamoSGLangConfig(ConfigBase):
video_generation_worker
:
bool
video_generation_worker
:
bool
def
validate
(
self
)
->
None
:
def
validate
(
self
)
->
None
:
if
isinstance
(
self
.
embedding_transfer_mode
,
str
):
if
not
isinstance
(
self
.
embedding_transfer_mode
,
EmbeddingTransferMode
):
self
.
embedding_transfer_mode
=
EmbeddingTransferMode
(
self
.
embedding_transfer_mode
=
EmbeddingTransferMode
(
self
.
embedding_transfer_mode
str
(
self
.
embedding_transfer_mode
)
)
)
if
(
self
.
disagg_config
is
not
None
)
^
(
self
.
disagg_config_key
is
not
None
):
if
(
self
.
disagg_config
is
not
None
)
^
(
self
.
disagg_config_key
is
not
None
):
raise
ValueError
(
raise
ValueError
(
"Both 'disagg_config' and 'disagg_config_key' must be provided together."
"Both 'disagg_config' and 'disagg_config_key' must be provided together."
...
...
components/src/dynamo/sglang/request_handlers/handler_base.py
View file @
2712426f
...
@@ -8,13 +8,23 @@ import random
...
@@ -8,13 +8,23 @@ import random
import
socket
import
socket
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Optional
,
Tuple
from
typing
import
(
Any
,
AsyncGenerator
,
AsyncIterator
,
Dict
,
Generic
,
Optional
,
Tuple
,
TypeVar
,
)
import
sglang
as
sgl
import
sglang
as
sgl
from
sglang.srt.utils
import
get_local_ip_auto
from
sglang.srt.utils
import
get_local_ip_auto
from
dynamo._core
import
Context
from
dynamo._core
import
Context
from
dynamo.common.utils.input_params
import
InputParamManager
from
dynamo.common.utils.input_params
import
InputParamManager
from
dynamo.llm
import
KvEventPublisher
,
WorkerMetricsPublisher
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.runtime
import
DistributedRuntime
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.args
import
Config
from
dynamo.sglang.publisher
import
DynamoSglangPublisher
from
dynamo.sglang.publisher
import
DynamoSglangPublisher
...
@@ -72,7 +82,11 @@ class SGLangEngineQuiesceController:
...
@@ -72,7 +82,11 @@ class SGLangEngineQuiesceController:
self
.
_is_quiesced
=
False
self
.
_is_quiesced
=
False
class
BaseGenerativeHandler
(
ABC
):
RequestT
=
TypeVar
(
"RequestT"
)
ResponseT
=
TypeVar
(
"ResponseT"
)
class
BaseGenerativeHandler
(
ABC
,
Generic
[
RequestT
,
ResponseT
]):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
"""Minimal base class for all generative handlers (LLM, diffusion, etc.).
Provides common infrastructure for:
Provides common infrastructure for:
...
@@ -95,27 +109,24 @@ class BaseGenerativeHandler(ABC):
...
@@ -95,27 +109,24 @@ class BaseGenerativeHandler(ABC):
self
.
config
=
config
self
.
config
=
config
# Set up metrics and KV publishers
# Set up metrics and KV publishers
self
.
metrics_publisher
:
Optional
[
WorkerMetricsPublisher
]
=
None
self
.
kv_publisher
:
Optional
[
KvEventPublisher
]
=
None
if
publisher
is
not
None
:
if
publisher
is
not
None
:
self
.
metrics_publisher
=
publisher
.
metrics_publisher
self
.
metrics_publisher
=
publisher
.
metrics_publisher
self
.
kv_publisher
=
publisher
.
kv_publisher
self
.
kv_publisher
=
publisher
.
kv_publisher
else
:
self
.
metrics_publisher
=
None
self
.
kv_publisher
=
None
@
abstractmethod
@
abstractmethod
async
def
generate
(
def
generate
(
self
,
request
:
RequestT
,
context
:
Context
)
->
AsyncIterator
[
ResponseT
]:
self
,
request
:
Dict
[
str
,
Any
],
context
:
Context
)
->
AsyncGenerator
[
Dict
[
str
,
Any
],
None
]:
"""Generate response from request.
"""Generate response from request.
Args:
Args:
request: Request
dict
with input and parameters.
request: Request with input and parameters.
context: Context object for cancellation handling.
context: Context object for cancellation handling.
Yields:
Yields:
Response data (format varies by handler implementation).
Response data (format varies by handler implementation).
"""
"""
pass
...
def
cleanup
(
self
)
->
None
:
def
cleanup
(
self
)
->
None
:
"""Cleanup resources. Override in subclasses as needed."""
"""Cleanup resources. Override in subclasses as needed."""
...
@@ -137,7 +148,7 @@ class BaseGenerativeHandler(ABC):
...
@@ -137,7 +148,7 @@ class BaseGenerativeHandler(ABC):
return
{
"traceparent"
:
f
"00-
{
trace_id
}
-
{
span_id
}
-01"
}
return
{
"traceparent"
:
f
"00-
{
trace_id
}
-
{
span_id
}
-01"
}
class
BaseWorkerHandler
(
BaseGenerativeHandler
):
class
BaseWorkerHandler
(
BaseGenerativeHandler
[
RequestT
,
ResponseT
]
):
"""Abstract base class for SGLang LLM worker handlers.
"""Abstract base class for SGLang LLM worker handlers.
Extends BaseGenerativeHandler with LLM-specific functionality:
Extends BaseGenerativeHandler with LLM-specific functionality:
...
@@ -175,9 +186,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
...
@@ -175,9 +186,6 @@ class BaseWorkerHandler(BaseGenerativeHandler):
if
publisher
is
not
None
:
if
publisher
is
not
None
:
self
.
metrics_publisher
=
publisher
.
metrics_publisher
self
.
metrics_publisher
=
publisher
.
metrics_publisher
self
.
kv_publisher
=
publisher
.
kv_publisher
self
.
kv_publisher
=
publisher
.
kv_publisher
else
:
self
.
metrics_publisher
=
None
self
.
kv_publisher
=
None
self
.
serving_mode
=
config
.
serving_mode
self
.
serving_mode
=
config
.
serving_mode
self
.
skip_tokenizer_init
=
config
.
server_args
.
skip_tokenizer_init
self
.
skip_tokenizer_init
=
config
.
server_args
.
skip_tokenizer_init
self
.
enable_trace
=
config
.
server_args
.
enable_trace
self
.
enable_trace
=
config
.
server_args
.
enable_trace
...
@@ -454,17 +462,17 @@ class BaseWorkerHandler(BaseGenerativeHandler):
...
@@ -454,17 +462,17 @@ class BaseWorkerHandler(BaseGenerativeHandler):
)
)
@
abstractmethod
@
abstractmethod
async
def
generate
(
self
,
request
:
Dict
[
str
,
Any
]
,
context
:
Context
):
def
generate
(
self
,
request
:
RequestT
,
context
:
Context
)
->
AsyncIterator
[
ResponseT
]
:
"""Generate response from request.
"""Generate response from request.
Args:
Args:
request: Request
dict
with input and parameters.
request: Request with input and parameters.
context: Context object for cancellation handling.
context: Context object for cancellation handling.
Yields:
Yields:
Response data (format varies by handler implementation).
Response data (format varies by handler implementation).
"""
"""
pass
...
def
cleanup
(
self
)
->
None
:
def
cleanup
(
self
)
->
None
:
"""Cleanup resources. Override in subclasses as needed."""
"""Cleanup resources. Override in subclasses as needed."""
...
...
components/src/dynamo/sglang/request_handlers/llm/decode_handler.py
View file @
2712426f
...
@@ -24,7 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -24,7 +24,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
self
,
self
,
engine
:
sgl
.
Engine
,
engine
:
sgl
.
Engine
,
config
:
Config
,
config
:
Config
,
publisher
:
DynamoSglangPublisher
,
publisher
:
Optional
[
DynamoSglangPublisher
]
=
None
,
generate_endpoint
=
None
,
generate_endpoint
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
)
->
None
:
)
->
None
:
...
@@ -230,7 +230,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -230,7 +230,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
# This lets SGLang proceed to the second token generation, which will
# This lets SGLang proceed to the second token generation, which will
# async context switch and allow the abort monitor to signal cancellation.
# async context switch and allow the abort monitor to signal cancellation.
# The loop should exit by itself when context.is_stopped() returns True.
# The loop should exit by itself when context.is_stopped() returns True.
out
=
{}
out
:
dict
[
str
,
Any
]
=
{}
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
finish_reason
=
res
[
"meta_info"
][
"finish_reason"
]
if
finish_reason
:
if
finish_reason
:
out
[
"finish_reason"
]
=
normalize_finish_reason
(
out
[
"finish_reason"
]
=
normalize_finish_reason
(
...
...
components/src/dynamo/sglang/request_handlers/llm/diffusion_handler.py
View file @
2712426f
...
@@ -21,7 +21,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
...
@@ -21,7 +21,7 @@ class DiffusionWorkerHandler(DecodeWorkerHandler):
self
,
self
,
engine
:
sgl
.
Engine
,
engine
:
sgl
.
Engine
,
config
:
Config
,
config
:
Config
,
publisher
:
DynamoSglangPublisher
=
None
,
publisher
:
Optional
[
DynamoSglangPublisher
]
=
None
,
generate_endpoint
=
None
,
generate_endpoint
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
)
->
None
:
)
->
None
:
...
...
components/src/dynamo/sglang/request_handlers/multimodal/encode_worker_handler.py
View file @
2712426f
...
@@ -38,7 +38,7 @@ except ImportError as e:
...
@@ -38,7 +38,7 @@ except ImportError as e:
DEVICE
=
"cpu"
DEVICE
=
"cpu"
class
MultimodalEncodeWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalEncodeWorkerHandler
(
BaseWorkerHandler
[
SglangMultimodalRequest
,
str
]
):
"""
"""
Handler for multimodal encode worker component that processes images/videos
Handler for multimodal encode worker component that processes images/videos
and forwards them to the downstream worker.
and forwards them to the downstream worker.
...
@@ -84,12 +84,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
...
@@ -84,12 +84,19 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
if
image_token_str
==
"<|vision_start|><|image_pad|><|vision_end|>"
:
if
image_token_str
==
"<|vision_start|><|image_pad|><|vision_end|>"
:
# These are likely the individual special tokens for Qwen2.5-VL
# These are likely the individual special tokens for Qwen2.5-VL
image_pad_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"<|image_pad|>"
)
image_pad_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
"<|image_pad|>"
)
assert
isinstance
(
image_pad_id
,
int
),
f
"Expected int token id, got
{
type
(
image_pad_id
)
}
"
# Use the image_pad token as the main image token
# Use the image_pad token as the main image token
self
.
image_token_id
=
image_pad_id
self
.
image_token_id
:
int
=
image_pad_id
else
:
else
:
# Fallback for other models
# Fallback for other models
self
.
image_token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
image_token_str
)
token_id
=
self
.
tokenizer
.
convert_tokens_to_ids
(
image_token_str
)
assert
isinstance
(
token_id
,
int
),
f
"Expected int token id, got
{
type
(
token_id
)
}
"
self
.
image_token_id
=
token_id
self
.
min_workers
=
1
self
.
min_workers
=
1
...
@@ -230,10 +237,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
...
@@ -230,10 +237,11 @@ class MultimodalEncodeWorkerHandler(BaseWorkerHandler):
zip
(
multimodal_groups
,
image_grid_thw_list
)
zip
(
multimodal_groups
,
image_grid_thw_list
)
):
):
mm_group
.
image_grid_thw
=
image_grid_thw
mm_group
.
image_grid_thw
=
image_grid_thw
if
mm_group
.
multimodal_input
is
not
None
:
mm_group
.
multimodal_input
.
image_url
=
None
mm_group
.
multimodal_input
.
image_url
=
None
# Store shared tensor transfer metadata at request level.
# Store shared tensor transfer metadata at request level.
request
.
embeddings_shape
=
tuple
(
precomputed_embeddings
.
shape
)
request
.
embeddings_shape
=
tuple
(
precomputed_embeddings
.
shape
)
# type: ignore[assignment]
request
.
transfer_payload
=
None
request
.
transfer_payload
=
None
search_start
=
0
search_start
=
0
...
...
components/src/dynamo/sglang/request_handlers/multimodal/processor_handler.py
View file @
2712426f
...
@@ -6,7 +6,7 @@ import json
...
@@ -6,7 +6,7 @@ import json
import
logging
import
logging
import
time
import
time
import
uuid
import
uuid
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Optional
from
transformers
import
AutoTokenizer
from
transformers
import
AutoTokenizer
...
@@ -20,6 +20,7 @@ from dynamo.sglang.protocol import (
...
@@ -20,6 +20,7 @@ from dynamo.sglang.protocol import (
MultiModalGroup
,
MultiModalGroup
,
MultiModalInput
,
MultiModalInput
,
MultiModalRequest
,
MultiModalRequest
,
PreprocessedRequest
,
SglangMultimodalRequest
,
SglangMultimodalRequest
,
)
)
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
from
dynamo.sglang.request_handlers.handler_base
import
BaseWorkerHandler
...
@@ -27,7 +28,7 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
...
@@ -27,7 +28,7 @@ from dynamo.sglang.request_handlers.handler_base import BaseWorkerHandler
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
class
MultimodalProcessorHandler
(
BaseWorkerHandler
):
class
MultimodalProcessorHandler
(
BaseWorkerHandler
[
MultiModalRequest
,
Dict
[
str
,
Any
]]
):
"""
"""
Handler for multimodal processor component that processes multimodal requests
Handler for multimodal processor component that processes multimodal requests
and forwards them to the encode worker.
and forwards them to the encode worker.
...
@@ -56,7 +57,9 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
...
@@ -56,7 +57,9 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
def
cleanup
(
self
):
def
cleanup
(
self
):
pass
pass
async
def
generate
(
self
,
raw_request
:
MultiModalRequest
,
context
:
Context
):
async
def
generate
(
self
,
raw_request
:
MultiModalRequest
,
context
:
Context
)
->
AsyncGenerator
[
Dict
[
str
,
Any
],
None
]:
"""
"""
Process multimodal request and forward to encode worker.
Process multimodal request and forward to encode worker.
...
@@ -119,7 +122,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
...
@@ -119,7 +122,7 @@ class MultimodalProcessorHandler(BaseWorkerHandler):
)
)
worker_request
=
SglangMultimodalRequest
(
worker_request
=
SglangMultimodalRequest
(
request
=
sglang_request
,
request
=
PreprocessedRequest
(
**
sglang_request
)
,
multimodal_inputs
=
multimodal_groups
,
multimodal_inputs
=
multimodal_groups
,
)
)
...
...
components/src/dynamo/sglang/request_handlers/multimodal/worker_handler.py
View file @
2712426f
...
@@ -256,7 +256,7 @@ async def _build_mm_items(
...
@@ -256,7 +256,7 @@ async def _build_mm_items(
return
mm_items
,
embeddings
,
tensor_id
return
mm_items
,
embeddings
,
tensor_id
class
MultimodalWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalWorkerHandler
(
BaseWorkerHandler
[
SglangMultimodalRequest
,
str
]
):
"""
"""
Multimodal worker handler for LLM inference with multimodal data.
Multimodal worker handler for LLM inference with multimodal data.
Handles both aggregated and disaggregated modes.
Handles both aggregated and disaggregated modes.
...
@@ -490,7 +490,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
...
@@ -490,7 +490,9 @@ class MultimodalWorkerHandler(BaseWorkerHandler):
logger
.
info
(
"Multimodal worker engine shutdown"
)
logger
.
info
(
"Multimodal worker engine shutdown"
)
class
MultimodalPrefillWorkerHandler
(
BaseWorkerHandler
):
class
MultimodalPrefillWorkerHandler
(
BaseWorkerHandler
[
DisaggSglangMultimodalRequest
,
str
]
):
"""
"""
Multimodal prefill worker handler for disaggregated inference
Multimodal prefill worker handler for disaggregated inference
Processes multimodal inputs and coordinates with decode worker.
Processes multimodal inputs and coordinates with decode worker.
...
...
components/src/dynamo/sglang/request_handlers/video_generation/video_generation_handler.py
View file @
2712426f
...
@@ -103,16 +103,22 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
...
@@ -103,16 +103,22 @@ class VideoGenerationWorkerHandler(BaseGenerativeHandler):
)
)
# Parse size
# Parse size
assert
req
.
size
is
not
None
,
"Size is required"
width
,
height
=
self
.
_parse_size
(
req
.
size
)
width
,
height
=
self
.
_parse_size
(
req
.
size
)
# Calculate num_frames if not explicitly provided
# Calculate num_frames if not explicitly provided
num_frames
=
nvext
.
num_frames
num_frames
=
nvext
.
num_frames
assert
nvext
.
fps
is
not
None
,
"FPS is required"
if
num_frames
is
None
:
if
num_frames
is
None
:
assert
req
.
seconds
is
not
None
,
"Seconds is required"
num_frames
=
nvext
.
fps
*
req
.
seconds
num_frames
=
nvext
.
fps
*
req
.
seconds
# Generate video
# Generate video
context_id
=
context
.
id
()
context_id
=
context
.
id
()
assert
context_id
is
not
None
assert
context_id
is
not
None
assert
(
nvext
.
num_inference_steps
is
not
None
),
"Num inference steps is required"
video_bytes
=
await
self
.
_generate_video
(
video_bytes
=
await
self
.
_generate_video
(
prompt
=
req
.
prompt
,
prompt
=
req
.
prompt
,
width
=
width
,
width
=
width
,
...
...
components/src/dynamo/trtllm/logits_processing/adapter.py
View file @
2712426f
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
logging
import
logging
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Sequence
import
torch
import
torch
from
tensorrt_llm.sampling_params
import
LogitsProcessor
from
tensorrt_llm.sampling_params
import
LogitsProcessor
...
@@ -70,7 +70,7 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
...
@@ -70,7 +70,7 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
def
create_trtllm_adapters
(
def
create_trtllm_adapters
(
processors
:
List
[
BaseLogitsProcessor
],
processors
:
Sequence
[
BaseLogitsProcessor
],
)
->
List
[
TrtllmDynamoLogitsAdapter
]:
)
->
List
[
TrtllmDynamoLogitsAdapter
]:
"""
"""
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
Create TensorRT-LLM compatible adapters from Dynamo logits processors.
...
...
components/src/dynamo/trtllm/request_handlers/aggregated_handler.py
View file @
2712426f
...
@@ -5,7 +5,9 @@
...
@@ -5,7 +5,9 @@
import
logging
import
logging
from
collections.abc
import
AsyncGenerator
from
collections.abc
import
AsyncGenerator
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
torch
from
dynamo._core
import
Context
from
dynamo._core
import
Context
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
from
dynamo.common.memory.multimodal_embedding_cache_manager
import
(
...
@@ -40,7 +42,7 @@ class AggregatedHandler(HandlerBase):
...
@@ -40,7 +42,7 @@ class AggregatedHandler(HandlerBase):
"""Generate response, optionally using remote encoder for multimodal."""
"""Generate response, optionally using remote encoder for multimodal."""
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
logging
.
debug
(
f
"AggregatedHandler Request ID:
{
context
.
id
()
}
"
)
embeddings
=
None
embeddings
:
Optional
[
Union
[
torch
.
Tensor
,
dict
]]
=
None
ep_disaggregated_params
=
None
ep_disaggregated_params
=
None
if
self
.
multimodal_processor
and
self
.
encode_client
:
if
self
.
multimodal_processor
and
self
.
encode_client
:
messages
=
request
.
get
(
"extra_args"
,
{}).
get
(
messages
=
request
.
get
(
"extra_args"
,
{}).
get
(
...
@@ -57,7 +59,7 @@ class AggregatedHandler(HandlerBase):
...
@@ -57,7 +59,7 @@ class AggregatedHandler(HandlerBase):
self
.
_encoder_cache
,
self
.
_encoder_cache
,
)
)
if
isinstance
(
result
,
list
):
if
isinstance
(
result
,
list
):
embeddings
=
result
embeddings
=
result
# type: ignore[assignment]
else
:
else
:
ep_disaggregated_params
=
result
ep_disaggregated_params
=
result
...
...
components/src/dynamo/trtllm/request_handlers/handler_base.py
View file @
2712426f
...
@@ -31,7 +31,7 @@ from tensorrt_llm.llmapi.llm import SamplingParams
...
@@ -31,7 +31,7 @@ from tensorrt_llm.llmapi.llm import SamplingParams
from
tensorrt_llm.sampling_params
import
GuidedDecodingParams
from
tensorrt_llm.sampling_params
import
GuidedDecodingParams
from
tensorrt_llm.scheduling_params
import
SchedulingParams
from
tensorrt_llm.scheduling_params
import
SchedulingParams
from
dynamo._core
import
Context
from
dynamo._core
import
Client
,
Context
from
dynamo.common.utils.otel_tracing
import
build_trace_headers
from
dynamo.common.utils.otel_tracing
import
build_trace_headers
from
dynamo.logits_processing.examples
import
HelloWorldLogitsProcessor
from
dynamo.logits_processing.examples
import
HelloWorldLogitsProcessor
from
dynamo.nixl_connect
import
Connector
from
dynamo.nixl_connect
import
Connector
...
@@ -65,9 +65,9 @@ class RequestHandlerConfig:
...
@@ -65,9 +65,9 @@ class RequestHandlerConfig:
engine
:
TensorRTLLMEngine
engine
:
TensorRTLLMEngine
default_sampling_params
:
SamplingParams
default_sampling_params
:
SamplingParams
publisher
:
Publisher
publisher
:
Optional
[
Publisher
]
disaggregation_mode
:
DisaggregationMode
disaggregation_mode
:
DisaggregationMode
encode_client
:
Optional
[
objec
t
]
=
None
encode_client
:
Optional
[
Clien
t
]
=
None
multimodal_processor
:
Optional
[
multimodal_processor
:
Optional
[
MultimodalRequestProcessor
MultimodalRequestProcessor
]
=
None
# for multimodal support
]
=
None
# for multimodal support
...
@@ -558,11 +558,11 @@ class HandlerBase(BaseGenerativeHandler):
...
@@ -558,11 +558,11 @@ class HandlerBase(BaseGenerativeHandler):
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
# PREFILL/ENCODE/AGGREGATED: Process multimodal content if available
if
self
.
multimodal_processor
:
if
self
.
multimodal_processor
:
processed_inpu
t
=
await
self
.
multimodal_processor
.
process_openai_request
(
mm_resul
t
=
await
self
.
multimodal_processor
.
process_openai_request
(
request
,
embeddings
,
ep_disaggregated_params
request
,
embeddings
,
ep_disaggregated_params
)
)
if
processed_inpu
t
:
if
mm_resul
t
:
return
processed_inpu
t
return
mm_resul
t
# If multimodal processing returned None but request has multimodal data,
# If multimodal processing returned None but request has multimodal data,
# this is an error (not a text-only request). Raise instead of falling back.
# this is an error (not a text-only request). Raise instead of falling back.
...
...
components/src/dynamo/trtllm/request_handlers/handlers.py
View file @
2712426f
...
@@ -111,6 +111,8 @@ class PrefillHandler(HandlerBase):
...
@@ -111,6 +111,8 @@ class PrefillHandler(HandlerBase):
Encoder's embeddings tensor to be used by the prefill worker
Encoder's embeddings tensor to be used by the prefill worker
"""
"""
# Get response with shape info and readable metadata
# Get response with shape info and readable metadata
if
self
.
encode_client
is
None
:
raise
RuntimeError
(
"Encode client is not configured."
)
encode_response
=
None
encode_response
=
None
async
for
res
in
await
self
.
encode_client
.
round_robin
(
request
):
async
for
res
in
await
self
.
encode_client
.
round_robin
(
request
):
encode_response
=
res
.
data
()
encode_response
=
res
.
data
()
...
@@ -119,6 +121,8 @@ class PrefillHandler(HandlerBase):
...
@@ -119,6 +121,8 @@ class PrefillHandler(HandlerBase):
if
not
encode_response
:
if
not
encode_response
:
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
raise
RuntimeError
(
"Did not receive a response from the encode worker."
)
if
self
.
connector
is
None
:
raise
RuntimeError
(
"Connector is not configured."
)
# Use utility function to handle NIXL reading and reconstruction
# Use utility function to handle NIXL reading and reconstruction
return
await
EncodeHelper
.
read_embeddings_from_encode_response
(
return
await
EncodeHelper
.
read_embeddings_from_encode_response
(
encode_response
,
self
.
connector
encode_response
,
self
.
connector
...
...
components/src/dynamo/trtllm/workers/video_diffusion_worker.py
View file @
2712426f
...
@@ -61,6 +61,9 @@ async def init_video_diffusion_worker(
...
@@ -61,6 +61,9 @@ async def init_video_diffusion_worker(
else
[]
else
[]
)
)
if
not
config
.
endpoint
:
raise
ValueError
(
"endpoint must be configured for video diffusion worker"
)
# Build DiffusionConfig from the main Config
# Build DiffusionConfig from the main Config
diffusion_config
=
DiffusionConfig
(
diffusion_config
=
DiffusionConfig
(
namespace
=
config
.
namespace
,
namespace
=
config
.
namespace
,
...
...
components/src/dynamo/vllm/handlers.py
View file @
2712426f
...
@@ -13,7 +13,7 @@ import time
...
@@ -13,7 +13,7 @@ import time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Async
Gen
erator
,
Dict
,
Final
from
typing
import
Any
,
Async
It
erator
,
Dict
,
Final
,
Generic
,
TypeVar
import
torch
import
torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -23,6 +23,7 @@ from vllm.outputs import RequestOutput
...
@@ -23,6 +23,7 @@ from vllm.outputs import RequestOutput
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.sampling_params
import
SamplingParams
,
StructuredOutputsParams
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
vllm.v1.engine.exceptions
import
EngineDeadError
from
dynamo._core
import
Context
from
dynamo.common.multimodal.image_loader
import
ImageLoader
from
dynamo.common.multimodal.image_loader
import
ImageLoader
from
dynamo.common.utils.engine_response
import
normalize_finish_reason
from
dynamo.common.utils.engine_response
import
normalize_finish_reason
from
dynamo.common.utils.input_params
import
InputParamManager
from
dynamo.common.utils.input_params
import
InputParamManager
...
@@ -325,7 +326,11 @@ def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
...
@@ -325,7 +326,11 @@ def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
)
)
class
BaseWorkerHandler
(
ABC
):
RequestT
=
TypeVar
(
"RequestT"
)
ResponseT
=
TypeVar
(
"ResponseT"
)
class
BaseWorkerHandler
(
ABC
,
Generic
[
RequestT
,
ResponseT
]):
"""
"""
Request handler for the generate and clear_kv_blocks endpoints.
Request handler for the generate and clear_kv_blocks endpoints.
"""
"""
...
@@ -459,7 +464,7 @@ class BaseWorkerHandler(ABC):
...
@@ -459,7 +464,7 @@ class BaseWorkerHandler(ABC):
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
return
{
"status"
:
"error"
,
"message"
:
str
(
e
)}
@
abstractmethod
@
abstractmethod
async
def
generate
(
self
,
request
,
c
ontext
)
->
Async
Gen
erator
[
dict
,
None
]:
def
generate
(
self
,
request
:
RequestT
,
context
:
C
ontext
)
->
Async
It
erator
[
ResponseT
]:
raise
NotImplementedError
raise
NotImplementedError
async
def
_monitor_abort
(
self
,
context
,
request_id
,
is_prefill
):
async
def
_monitor_abort
(
self
,
context
,
request_id
,
is_prefill
):
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment