Unverified Commit 026f361d authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: resolve lora request for multimodal workers (#6399)

parent bc320806
......@@ -12,6 +12,7 @@ import threading
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Final
import torch
......@@ -50,6 +51,14 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class LoRAInfo:
"""Metadata for a loaded LoRA adapter."""
id: int
path: str
def _compute_mm_uuids(
multi_modal_data: Dict[str, Any] | None
) -> Dict[str, list[str]] | None:
......@@ -287,9 +296,8 @@ class BaseWorkerHandler(ABC):
# NIXL connector for frontend decoding - lazy initialized
self._nixl_connector = None
self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking
self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {}
# LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
self._lora_load_locks: dict[str, asyncio.Lock] = {}
# Guard lock-map access in case handlers are invoked from multiple threads.
......@@ -458,6 +466,16 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None:
self.temp_dirs.append(temp_dir)
def _resolve_lora_request(self, model_name: str | None) -> LoRARequest | None:
"""Return a LoRARequest if model_name is a loaded adapter, else None."""
if model_name and (lora := self.loaded_loras.get(model_name)):
return LoRARequest(
lora_name=model_name,
lora_int_id=lora.id,
lora_path=lora.path,
)
return None
def _get_lora_lock(self, lora_name: str) -> asyncio.Lock:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with self._lora_load_locks_guard:
......@@ -534,8 +552,8 @@ class BaseWorkerHandler(ABC):
try:
# Check if already loaded (idempotency check after acquiring lock).
# Another concurrent request may have loaded this LoRA while we waited.
if lora_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[lora_name]
if lora_name in self.loaded_loras:
lora_id = self.loaded_loras[lora_name].id
logger.info(
f"LoRA adapter already loaded (concurrent request completed): "
f"{lora_name} with ID {lora_id}"
......@@ -576,8 +594,7 @@ class BaseWorkerHandler(ABC):
)
# Track the LoRA
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
self.loaded_loras[lora_name] = LoRAInfo(id=lora_id, path=lora_path)
logger.info(
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
)
......@@ -625,11 +642,7 @@ class BaseWorkerHandler(ABC):
f"Rolling back: removing LoRA '{lora_name}' from engine"
)
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
if lora_name in self.lora_id_for_name:
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
self.loaded_loras.pop(lora_name, None)
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
......@@ -661,7 +674,7 @@ class BaseWorkerHandler(ABC):
# loaded, remove the lock entry (best-effort).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
lora_name not in self.loaded_loras
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
......@@ -697,23 +710,22 @@ class BaseWorkerHandler(ABC):
async with lock:
try:
# Check if the LoRA exists *after* waiting for any in-progress load.
if lora_name not in self.lora_id_for_name:
lora = self.loaded_loras.get(lora_name)
if lora is None:
yield {
"status": "error",
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}",
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.loaded_loras.keys())}",
}
return
logger.debug(f"Unloading LoRA adapter: {lora_name}")
lora_id = self.lora_id_for_name[lora_name]
lora_path = self.lora_name_to_path.get(lora_name)
lora_id = lora.id
lora_path = lora.path
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
# Remove from tracking
del self.loaded_loras[lora_name]
# Unregister the LoRA model from the model registry
if self.generate_endpoint is not None:
......@@ -734,11 +746,6 @@ class BaseWorkerHandler(ABC):
)
# Rollback: re-add the LoRA to the engine to maintain consistency
if lora_path is None:
logger.error(
f"Cannot rollback LoRA '{lora_name}': lora_path is None (data inconsistency)"
)
else:
try:
logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine"
......@@ -750,9 +757,10 @@ class BaseWorkerHandler(ABC):
lora_path=lora_path,
)
)
# Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
# Re-add to tracking
self.loaded_loras[lora_name] = LoRAInfo(
id=lora_id, path=lora_path
)
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
......@@ -786,7 +794,7 @@ class BaseWorkerHandler(ABC):
# Remove lock entry once the LoRA is not loaded (or never was).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
lora_name not in self.loaded_loras
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
......@@ -800,7 +808,7 @@ class BaseWorkerHandler(ABC):
Returns a dictionary of lora_name -> lora_id mappings.
"""
try:
loras = dict(self.lora_id_for_name)
loras = {name: lora.id for name, lora in self.loaded_loras.items()}
yield {
"status": "success",
"loras": loras,
......@@ -1354,19 +1362,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
)
# Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request = None
model_name = request.get("model")
if model_name and model_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[model_name]
lora_request = LoRARequest(
lora_name=model_name,
lora_int_id=lora_id,
lora_path=self.lora_name_to_path[model_name],
)
lora_request = self._resolve_lora_request(model_name)
if lora_request:
logger.info(
f"Decode request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id})"
f"Decode request {request_id} will use LoRA adapter: {model_name} (ID: {lora_request.lora_int_id})"
)
else:
logger.debug(
......@@ -1570,20 +1570,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params.min_tokens = 1
# Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request = None
model_name = request.get("model")
if model_name and model_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[model_name]
lora_request = LoRARequest(
lora_name=model_name,
lora_int_id=lora_id,
lora_path=self.lora_name_to_path[model_name],
)
lora_request = self._resolve_lora_request(model_name)
if lora_request:
logger.info(
f"Prefill request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id}), "
f"path: {self.lora_name_to_path[model_name]}"
f"Prefill request {request_id} will use LoRA adapter: {model_name} "
f"(ID: {lora_request.lora_int_id}), path: {lora_request.lora_path}"
)
else:
logger.debug(
......
......@@ -687,6 +687,9 @@ async def init(
)
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
......@@ -812,13 +815,8 @@ async def init(
try:
logger.debug("Starting serve_endpoint for decode worker")
await asyncio.gather(
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[
model_metrics_labels = [
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
......@@ -827,62 +825,42 @@ async def init(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
]
serve_tasks = [
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
metrics_labels=model_metrics_labels,
),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
metrics_labels=model_metrics_labels,
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
metrics_labels=model_metrics_labels,
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=[
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
metrics_labels=model_metrics_labels,
),
]
)
await asyncio.gather(*serve_tasks)
logger.debug("serve_endpoint completed for decode worker")
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
......
......@@ -56,6 +56,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
encode_worker_client: Client | None = None,
decode_worker_client: Client | None = None,
shutdown_event=None,
generate_endpoint=None,
):
# Get default_sampling_params from config
default_sampling_params = (
......@@ -69,6 +70,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event,
)
......@@ -318,6 +321,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
received_tensor_ids: list[int],
):
"""Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
......@@ -325,6 +329,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
),
sampling_params=request.sampling_params,
request_id=request.request_id,
lora_request=lora_request,
)
for tensor_id in received_tensor_ids:
......@@ -358,6 +363,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
prefill_only_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt["prompt_token_ids"],
......@@ -365,6 +371,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
),
sampling_params=prefill_only_request.sampling_params,
request_id=prefill_only_request.request_id,
lora_request=lora_request,
)
for tensor_id in received_tensor_ids:
......@@ -400,6 +407,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# embeddings_shape). Heavy multimodal data was consumed locally by
# engine_client.generate() and multimodal_inputs was cleared by
# `_finalize_request_metadata`.
#
# request.model (LoRA name) is preserved in the serialized request
# so the decode worker can resolve the same LoRA adapter.
if lora_request and request.model:
logger.debug(
f"Forwarding disaggregated decode with LoRA '{request.model}' "
f"— ensure the same adapter is loaded on the decode worker."
)
async for (
decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
......
......@@ -26,6 +26,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client,
config: Config,
shutdown_event=None,
generate_endpoint=None,
):
# Get default_sampling_params from config
default_sampling_params = (
......@@ -39,6 +40,8 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client,
default_sampling_params,
enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event,
)
......@@ -82,6 +85,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
image_grid_thw, embeddings_shape, request.request_id
)
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate(
prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"],
......@@ -89,6 +93,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
),
sampling_params=request.sampling_params,
request_id=request.request_id,
lora_request=lora_request,
)
async for response in gen:
......
......@@ -94,6 +94,12 @@ class WorkerFactory:
component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks")
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
......@@ -134,7 +140,12 @@ class WorkerFactory:
# Choose handler based on worker type
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config, shutdown_event
runtime,
component,
engine_client,
config,
shutdown_event,
generate_endpoint=generate_endpoint,
)
else:
handler = MultimodalPDWorkerHandler(
......@@ -145,6 +156,7 @@ class WorkerFactory:
encode_worker_client,
decode_worker_client,
shutdown_event,
generate_endpoint=generate_endpoint,
)
handler.add_temp_dir(prometheus_temp_dir)
......@@ -173,7 +185,7 @@ class WorkerFactory:
metrics_labels = [("model", config.served_model_name or config.model)]
try:
await asyncio.gather(
serve_tasks = [
generate_endpoint.serve_endpoint(
handler.generate,
metrics_labels=metrics_labels,
......@@ -182,7 +194,27 @@ class WorkerFactory:
handler.clear_kv_blocks,
metrics_labels=metrics_labels,
),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=metrics_labels,
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=metrics_labels,
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=metrics_labels,
),
]
)
await asyncio.gather(*serve_tasks)
except Exception as e:
logger.error(f"Failed to serve endpoints: {e}")
raise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment