Unverified Commit 07721d1c authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: handle concurrent load lora calls (#5184)

parent d9cc6f6b
...@@ -5,6 +5,7 @@ import asyncio ...@@ -5,6 +5,7 @@ import asyncio
import logging import logging
import os import os
import tempfile import tempfile
import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
...@@ -254,6 +255,10 @@ class BaseWorkerHandler(ABC): ...@@ -254,6 +255,10 @@ class BaseWorkerHandler(ABC):
# LoRA tracking # LoRA tracking
self.lora_id_for_name: dict[str, int] = {} self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {} self.lora_name_to_path: dict[str, str] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA
self._lora_load_locks: dict[str, asyncio.Lock] = {}
# Guard lock-map access in case handlers are invoked from multiple threads.
self._lora_load_locks_guard = threading.Lock()
self.use_vllm_tokenizer = use_vllm_tokenizer self.use_vllm_tokenizer = use_vllm_tokenizer
...@@ -309,6 +314,15 @@ class BaseWorkerHandler(ABC): ...@@ -309,6 +314,15 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None: if temp_dir is not None:
self.temp_dirs.append(temp_dir) self.temp_dirs.append(temp_dir)
def _get_lora_lock(self, lora_name: str) -> asyncio.Lock:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with self._lora_load_locks_guard:
lock = self._lora_load_locks.get(lora_name)
if lock is None:
lock = asyncio.Lock()
self._lora_load_locks[lora_name] = lock
return lock
async def load_lora(self, request=None): async def load_lora(self, request=None):
""" """
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine. Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
...@@ -320,6 +334,9 @@ class BaseWorkerHandler(ABC): ...@@ -320,6 +334,9 @@ class BaseWorkerHandler(ABC):
"uri": str # e.g., "s3://bucket/path" or "file:///path" "uri": str # e.g., "s3://bucket/path" or "file:///path"
} }
} }
This method is idempotent - concurrent calls for the same LoRA will be
serialized and only one load operation will happen.
""" """
try: try:
if request is None: if request is None:
...@@ -367,112 +384,145 @@ class BaseWorkerHandler(ABC): ...@@ -367,112 +384,145 @@ class BaseWorkerHandler(ABC):
} }
return return
logger.info(f"Downloading LoRA adapter: {lora_name} from {lora_uri}") # Serialize load/unload operations per lora_name.
download_result = await lora_manager.download_lora(lora_uri) lock = self._get_lora_lock(lora_name)
async with lock:
if download_result["status"] != "success": try:
yield { # Check if already loaded (idempotency check after acquiring lock).
"status": "error", # Another concurrent request may have loaded this LoRA while we waited.
"message": f"Failed to download LoRA: {download_result.get('message', 'Unknown error')}", if lora_name in self.lora_id_for_name:
} lora_id = self.lora_id_for_name[lora_name]
return logger.info(
f"LoRA adapter already loaded (concurrent request completed): "
lora_path = download_result["local_path"] f"{lora_name} with ID {lora_id}"
logger.debug(f"LoRA downloaded to: {lora_path}") )
yield {
# Generate deterministic ID from lora_name before using it "status": "success",
lora_id = lora_name_to_id(lora_name) "message": f"LoRA adapter '{lora_name}' already loaded",
"lora_name": lora_name,
# Add the LoRA to the engine "lora_id": lora_id,
await self.engine_client.add_lora( }
LoRARequest( return
lora_name=lora_name, lora_int_id=lora_id, lora_path=lora_path
)
)
# Track the LoRA logger.info(
self.lora_id_for_name[lora_name] = lora_id f"Downloading LoRA adapter: {lora_name} from {lora_uri}"
self.lora_name_to_path[lora_name] = lora_path )
logger.info( download_result = await lora_manager.download_lora(lora_uri)
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
)
# Publish LoRA as a ModelDeploymentCard with format: if download_result["status"] != "success":
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug} yield {
# This allows the frontend to discover it and route correctly to the worker instance "status": "error",
"message": f"Failed to download LoRA: {download_result.get('message', 'Unknown error')}",
}
return
if self.generate_endpoint is not None and self.config is not None: lora_path = download_result["local_path"]
logger.debug( logger.debug(f"LoRA downloaded to: {lora_path}")
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
)
try:
logger.debug(f"Publishing LoRA '{lora_name}' ModelDeploymentCard")
# Mark this as a LoRA in user_data # Generate deterministic ID from lora_name before using it
user_data = { lora_id = lora_name_to_id(lora_name)
"lora_adapter": True,
"lora_id": lora_id,
}
# Publish with format: v1/mdc/dynamo/backend/generate/{instance_id}/{lora_slug} # Add the LoRA to the engine
await register_llm( await self.engine_client.add_lora(
model_input=ModelInput.Tokens, LoRARequest(
model_type=ModelType.Chat | ModelType.Completions, lora_name=lora_name,
endpoint=self.generate_endpoint, lora_int_id=lora_id,
model_path=self.config.model, lora_path=lora_path,
kv_cache_block_size=self.config.engine_args.block_size, )
user_data=user_data,
lora_name=lora_name,
base_model_path=self.config.model,
)
logger.info(
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
) )
except Exception as e:
import traceback
logger.error( # Track the LoRA
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}" self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
logger.info(
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
) )
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: remove the LoRA from the engine to maintain consistency # Publish LoRA as a ModelDeploymentCard with format:
try: # v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None:
logger.debug( logger.debug(
f"Rolling back: removing LoRA '{lora_name}' from engine" f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
) )
await self.engine_client.remove_lora(lora_id) try:
# Remove from tracking dictionaries logger.debug(
if lora_name in self.lora_id_for_name: f"Publishing LoRA '{lora_name}' ModelDeploymentCard"
del self.lora_id_for_name[lora_name] )
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name] # Mark this as a LoRA in user_data
logger.debug(f"Successfully rolled back LoRA '{lora_name}'") user_data = {
except Exception as rollback_error: "lora_adapter": True,
logger.error( "lora_id": lora_id,
f"Failed to rollback LoRA {lora_name}: {rollback_error}" }
# Publish with format: v1/mdc/dynamo/backend/generate/{instance_id}/{lora_slug}
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=self.generate_endpoint,
model_path=self.config.model,
kv_cache_block_size=self.config.engine_args.block_size,
user_data=user_data,
lora_name=lora_name,
base_model_path=self.config.model,
)
logger.info(
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
logger.exception(
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
)
# Rollback: remove the LoRA from the engine to maintain consistency
try:
logger.debug(
f"Rolling back: removing LoRA '{lora_name}' from engine"
)
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
if lora_name in self.lora_id_for_name:
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error:
logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
# Return error status since registration failed
yield {
"status": "error",
"message": f"Failed to register LoRA '{lora_name}' in discovery registry: {str(e)}",
"lora_name": lora_name,
}
return
else:
logger.debug(
f"Cannot publish LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}, config={self.config}"
) )
# Return error status since registration failed
yield { yield {
"status": "error", "status": "success",
"message": f"Failed to register LoRA '{lora_name}' in discovery registry: {str(e)}", "message": f"LoRA adapter '{lora_name}' loaded successfully",
"lora_name": lora_name, "lora_name": lora_name,
"lora_id": lora_id,
} }
return finally:
else: # Avoid lock-map growth on failed loads: if this attempt did not leave the LoRA
logger.debug( # loaded, remove the lock entry (best-effort).
f"Cannot publish LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}, config={self.config}" with self._lora_load_locks_guard:
) if (
lora_name not in self.lora_id_for_name
yield { and self._lora_load_locks.get(lora_name) is lock
"status": "success", ):
"message": f"LoRA adapter '{lora_name}' loaded successfully", self._lora_load_locks.pop(lora_name, None)
"lora_name": lora_name,
"lora_id": lora_id,
}
except Exception as e: except Exception as e:
logger.error(f"Failed to load LoRA adapter: {e}") logger.exception(f"Failed to load LoRA adapter: {e}")
yield {"status": "error", "message": str(e)} yield {"status": "error", "message": str(e)}
async def unload_lora(self, request=None): async def unload_lora(self, request=None):
...@@ -498,89 +548,106 @@ class BaseWorkerHandler(ABC): ...@@ -498,89 +548,106 @@ class BaseWorkerHandler(ABC):
} }
return return
# Check if the LoRA exists # Serialize load/unload operations per lora_name.
if lora_name not in self.lora_id_for_name: lock = self._get_lora_lock(lora_name)
yield { async with lock:
"status": "error", try:
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}", # Check if the LoRA exists *after* waiting for any in-progress load.
} if lora_name not in self.lora_id_for_name:
return yield {
"status": "error",
logger.debug(f"Unloading LoRA adapter: {lora_name}") "message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}",
lora_id = self.lora_id_for_name[lora_name] }
lora_path = self.lora_name_to_path.get(lora_name) return
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries logger.debug(f"Unloading LoRA adapter: {lora_name}")
del self.lora_id_for_name[lora_name] lora_id = self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path: lora_path = self.lora_name_to_path.get(lora_name)
del self.lora_name_to_path[lora_name]
# Unregister the LoRA model from the model registry (outside lock) await self.engine_client.remove_lora(lora_id)
if self.generate_endpoint is not None:
logger.debug(f"Unregistering LoRA '{lora_name}' ModelDeploymentCard")
try:
await unregister_llm(
endpoint=self.generate_endpoint,
lora_name=lora_name,
)
logger.info(
f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
import traceback
logger.error( # Remove from tracking dictionaries
f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}" del self.lora_id_for_name[lora_name]
) if lora_name in self.lora_name_to_path:
logger.debug(f"Traceback: {traceback.format_exc()}") del self.lora_name_to_path[lora_name]
# Rollback: re-add the LoRA to the engine to maintain consistency # Unregister the LoRA model from the model registry
try: if self.generate_endpoint is not None:
logger.debug( logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine" f"Unregistering LoRA '{lora_name}' ModelDeploymentCard"
) )
await self.engine_client.add_lora( try:
LoRARequest( await unregister_llm(
endpoint=self.generate_endpoint,
lora_name=lora_name, lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
) )
) logger.info(
# Re-add to tracking dictionaries f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
self.lora_id_for_name[lora_name] = lora_id )
if lora_path: except Exception as e:
self.lora_name_to_path[lora_name] = lora_path logger.exception(
logger.debug(f"Successfully rolled back LoRA '{lora_name}'") f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
except Exception as rollback_error: )
logger.error(
f"Failed to rollback LoRA {lora_name}: {rollback_error}" # Rollback: re-add the LoRA to the engine to maintain consistency
if lora_path is None:
logger.error(
f"Cannot rollback LoRA '{lora_name}': lora_path is None (data inconsistency)"
)
else:
try:
logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine"
)
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
)
)
# Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
logger.debug(
f"Successfully rolled back LoRA '{lora_name}'"
)
except Exception as rollback_error:
logger.exception(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
# Return error status since unregistration failed
yield {
"status": "error",
"message": f"Failed to unregister LoRA '{lora_name}' from discovery registry: {str(e)}",
"lora_name": lora_name,
}
return
else:
logger.debug(
f"Cannot unregister LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}"
) )
# Return error status since unregistration failed logger.info(
f"Successfully unloaded LoRA adapter: {lora_name} with ID {lora_id}"
)
yield { yield {
"status": "error", "status": "success",
"message": f"Failed to unregister LoRA '{lora_name}' from discovery registry: {str(e)}", "message": f"LoRA adapter '{lora_name}' unloaded successfully",
"lora_name": lora_name, "lora_name": lora_name,
"lora_id": lora_id,
} }
return finally:
else: # Remove lock entry once the LoRA is not loaded (or never was).
logger.debug( with self._lora_load_locks_guard:
f"Cannot unregister LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}" if (
) lora_name not in self.lora_id_for_name
and self._lora_load_locks.get(lora_name) is lock
logger.info( ):
f"Successfully unloaded LoRA adapter: {lora_name} with ID {lora_id}" self._lora_load_locks.pop(lora_name, None)
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' unloaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
except Exception as e: except Exception as e:
logger.error(f"Failed to unload LoRA adapter: {e}") logger.exception(f"Failed to unload LoRA adapter: {e}")
yield {"status": "error", "message": str(e)} yield {"status": "error", "message": str(e)}
async def list_loras(self, request=None): async def list_loras(self, request=None):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment