Unverified Commit 07721d1c authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: handle concurrent load lora calls (#5184)

parent d9cc6f6b
...@@ -5,6 +5,7 @@ import asyncio ...@@ -5,6 +5,7 @@ import asyncio
import logging import logging
import os import os
import tempfile import tempfile
import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
...@@ -254,6 +255,10 @@ class BaseWorkerHandler(ABC): ...@@ -254,6 +255,10 @@ class BaseWorkerHandler(ABC):
# LoRA tracking # LoRA tracking
self.lora_id_for_name: dict[str, int] = {} self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {} self.lora_name_to_path: dict[str, str] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
self._lora_load_locks: dict[str, asyncio.Lock] = {}
# Guard lock-map access in case handlers are invoked from multiple threads.
self._lora_load_locks_guard = threading.Lock()
self.use_vllm_tokenizer = use_vllm_tokenizer self.use_vllm_tokenizer = use_vllm_tokenizer
...@@ -309,6 +314,15 @@ class BaseWorkerHandler(ABC): ...@@ -309,6 +314,15 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None: if temp_dir is not None:
self.temp_dirs.append(temp_dir) self.temp_dirs.append(temp_dir)
def _get_lora_lock(self, lora_name: str) -> asyncio.Lock:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with self._lora_load_locks_guard:
lock = self._lora_load_locks.get(lora_name)
if lock is None:
lock = asyncio.Lock()
self._lora_load_locks[lora_name] = lock
return lock
async def load_lora(self, request=None): async def load_lora(self, request=None):
""" """
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine. Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
...@@ -320,6 +334,9 @@ class BaseWorkerHandler(ABC): ...@@ -320,6 +334,9 @@ class BaseWorkerHandler(ABC):
"uri": str # e.g., "s3://bucket/path" or "file:///path" "uri": str # e.g., "s3://bucket/path" or "file:///path"
} }
} }
This method is idempotent - concurrent calls for the same LoRA will be
serialized and only one load operation will happen.
""" """
try: try:
if request is None: if request is None:
...@@ -367,7 +384,29 @@ class BaseWorkerHandler(ABC): ...@@ -367,7 +384,29 @@ class BaseWorkerHandler(ABC):
} }
return return
logger.info(f"Downloading LoRA adapter: {lora_name} from {lora_uri}") # Serialize load/unload operations per lora_name.
lock = self._get_lora_lock(lora_name)
async with lock:
try:
# Check if already loaded (idempotency check after acquiring lock).
# Another concurrent request may have loaded this LoRA while we waited.
if lora_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[lora_name]
logger.info(
f"LoRA adapter already loaded (concurrent request completed): "
f"{lora_name} with ID {lora_id}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' already loaded",
"lora_name": lora_name,
"lora_id": lora_id,
}
return
logger.info(
f"Downloading LoRA adapter: {lora_name} from {lora_uri}"
)
download_result = await lora_manager.download_lora(lora_uri) download_result = await lora_manager.download_lora(lora_uri)
if download_result["status"] != "success": if download_result["status"] != "success":
...@@ -386,7 +425,9 @@ class BaseWorkerHandler(ABC): ...@@ -386,7 +425,9 @@ class BaseWorkerHandler(ABC):
# Add the LoRA to the engine # Add the LoRA to the engine
await self.engine_client.add_lora( await self.engine_client.add_lora(
LoRARequest( LoRARequest(
lora_name=lora_name, lora_int_id=lora_id, lora_path=lora_path lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
) )
) )
...@@ -400,13 +441,14 @@ class BaseWorkerHandler(ABC): ...@@ -400,13 +441,14 @@ class BaseWorkerHandler(ABC):
# Publish LoRA as a ModelDeploymentCard with format: # Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug} # v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance # This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None: if self.generate_endpoint is not None and self.config is not None:
logger.debug( logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}" f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
) )
try: try:
logger.debug(f"Publishing LoRA '{lora_name}' ModelDeploymentCard") logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard"
)
# Mark this as a LoRA in user_data # Mark this as a LoRA in user_data
user_data = { user_data = {
...@@ -429,12 +471,9 @@ class BaseWorkerHandler(ABC): ...@@ -429,12 +471,9 @@ class BaseWorkerHandler(ABC):
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard" f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
) )
except Exception as e: except Exception as e:
import traceback logger.exception(
logger.error(
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}" f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
) )
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: remove the LoRA from the engine to maintain consistency # Rollback: remove the LoRA from the engine to maintain consistency
try: try:
...@@ -447,9 +486,11 @@ class BaseWorkerHandler(ABC): ...@@ -447,9 +486,11 @@ class BaseWorkerHandler(ABC):
del self.lora_id_for_name[lora_name] del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path: if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name] del self.lora_name_to_path[lora_name]
logger.debug(f"Successfully rolled back LoRA '{lora_name}'") logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error: except Exception as rollback_error:
logger.error( logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}" f"Failed to rollback LoRA {lora_name}: {rollback_error}"
) )
...@@ -471,8 +512,17 @@ class BaseWorkerHandler(ABC): ...@@ -471,8 +512,17 @@ class BaseWorkerHandler(ABC):
"lora_name": lora_name, "lora_name": lora_name,
"lora_id": lora_id, "lora_id": lora_id,
} }
finally:
# Avoid lock-map growth on failed loads: if this attempt did not leave the LoRA
# loaded, remove the lock entry (best-effort).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
except Exception as e: except Exception as e:
logger.error(f"Failed to load LoRA adapter: {e}") logger.exception(f"Failed to load LoRA adapter: {e}")
yield {"status": "error", "message": str(e)} yield {"status": "error", "message": str(e)}
async def unload_lora(self, request=None): async def unload_lora(self, request=None):
...@@ -498,7 +548,11 @@ class BaseWorkerHandler(ABC): ...@@ -498,7 +548,11 @@ class BaseWorkerHandler(ABC):
} }
return return
# Check if the LoRA exists # Serialize load/unload operations per lora_name.
lock = self._get_lora_lock(lora_name)
async with lock:
try:
# Check if the LoRA exists *after* waiting for any in-progress load.
if lora_name not in self.lora_id_for_name: if lora_name not in self.lora_id_for_name:
yield { yield {
"status": "error", "status": "error",
...@@ -517,9 +571,11 @@ class BaseWorkerHandler(ABC): ...@@ -517,9 +571,11 @@ class BaseWorkerHandler(ABC):
if lora_name in self.lora_name_to_path: if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name] del self.lora_name_to_path[lora_name]
# Unregister the LoRA model from the model registry (outside lock) # Unregister the LoRA model from the model registry
if self.generate_endpoint is not None: if self.generate_endpoint is not None:
logger.debug(f"Unregistering LoRA '{lora_name}' ModelDeploymentCard") logger.debug(
f"Unregistering LoRA '{lora_name}' ModelDeploymentCard"
)
try: try:
await unregister_llm( await unregister_llm(
endpoint=self.generate_endpoint, endpoint=self.generate_endpoint,
...@@ -529,14 +585,16 @@ class BaseWorkerHandler(ABC): ...@@ -529,14 +585,16 @@ class BaseWorkerHandler(ABC):
f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard" f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
) )
except Exception as e: except Exception as e:
import traceback logger.exception(
logger.error(
f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}" f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
) )
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: re-add the LoRA to the engine to maintain consistency # Rollback: re-add the LoRA to the engine to maintain consistency
if lora_path is None:
logger.error(
f"Cannot rollback LoRA '{lora_name}': lora_path is None (data inconsistency)"
)
else:
try: try:
logger.debug( logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine" f"Rolling back: re-adding LoRA '{lora_name}' to engine"
...@@ -550,11 +608,12 @@ class BaseWorkerHandler(ABC): ...@@ -550,11 +608,12 @@ class BaseWorkerHandler(ABC):
) )
# Re-add to tracking dictionaries # Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id self.lora_id_for_name[lora_name] = lora_id
if lora_path:
self.lora_name_to_path[lora_name] = lora_path self.lora_name_to_path[lora_name] = lora_path
logger.debug(f"Successfully rolled back LoRA '{lora_name}'") logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error: except Exception as rollback_error:
logger.error( logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}" f"Failed to rollback LoRA {lora_name}: {rollback_error}"
) )
...@@ -579,8 +638,16 @@ class BaseWorkerHandler(ABC): ...@@ -579,8 +638,16 @@ class BaseWorkerHandler(ABC):
"lora_name": lora_name, "lora_name": lora_name,
"lora_id": lora_id, "lora_id": lora_id,
} }
finally:
# Remove lock entry once the LoRA is not loaded (or never was).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
except Exception as e: except Exception as e:
logger.error(f"Failed to unload LoRA adapter: {e}") logger.exception(f"Failed to unload LoRA adapter: {e}")
yield {"status": "error", "message": str(e)} yield {"status": "error", "message": str(e)}
async def list_loras(self, request=None): async def list_loras(self, request=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment