"deploy/vscode:/vscode.git/clone" did not exist on "5d90e530bc4ff683a779b2bc0b9237cfcc2504fd"
Unverified Commit 07721d1c authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: handle concurrent load lora calls (#5184)

parent d9cc6f6b
......@@ -5,6 +5,7 @@ import asyncio
import logging
import os
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
......@@ -254,6 +255,10 @@ class BaseWorkerHandler(ABC):
# LoRA tracking
self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
self._lora_load_locks: dict[str, asyncio.Lock] = {}
# Guard lock-map access in case handlers are invoked from multiple threads.
self._lora_load_locks_guard = threading.Lock()
self.use_vllm_tokenizer = use_vllm_tokenizer
......@@ -309,6 +314,15 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None:
self.temp_dirs.append(temp_dir)
def _get_lora_lock(self, lora_name: str) -> asyncio.Lock:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with self._lora_load_locks_guard:
lock = self._lora_load_locks.get(lora_name)
if lock is None:
lock = asyncio.Lock()
self._lora_load_locks[lora_name] = lock
return lock
async def load_lora(self, request=None):
"""
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
......@@ -320,6 +334,9 @@ class BaseWorkerHandler(ABC):
"uri": str # e.g., "s3://bucket/path" or "file:///path"
}
}
This method is idempotent - concurrent calls for the same LoRA will be
serialized and only one load operation will happen.
"""
try:
if request is None:
......@@ -367,7 +384,29 @@ class BaseWorkerHandler(ABC):
}
return
logger.info(f"Downloading LoRA adapter: {lora_name} from {lora_uri}")
# Serialize load/unload operations per lora_name.
lock = self._get_lora_lock(lora_name)
async with lock:
try:
# Check if already loaded (idempotency check after acquiring lock).
# Another concurrent request may have loaded this LoRA while we waited.
if lora_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[lora_name]
logger.info(
f"LoRA adapter already loaded (concurrent request completed): "
f"{lora_name} with ID {lora_id}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' already loaded",
"lora_name": lora_name,
"lora_id": lora_id,
}
return
logger.info(
f"Downloading LoRA adapter: {lora_name} from {lora_uri}"
)
download_result = await lora_manager.download_lora(lora_uri)
if download_result["status"] != "success":
......@@ -386,7 +425,9 @@ class BaseWorkerHandler(ABC):
# Add the LoRA to the engine
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name, lora_int_id=lora_id, lora_path=lora_path
lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
)
)
......@@ -400,13 +441,14 @@ class BaseWorkerHandler(ABC):
# Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None:
logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
)
try:
logger.debug(f"Publishing LoRA '{lora_name}' ModelDeploymentCard")
logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard"
)
# Mark this as a LoRA in user_data
user_data = {
......@@ -429,12 +471,9 @@ class BaseWorkerHandler(ABC):
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
import traceback
logger.error(
logger.exception(
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
)
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: remove the LoRA from the engine to maintain consistency
try:
......@@ -447,9 +486,11 @@ class BaseWorkerHandler(ABC):
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error:
logger.error(
logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
......@@ -471,8 +512,17 @@ class BaseWorkerHandler(ABC):
"lora_name": lora_name,
"lora_id": lora_id,
}
finally:
# Avoid lock-map growth on failed loads: if this attempt did not leave the LoRA
# loaded, remove the lock entry (best-effort).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
except Exception as e:
logger.error(f"Failed to load LoRA adapter: {e}")
logger.exception(f"Failed to load LoRA adapter: {e}")
yield {"status": "error", "message": str(e)}
async def unload_lora(self, request=None):
......@@ -498,7 +548,11 @@ class BaseWorkerHandler(ABC):
}
return
# Check if the LoRA exists
# Serialize load/unload operations per lora_name.
lock = self._get_lora_lock(lora_name)
async with lock:
try:
# Check if the LoRA exists *after* waiting for any in-progress load.
if lora_name not in self.lora_id_for_name:
yield {
"status": "error",
......@@ -517,9 +571,11 @@ class BaseWorkerHandler(ABC):
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
# Unregister the LoRA model from the model registry (outside lock)
# Unregister the LoRA model from the model registry
if self.generate_endpoint is not None:
logger.debug(f"Unregistering LoRA '{lora_name}' ModelDeploymentCard")
logger.debug(
f"Unregistering LoRA '{lora_name}' ModelDeploymentCard"
)
try:
await unregister_llm(
endpoint=self.generate_endpoint,
......@@ -529,14 +585,16 @@ class BaseWorkerHandler(ABC):
f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
import traceback
logger.error(
logger.exception(
f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
)
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: re-add the LoRA to the engine to maintain consistency
if lora_path is None:
logger.error(
f"Cannot rollback LoRA '{lora_name}': lora_path is None (data inconsistency)"
)
else:
try:
logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine"
......@@ -550,11 +608,12 @@ class BaseWorkerHandler(ABC):
)
# Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id
if lora_path:
self.lora_name_to_path[lora_name] = lora_path
logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error:
logger.error(
logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
......@@ -579,8 +638,16 @@ class BaseWorkerHandler(ABC):
"lora_name": lora_name,
"lora_id": lora_id,
}
finally:
# Remove lock entry once the LoRA is not loaded (or never was).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
except Exception as e:
logger.error(f"Failed to unload LoRA adapter: {e}")
logger.exception(f"Failed to unload LoRA adapter: {e}")
yield {"status": "error", "message": str(e)}
async def list_loras(self, request=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment