Unverified Commit 07721d1c authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: handle concurrent load lora calls (#5184)

parent d9cc6f6b
......@@ -5,6 +5,7 @@ import asyncio
import logging
import os
import tempfile
import threading
import time
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
......@@ -254,6 +255,10 @@ class BaseWorkerHandler(ABC):
# LoRA tracking
self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
self._lora_load_locks: dict[str, asyncio.Lock] = {}
# Guard lock-map access in case handlers are invoked from multiple threads.
self._lora_load_locks_guard = threading.Lock()
self.use_vllm_tokenizer = use_vllm_tokenizer
......@@ -309,6 +314,15 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None:
self.temp_dirs.append(temp_dir)
def _get_lora_lock(self, lora_name: str) -> asyncio.Lock:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with self._lora_load_locks_guard:
lock = self._lora_load_locks.get(lora_name)
if lock is None:
lock = asyncio.Lock()
self._lora_load_locks[lora_name] = lock
return lock
async def load_lora(self, request=None):
"""
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
......@@ -320,6 +334,9 @@ class BaseWorkerHandler(ABC):
"uri": str # e.g., "s3://bucket/path" or "file:///path"
}
}
This method is idempotent - concurrent calls for the same LoRA will be
serialized and only one load operation will happen.
"""
try:
if request is None:
......@@ -367,112 +384,145 @@ class BaseWorkerHandler(ABC):
}
return
logger.info(f"Downloading LoRA adapter: {lora_name} from {lora_uri}")
download_result = await lora_manager.download_lora(lora_uri)
if download_result["status"] != "success":
yield {
"status": "error",
"message": f"Failed to download LoRA: {download_result.get('message', 'Unknown error')}",
}
return
lora_path = download_result["local_path"]
logger.debug(f"LoRA downloaded to: {lora_path}")
# Generate deterministic ID from lora_name before using it
lora_id = lora_name_to_id(lora_name)
# Add the LoRA to the engine
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name, lora_int_id=lora_id, lora_path=lora_path
)
)
# Serialize load/unload operations per lora_name.
lock = self._get_lora_lock(lora_name)
async with lock:
try:
# Check if already loaded (idempotency check after acquiring lock).
# Another concurrent request may have loaded this LoRA while we waited.
if lora_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[lora_name]
logger.info(
f"LoRA adapter already loaded (concurrent request completed): "
f"{lora_name} with ID {lora_id}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' already loaded",
"lora_name": lora_name,
"lora_id": lora_id,
}
return
# Track the LoRA
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
logger.info(
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
)
logger.info(
f"Downloading LoRA adapter: {lora_name} from {lora_uri}"
)
download_result = await lora_manager.download_lora(lora_uri)
# Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance
if download_result["status"] != "success":
yield {
"status": "error",
"message": f"Failed to download LoRA: {download_result.get('message', 'Unknown error')}",
}
return
if self.generate_endpoint is not None and self.config is not None:
logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
)
try:
logger.debug(f"Publishing LoRA '{lora_name}' ModelDeploymentCard")
lora_path = download_result["local_path"]
logger.debug(f"LoRA downloaded to: {lora_path}")
# Mark this as a LoRA in user_data
user_data = {
"lora_adapter": True,
"lora_id": lora_id,
}
# Generate deterministic ID from lora_name before using it
lora_id = lora_name_to_id(lora_name)
# Publish with format: v1/mdc/dynamo/backend/generate/{instance_id}/{lora_slug}
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=self.generate_endpoint,
model_path=self.config.model,
kv_cache_block_size=self.config.engine_args.block_size,
user_data=user_data,
lora_name=lora_name,
base_model_path=self.config.model,
)
logger.info(
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
# Add the LoRA to the engine
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
)
)
except Exception as e:
import traceback
logger.error(
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
# Track the LoRA
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
logger.info(
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
)
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: remove the LoRA from the engine to maintain consistency
try:
# Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None:
logger.debug(
f"Rolling back: removing LoRA '{lora_name}' from engine"
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
)
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
if lora_name in self.lora_id_for_name:
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
except Exception as rollback_error:
logger.error(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
try:
logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard"
)
# Mark this as a LoRA in user_data
user_data = {
"lora_adapter": True,
"lora_id": lora_id,
}
# Publish with format: v1/mdc/dynamo/backend/generate/{instance_id}/{lora_slug}
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=self.generate_endpoint,
model_path=self.config.model,
kv_cache_block_size=self.config.engine_args.block_size,
user_data=user_data,
lora_name=lora_name,
base_model_path=self.config.model,
)
logger.info(
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
logger.exception(
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
)
# Rollback: remove the LoRA from the engine to maintain consistency
try:
logger.debug(
f"Rolling back: removing LoRA '{lora_name}' from engine"
)
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
if lora_name in self.lora_id_for_name:
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error:
logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
# Return error status since registration failed
yield {
"status": "error",
"message": f"Failed to register LoRA '{lora_name}' in discovery registry: {str(e)}",
"lora_name": lora_name,
}
return
else:
logger.debug(
f"Cannot publish LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}, config={self.config}"
)
# Return error status since registration failed
yield {
"status": "error",
"message": f"Failed to register LoRA '{lora_name}' in discovery registry: {str(e)}",
"status": "success",
"message": f"LoRA adapter '{lora_name}' loaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
return
else:
logger.debug(
f"Cannot publish LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}, config={self.config}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' loaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
finally:
# Avoid lock-map growth on failed loads: if this attempt did not leave the LoRA
# loaded, remove the lock entry (best-effort).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
except Exception as e:
logger.error(f"Failed to load LoRA adapter: {e}")
logger.exception(f"Failed to load LoRA adapter: {e}")
yield {"status": "error", "message": str(e)}
async def unload_lora(self, request=None):
......@@ -498,89 +548,106 @@ class BaseWorkerHandler(ABC):
}
return
# Check if the LoRA exists
if lora_name not in self.lora_id_for_name:
yield {
"status": "error",
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}",
}
return
logger.debug(f"Unloading LoRA adapter: {lora_name}")
lora_id = self.lora_id_for_name[lora_name]
lora_path = self.lora_name_to_path.get(lora_name)
await self.engine_client.remove_lora(lora_id)
# Serialize load/unload operations per lora_name.
lock = self._get_lora_lock(lora_name)
async with lock:
try:
# Check if the LoRA exists *after* waiting for any in-progress load.
if lora_name not in self.lora_id_for_name:
yield {
"status": "error",
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}",
}
return
# Remove from tracking dictionaries
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug(f"Unloading LoRA adapter: {lora_name}")
lora_id = self.lora_id_for_name[lora_name]
lora_path = self.lora_name_to_path.get(lora_name)
# Unregister the LoRA model from the model registry (outside lock)
if self.generate_endpoint is not None:
logger.debug(f"Unregistering LoRA '{lora_name}' ModelDeploymentCard")
try:
await unregister_llm(
endpoint=self.generate_endpoint,
lora_name=lora_name,
)
logger.info(
f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
import traceback
await self.engine_client.remove_lora(lora_id)
logger.error(
f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
)
logger.debug(f"Traceback: {traceback.format_exc()}")
# Remove from tracking dictionaries
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
# Rollback: re-add the LoRA to the engine to maintain consistency
try:
# Unregister the LoRA model from the model registry
if self.generate_endpoint is not None:
logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine"
f"Unregistering LoRA '{lora_name}' ModelDeploymentCard"
)
await self.engine_client.add_lora(
LoRARequest(
try:
await unregister_llm(
endpoint=self.generate_endpoint,
lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
)
)
# Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id
if lora_path:
self.lora_name_to_path[lora_name] = lora_path
logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
except Exception as rollback_error:
logger.error(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
logger.info(
f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
logger.exception(
f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
)
# Rollback: re-add the LoRA to the engine to maintain consistency
if lora_path is None:
logger.error(
f"Cannot rollback LoRA '{lora_name}': lora_path is None (data inconsistency)"
)
else:
try:
logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine"
)
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
)
)
# Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error:
logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
# Return error status since unregistration failed
yield {
"status": "error",
"message": f"Failed to unregister LoRA '{lora_name}' from discovery registry: {str(e)}",
"lora_name": lora_name,
}
return
else:
logger.debug(
f"Cannot unregister LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}"
)
# Return error status since unregistration failed
logger.info(
f"Successfully unloaded LoRA adapter: {lora_name} with ID {lora_id}"
)
yield {
"status": "error",
"message": f"Failed to unregister LoRA '{lora_name}' from discovery registry: {str(e)}",
"status": "success",
"message": f"LoRA adapter '{lora_name}' unloaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
return
else:
logger.debug(
f"Cannot unregister LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}"
)
logger.info(
f"Successfully unloaded LoRA adapter: {lora_name} with ID {lora_id}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' unloaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
finally:
# Remove lock entry once the LoRA is not loaded (or never was).
with self._lora_load_locks_guard:
if (
lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
):
self._lora_load_locks.pop(lora_name, None)
except Exception as e:
logger.error(f"Failed to unload LoRA adapter: {e}")
logger.exception(f"Failed to unload LoRA adapter: {e}")
yield {"status": "error", "message": str(e)}
async def list_loras(self, request=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment