Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
026f361d
Unverified
Commit
026f361d
authored
Feb 19, 2026
by
Biswa Panda
Committed by
GitHub
Feb 19, 2026
Browse files
feat: resolve lora request for multimodal workers (#6399)
parent
bc320806
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
156 additions
and
134 deletions
+156
-134
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+62
-70
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+39
-61
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
.../vllm/multimodal_handlers/multimodal_pd_worker_handler.py
+15
-0
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
...nts/src/dynamo/vllm/multimodal_handlers/worker_handler.py
+5
-0
components/src/dynamo/vllm/worker_factory.py
components/src/dynamo/vllm/worker_factory.py
+35
-3
No files found.
components/src/dynamo/vllm/handlers.py
View file @
026f361d
...
@@ -12,6 +12,7 @@ import threading
...
@@ -12,6 +12,7 @@ import threading
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
asynccontextmanager
from
contextlib
import
asynccontextmanager
from
dataclasses
import
dataclass
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Final
from
typing
import
Any
,
AsyncGenerator
,
Dict
,
Final
import
torch
import
torch
...
@@ -50,6 +51,14 @@ configure_dynamo_logging()
...
@@ -50,6 +51,14 @@ configure_dynamo_logging()
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
(
frozen
=
True
)
class
LoRAInfo
:
"""Metadata for a loaded LoRA adapter."""
id
:
int
path
:
str
def
_compute_mm_uuids
(
def
_compute_mm_uuids
(
multi_modal_data
:
Dict
[
str
,
Any
]
|
None
multi_modal_data
:
Dict
[
str
,
Any
]
|
None
)
->
Dict
[
str
,
list
[
str
]]
|
None
:
)
->
Dict
[
str
,
list
[
str
]]
|
None
:
...
@@ -287,9 +296,8 @@ class BaseWorkerHandler(ABC):
...
@@ -287,9 +296,8 @@ class BaseWorkerHandler(ABC):
# NIXL connector for frontend decoding - lazy initialized
# NIXL connector for frontend decoding - lazy initialized
self
.
_nixl_connector
=
None
self
.
_nixl_connector
=
None
self
.
_nixl_connector_lock
=
asyncio
.
Lock
()
self
.
_nixl_connector_lock
=
asyncio
.
Lock
()
# LoRA tracking
# LoRA tracking: name -> LoRAInfo(id, path)
self
.
lora_id_for_name
:
dict
[
str
,
int
]
=
{}
self
.
loaded_loras
:
dict
[
str
,
LoRAInfo
]
=
{}
self
.
lora_name_to_path
:
dict
[
str
,
str
]
=
{}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
self
.
_lora_load_locks
:
dict
[
str
,
asyncio
.
Lock
]
=
{}
self
.
_lora_load_locks
:
dict
[
str
,
asyncio
.
Lock
]
=
{}
# Guard lock-map access in case handlers are invoked from multiple threads.
# Guard lock-map access in case handlers are invoked from multiple threads.
...
@@ -458,6 +466,16 @@ class BaseWorkerHandler(ABC):
...
@@ -458,6 +466,16 @@ class BaseWorkerHandler(ABC):
if
temp_dir
is
not
None
:
if
temp_dir
is
not
None
:
self
.
temp_dirs
.
append
(
temp_dir
)
self
.
temp_dirs
.
append
(
temp_dir
)
def
_resolve_lora_request
(
self
,
model_name
:
str
|
None
)
->
LoRARequest
|
None
:
"""Return a LoRARequest if model_name is a loaded adapter, else None."""
if
model_name
and
(
lora
:
=
self
.
loaded_loras
.
get
(
model_name
)):
return
LoRARequest
(
lora_name
=
model_name
,
lora_int_id
=
lora
.
id
,
lora_path
=
lora
.
path
,
)
return
None
def
_get_lora_lock
(
self
,
lora_name
:
str
)
->
asyncio
.
Lock
:
def
_get_lora_lock
(
self
,
lora_name
:
str
)
->
asyncio
.
Lock
:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with
self
.
_lora_load_locks_guard
:
with
self
.
_lora_load_locks_guard
:
...
@@ -534,8 +552,8 @@ class BaseWorkerHandler(ABC):
...
@@ -534,8 +552,8 @@ class BaseWorkerHandler(ABC):
try
:
try
:
# Check if already loaded (idempotency check after acquiring lock).
# Check if already loaded (idempotency check after acquiring lock).
# Another concurrent request may have loaded this LoRA while we waited.
# Another concurrent request may have loaded this LoRA while we waited.
if
lora_name
in
self
.
lo
ra_i
d_
f
or
_name
:
if
lora_name
in
self
.
lo
ade
d_
l
or
as
:
lora_id
=
self
.
lo
ra_i
d_
f
or
_name
[
lora_name
]
lora_id
=
self
.
lo
ade
d_
l
or
as
[
lora_name
]
.
id
logger
.
info
(
logger
.
info
(
f
"LoRA adapter already loaded (concurrent request completed): "
f
"LoRA adapter already loaded (concurrent request completed): "
f
"
{
lora_name
}
with ID
{
lora_id
}
"
f
"
{
lora_name
}
with ID
{
lora_id
}
"
...
@@ -576,8 +594,7 @@ class BaseWorkerHandler(ABC):
...
@@ -576,8 +594,7 @@ class BaseWorkerHandler(ABC):
)
)
# Track the LoRA
# Track the LoRA
self
.
lora_id_for_name
[
lora_name
]
=
lora_id
self
.
loaded_loras
[
lora_name
]
=
LoRAInfo
(
id
=
lora_id
,
path
=
lora_path
)
self
.
lora_name_to_path
[
lora_name
]
=
lora_path
logger
.
info
(
logger
.
info
(
f
"Successfully loaded LoRA adapter:
{
lora_name
}
with ID
{
lora_id
}
"
f
"Successfully loaded LoRA adapter:
{
lora_name
}
with ID
{
lora_id
}
"
)
)
...
@@ -625,11 +642,7 @@ class BaseWorkerHandler(ABC):
...
@@ -625,11 +642,7 @@ class BaseWorkerHandler(ABC):
f
"Rolling back: removing LoRA '
{
lora_name
}
' from engine"
f
"Rolling back: removing LoRA '
{
lora_name
}
' from engine"
)
)
await
self
.
engine_client
.
remove_lora
(
lora_id
)
await
self
.
engine_client
.
remove_lora
(
lora_id
)
# Remove from tracking dictionaries
self
.
loaded_loras
.
pop
(
lora_name
,
None
)
if
lora_name
in
self
.
lora_id_for_name
:
del
self
.
lora_id_for_name
[
lora_name
]
if
lora_name
in
self
.
lora_name_to_path
:
del
self
.
lora_name_to_path
[
lora_name
]
logger
.
debug
(
logger
.
debug
(
f
"Successfully rolled back LoRA '
{
lora_name
}
'"
f
"Successfully rolled back LoRA '
{
lora_name
}
'"
)
)
...
@@ -661,7 +674,7 @@ class BaseWorkerHandler(ABC):
...
@@ -661,7 +674,7 @@ class BaseWorkerHandler(ABC):
# loaded, remove the lock entry (best-effort).
# loaded, remove the lock entry (best-effort).
with
self
.
_lora_load_locks_guard
:
with
self
.
_lora_load_locks_guard
:
if
(
if
(
lora_name
not
in
self
.
lo
ra_i
d_
f
or
_name
lora_name
not
in
self
.
lo
ade
d_
l
or
as
and
self
.
_lora_load_locks
.
get
(
lora_name
)
is
lock
and
self
.
_lora_load_locks
.
get
(
lora_name
)
is
lock
):
):
self
.
_lora_load_locks
.
pop
(
lora_name
,
None
)
self
.
_lora_load_locks
.
pop
(
lora_name
,
None
)
...
@@ -697,23 +710,22 @@ class BaseWorkerHandler(ABC):
...
@@ -697,23 +710,22 @@ class BaseWorkerHandler(ABC):
async
with
lock
:
async
with
lock
:
try
:
try
:
# Check if the LoRA exists *after* waiting for any in-progress load.
# Check if the LoRA exists *after* waiting for any in-progress load.
if
lora_name
not
in
self
.
lora_id_for_name
:
lora
=
self
.
loaded_loras
.
get
(
lora_name
)
if
lora
is
None
:
yield
{
yield
{
"status"
:
"error"
,
"status"
:
"error"
,
"message"
:
f
"LoRA adapter '
{
lora_name
}
' not found. Available LoRAs:
{
list
(
self
.
lo
ra_i
d_
f
or
_name
.
keys
())
}
"
,
"message"
:
f
"LoRA adapter '
{
lora_name
}
' not found. Available LoRAs:
{
list
(
self
.
lo
ade
d_
l
or
as
.
keys
())
}
"
,
}
}
return
return
logger
.
debug
(
f
"Unloading LoRA adapter:
{
lora_name
}
"
)
logger
.
debug
(
f
"Unloading LoRA adapter:
{
lora_name
}
"
)
lora_id
=
self
.
lora
_
id
_for_name
[
lora_name
]
lora_id
=
lora
.
id
lora_path
=
self
.
lora_name_to_path
.
get
(
lora_name
)
lora_path
=
lora
.
path
await
self
.
engine_client
.
remove_lora
(
lora_id
)
await
self
.
engine_client
.
remove_lora
(
lora_id
)
# Remove from tracking dictionaries
# Remove from tracking
del
self
.
lora_id_for_name
[
lora_name
]
del
self
.
loaded_loras
[
lora_name
]
if
lora_name
in
self
.
lora_name_to_path
:
del
self
.
lora_name_to_path
[
lora_name
]
# Unregister the LoRA model from the model registry
# Unregister the LoRA model from the model registry
if
self
.
generate_endpoint
is
not
None
:
if
self
.
generate_endpoint
is
not
None
:
...
@@ -734,11 +746,6 @@ class BaseWorkerHandler(ABC):
...
@@ -734,11 +746,6 @@ class BaseWorkerHandler(ABC):
)
)
# Rollback: re-add the LoRA to the engine to maintain consistency
# Rollback: re-add the LoRA to the engine to maintain consistency
if
lora_path
is
None
:
logger
.
error
(
f
"Cannot rollback LoRA '
{
lora_name
}
': lora_path is None (data inconsistency)"
)
else
:
try
:
try
:
logger
.
debug
(
logger
.
debug
(
f
"Rolling back: re-adding LoRA '
{
lora_name
}
' to engine"
f
"Rolling back: re-adding LoRA '
{
lora_name
}
' to engine"
...
@@ -750,9 +757,10 @@ class BaseWorkerHandler(ABC):
...
@@ -750,9 +757,10 @@ class BaseWorkerHandler(ABC):
lora_path
=
lora_path
,
lora_path
=
lora_path
,
)
)
)
)
# Re-add to tracking dictionaries
# Re-add to tracking
self
.
lora_id_for_name
[
lora_name
]
=
lora_id
self
.
loaded_loras
[
lora_name
]
=
LoRAInfo
(
self
.
lora_name_to_path
[
lora_name
]
=
lora_path
id
=
lora_id
,
path
=
lora_path
)
logger
.
debug
(
logger
.
debug
(
f
"Successfully rolled back LoRA '
{
lora_name
}
'"
f
"Successfully rolled back LoRA '
{
lora_name
}
'"
)
)
...
@@ -786,7 +794,7 @@ class BaseWorkerHandler(ABC):
...
@@ -786,7 +794,7 @@ class BaseWorkerHandler(ABC):
# Remove lock entry once the LoRA is not loaded (or never was).
# Remove lock entry once the LoRA is not loaded (or never was).
with
self
.
_lora_load_locks_guard
:
with
self
.
_lora_load_locks_guard
:
if
(
if
(
lora_name
not
in
self
.
lo
ra_i
d_
f
or
_name
lora_name
not
in
self
.
lo
ade
d_
l
or
as
and
self
.
_lora_load_locks
.
get
(
lora_name
)
is
lock
and
self
.
_lora_load_locks
.
get
(
lora_name
)
is
lock
):
):
self
.
_lora_load_locks
.
pop
(
lora_name
,
None
)
self
.
_lora_load_locks
.
pop
(
lora_name
,
None
)
...
@@ -800,7 +808,7 @@ class BaseWorkerHandler(ABC):
...
@@ -800,7 +808,7 @@ class BaseWorkerHandler(ABC):
Returns a dictionary of lora_name -> lora_id mappings.
Returns a dictionary of lora_name -> lora_id mappings.
"""
"""
try
:
try
:
loras
=
dict
(
self
.
lo
ra_i
d_
f
or
_name
)
loras
=
{
name
:
lora
.
id
for
name
,
lora
in
self
.
lo
ade
d_
l
or
as
.
items
()}
yield
{
yield
{
"status"
:
"success"
,
"status"
:
"success"
,
"loras"
:
loras
,
"loras"
:
loras
,
...
@@ -1354,19 +1362,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -1354,19 +1362,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
)
)
# Extract LoRA request if present
# Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request
=
None
model_name
=
request
.
get
(
"model"
)
model_name
=
request
.
get
(
"model"
)
lora_request
=
self
.
_resolve_lora_request
(
model_name
)
if
model_name
and
model_name
in
self
.
lora_id_for_name
:
if
lora_request
:
lora_id
=
self
.
lora_id_for_name
[
model_name
]
lora_request
=
LoRARequest
(
lora_name
=
model_name
,
lora_int_id
=
lora_id
,
lora_path
=
self
.
lora_name_to_path
[
model_name
],
)
logger
.
info
(
logger
.
info
(
f
"Decode request
{
request_id
}
will use LoRA adapter:
{
model_name
}
(ID:
{
lora_id
}
)"
f
"Decode request
{
request_id
}
will use LoRA adapter:
{
model_name
}
(ID:
{
lora_
request
.
lora_int_
id
}
)"
)
)
else
:
else
:
logger
.
debug
(
logger
.
debug
(
...
@@ -1570,20 +1570,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -1570,20 +1570,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params
.
min_tokens
=
1
sampling_params
.
min_tokens
=
1
# Extract LoRA request if present
# Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request
=
None
model_name
=
request
.
get
(
"model"
)
model_name
=
request
.
get
(
"model"
)
lora_request
=
self
.
_resolve_lora_request
(
model_name
)
if
model_name
and
model_name
in
self
.
lora_id_for_name
:
if
lora_request
:
lora_id
=
self
.
lora_id_for_name
[
model_name
]
lora_request
=
LoRARequest
(
lora_name
=
model_name
,
lora_int_id
=
lora_id
,
lora_path
=
self
.
lora_name_to_path
[
model_name
],
)
logger
.
info
(
logger
.
info
(
f
"Prefill request
{
request_id
}
will use LoRA adapter:
{
model_name
}
(ID:
{
lora_id
}
),
"
f
"Prefill request
{
request_id
}
will use LoRA adapter:
{
model_name
}
"
f
"
path:
{
self
.
lora_name_to_path
[
model_name
]
}
"
f
"
(ID:
{
lora_request
.
lora_int_id
}
), path:
{
lora_request
.
lora_path
}
"
)
)
else
:
else
:
logger
.
debug
(
logger
.
debug
(
...
...
components/src/dynamo/vllm/main.py
View file @
026f361d
...
@@ -687,6 +687,9 @@ async def init(
...
@@ -687,6 +687,9 @@ async def init(
)
)
component
=
generate_endpoint
.
component
()
component
=
generate_endpoint
.
component
()
clear_endpoint
=
component
.
endpoint
(
"clear_kv_blocks"
)
clear_endpoint
=
component
.
endpoint
(
"clear_kv_blocks"
)
lora_enabled
=
config
.
engine_args
.
enable_lora
if
lora_enabled
:
load_lora_endpoint
=
component
.
endpoint
(
"load_lora"
)
load_lora_endpoint
=
component
.
endpoint
(
"load_lora"
)
unload_lora_endpoint
=
component
.
endpoint
(
"unload_lora"
)
unload_lora_endpoint
=
component
.
endpoint
(
"unload_lora"
)
list_loras_endpoint
=
component
.
endpoint
(
"list_loras"
)
list_loras_endpoint
=
component
.
endpoint
(
"list_loras"
)
...
@@ -812,13 +815,8 @@ async def init(
...
@@ -812,13 +815,8 @@ async def init(
try
:
try
:
logger
.
debug
(
"Starting serve_endpoint for decode worker"
)
logger
.
debug
(
"Starting serve_endpoint for decode worker"
)
await
asyncio
.
gather
(
# for decode, we want to transfer the in-flight requests to other decode engines,
model_metrics_labels
=
[
# because waiting them to finish can take a long time for long OSLs
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
True
,
metrics_labels
=
[
(
(
prometheus_names
.
labels
.
MODEL
,
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
config
.
served_model_name
or
config
.
model
,
...
@@ -827,62 +825,42 @@ async def init(
...
@@ -827,62 +825,42 @@ async def init(
prometheus_names
.
labels
.
MODEL_NAME
,
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
config
.
served_model_name
or
config
.
model
,
),
),
],
]
serve_tasks
=
[
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
graceful_shutdown
=
True
,
metrics_labels
=
model_metrics_labels
,
health_check_payload
=
health_check_payload
,
health_check_payload
=
health_check_payload
,
),
),
clear_endpoint
.
serve_endpoint
(
clear_endpoint
.
serve_endpoint
(
handler
.
clear_kv_blocks
,
handler
.
clear_kv_blocks
,
metrics_labels
=
[
metrics_labels
=
model_metrics_labels
,
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
),
],
),
),
]
if
lora_enabled
:
serve_tasks
.
extend
(
[
load_lora_endpoint
.
serve_endpoint
(
load_lora_endpoint
.
serve_endpoint
(
handler
.
load_lora
,
handler
.
load_lora
,
metrics_labels
=
[
metrics_labels
=
model_metrics_labels
,
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
),
],
),
),
unload_lora_endpoint
.
serve_endpoint
(
unload_lora_endpoint
.
serve_endpoint
(
handler
.
unload_lora
,
handler
.
unload_lora
,
metrics_labels
=
[
metrics_labels
=
model_metrics_labels
,
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
),
],
),
),
list_loras_endpoint
.
serve_endpoint
(
list_loras_endpoint
.
serve_endpoint
(
handler
.
list_loras
,
handler
.
list_loras
,
metrics_labels
=
[
metrics_labels
=
model_metrics_labels
,
(
prometheus_names
.
labels
.
MODEL
,
config
.
served_model_name
or
config
.
model
,
),
(
prometheus_names
.
labels
.
MODEL_NAME
,
config
.
served_model_name
or
config
.
model
,
),
],
),
),
]
)
)
await
asyncio
.
gather
(
*
serve_tasks
)
logger
.
debug
(
"serve_endpoint completed for decode worker"
)
logger
.
debug
(
"serve_endpoint completed for decode worker"
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
logger
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
...
...
components/src/dynamo/vllm/multimodal_handlers/multimodal_pd_worker_handler.py
View file @
026f361d
...
@@ -56,6 +56,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -56,6 +56,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
encode_worker_client
:
Client
|
None
=
None
,
encode_worker_client
:
Client
|
None
=
None
,
decode_worker_client
:
Client
|
None
=
None
,
decode_worker_client
:
Client
|
None
=
None
,
shutdown_event
=
None
,
shutdown_event
=
None
,
generate_endpoint
=
None
,
):
):
# Get default_sampling_params from config
# Get default_sampling_params from config
default_sampling_params
=
(
default_sampling_params
=
(
...
@@ -69,6 +70,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -69,6 +70,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client
,
engine_client
,
default_sampling_params
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
enable_multimodal
=
config
.
enable_multimodal
,
generate_endpoint
=
generate_endpoint
,
config
=
config
,
shutdown_event
=
shutdown_event
,
shutdown_event
=
shutdown_event
,
)
)
...
@@ -318,6 +321,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -318,6 +321,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
received_tensor_ids
:
list
[
int
],
received_tensor_ids
:
list
[
int
],
):
):
"""Run prefill and decode on this worker (aggregated mode)."""
"""Run prefill and decode on this worker (aggregated mode)."""
lora_request
=
self
.
_resolve_lora_request
(
request
.
model
)
gen
=
self
.
engine_client
.
generate
(
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
],
prompt_token_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
],
...
@@ -325,6 +329,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -325,6 +329,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
),
),
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
lora_request
=
lora_request
,
)
)
for
tensor_id
in
received_tensor_ids
:
for
tensor_id
in
received_tensor_ids
:
...
@@ -358,6 +363,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -358,6 +363,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
prefill_only_request
.
sampling_params
.
min_tokens
=
1
prefill_only_request
.
sampling_params
.
min_tokens
=
1
logger
.
debug
(
"Prefill request: %s"
,
prefill_only_request
)
logger
.
debug
(
"Prefill request: %s"
,
prefill_only_request
)
lora_request
=
self
.
_resolve_lora_request
(
request
.
model
)
gen
=
self
.
engine_client
.
generate
(
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
prefill_only_request
.
engine_prompt
[
"prompt_token_ids"
],
prompt_token_ids
=
prefill_only_request
.
engine_prompt
[
"prompt_token_ids"
],
...
@@ -365,6 +371,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -365,6 +371,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
),
),
sampling_params
=
prefill_only_request
.
sampling_params
,
sampling_params
=
prefill_only_request
.
sampling_params
,
request_id
=
prefill_only_request
.
request_id
,
request_id
=
prefill_only_request
.
request_id
,
lora_request
=
lora_request
,
)
)
for
tensor_id
in
received_tensor_ids
:
for
tensor_id
in
received_tensor_ids
:
...
@@ -400,6 +407,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -400,6 +407,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# embeddings_shape). Heavy multimodal data was consumed locally by
# embeddings_shape). Heavy multimodal data was consumed locally by
# engine_client.generate() and multimodal_inputs was cleared by
# engine_client.generate() and multimodal_inputs was cleared by
# `_finalize_request_metadata`.
# `_finalize_request_metadata`.
#
# request.model (LoRA name) is preserved in the serialized request
# so the decode worker can resolve the same LoRA adapter.
if
lora_request
and
request
.
model
:
logger
.
debug
(
f
"Forwarding disaggregated decode with LoRA '
{
request
.
model
}
' "
f
"— ensure the same adapter is loaded on the decode worker."
)
async
for
(
async
for
(
decode_response
decode_response
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore[union-attr]
)
in
await
self
.
decode_worker_client
.
round_robin
(
# type: ignore[union-attr]
...
...
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
View file @
026f361d
...
@@ -26,6 +26,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -26,6 +26,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client
,
engine_client
,
config
:
Config
,
config
:
Config
,
shutdown_event
=
None
,
shutdown_event
=
None
,
generate_endpoint
=
None
,
):
):
# Get default_sampling_params from config
# Get default_sampling_params from config
default_sampling_params
=
(
default_sampling_params
=
(
...
@@ -39,6 +40,8 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -39,6 +40,8 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client
,
engine_client
,
default_sampling_params
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
enable_multimodal
=
config
.
enable_multimodal
,
generate_endpoint
=
generate_endpoint
,
config
=
config
,
shutdown_event
=
shutdown_event
,
shutdown_event
=
shutdown_event
,
)
)
...
@@ -82,6 +85,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -82,6 +85,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
image_grid_thw
,
embeddings_shape
,
request
.
request_id
image_grid_thw
,
embeddings_shape
,
request
.
request_id
)
)
lora_request
=
self
.
_resolve_lora_request
(
request
.
model
)
gen
=
self
.
engine_client
.
generate
(
gen
=
self
.
engine_client
.
generate
(
prompt
=
TokensPrompt
(
prompt
=
TokensPrompt
(
prompt_token_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
],
prompt_token_ids
=
request
.
engine_prompt
[
"prompt_token_ids"
],
...
@@ -89,6 +93,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -89,6 +93,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
),
),
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
lora_request
=
lora_request
,
)
)
async
for
response
in
gen
:
async
for
response
in
gen
:
...
...
components/src/dynamo/vllm/worker_factory.py
View file @
026f361d
...
@@ -94,6 +94,12 @@ class WorkerFactory:
...
@@ -94,6 +94,12 @@ class WorkerFactory:
component
=
generate_endpoint
.
component
()
component
=
generate_endpoint
.
component
()
clear_endpoint
=
component
.
endpoint
(
"clear_kv_blocks"
)
clear_endpoint
=
component
.
endpoint
(
"clear_kv_blocks"
)
lora_enabled
=
config
.
engine_args
.
enable_lora
if
lora_enabled
:
load_lora_endpoint
=
component
.
endpoint
(
"load_lora"
)
unload_lora_endpoint
=
component
.
endpoint
(
"unload_lora"
)
list_loras_endpoint
=
component
.
endpoint
(
"list_loras"
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if
pre_created_engine
is
not
None
:
if
pre_created_engine
is
not
None
:
(
(
...
@@ -134,7 +140,12 @@ class WorkerFactory:
...
@@ -134,7 +140,12 @@ class WorkerFactory:
# Choose handler based on worker type
# Choose handler based on worker type
if
config
.
multimodal_decode_worker
:
if
config
.
multimodal_decode_worker
:
handler
=
MultimodalDecodeWorkerHandler
(
handler
=
MultimodalDecodeWorkerHandler
(
runtime
,
component
,
engine_client
,
config
,
shutdown_event
runtime
,
component
,
engine_client
,
config
,
shutdown_event
,
generate_endpoint
=
generate_endpoint
,
)
)
else
:
else
:
handler
=
MultimodalPDWorkerHandler
(
handler
=
MultimodalPDWorkerHandler
(
...
@@ -145,6 +156,7 @@ class WorkerFactory:
...
@@ -145,6 +156,7 @@ class WorkerFactory:
encode_worker_client
,
encode_worker_client
,
decode_worker_client
,
decode_worker_client
,
shutdown_event
,
shutdown_event
,
generate_endpoint
=
generate_endpoint
,
)
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
...
@@ -173,7 +185,7 @@ class WorkerFactory:
...
@@ -173,7 +185,7 @@ class WorkerFactory:
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)]
metrics_labels
=
[(
"model"
,
config
.
served_model_name
or
config
.
model
)]
try
:
try
:
await
asyncio
.
gather
(
serve_tasks
=
[
generate_endpoint
.
serve_endpoint
(
generate_endpoint
.
serve_endpoint
(
handler
.
generate
,
handler
.
generate
,
metrics_labels
=
metrics_labels
,
metrics_labels
=
metrics_labels
,
...
@@ -182,7 +194,27 @@ class WorkerFactory:
...
@@ -182,7 +194,27 @@ class WorkerFactory:
handler
.
clear_kv_blocks
,
handler
.
clear_kv_blocks
,
metrics_labels
=
metrics_labels
,
metrics_labels
=
metrics_labels
,
),
),
]
if
lora_enabled
:
serve_tasks
.
extend
(
[
load_lora_endpoint
.
serve_endpoint
(
handler
.
load_lora
,
metrics_labels
=
metrics_labels
,
),
unload_lora_endpoint
.
serve_endpoint
(
handler
.
unload_lora
,
metrics_labels
=
metrics_labels
,
),
list_loras_endpoint
.
serve_endpoint
(
handler
.
list_loras
,
metrics_labels
=
metrics_labels
,
),
]
)
)
await
asyncio
.
gather
(
*
serve_tasks
)
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
logger
.
error
(
f
"Failed to serve endpoints:
{
e
}
"
)
raise
raise
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment