Unverified Commit 71f94eda authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: lora - centralize lora cache key, restructure folders, s3 resiliency (#4644)

parent 4c1bc4ee
...@@ -5,6 +5,6 @@ ...@@ -5,6 +5,6 @@
LoRA management infrastructure LoRA management infrastructure
""" """
from .lora import LoRAManager, LoRASourceProtocol from .manager import LoRAManager, LoRASourceProtocol
__all__ = ["LoRAManager", "LoRASourceProtocol"] __all__ = ["LoRAManager", "LoRASourceProtocol"]
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Minimal LoRA management layer with extensible sources.
"""
from .manager import LoRAManager, LoRASourceProtocol
__all__ = ["LoRAManager", "LoRASourceProtocol"]
...@@ -106,13 +106,9 @@ class LoRAManager: ...@@ -106,13 +106,9 @@ class LoRAManager:
def is_cached(self, lora_uri: str) -> bool: def is_cached(self, lora_uri: str) -> bool:
"""Check if LoRA is already cached locally.""" """Check if LoRA is already cached locally."""
cache_key = self._uri_to_cache_key(lora_uri) cache_key = LoRADownloader.uri_to_cache_key(lora_uri)
return self._downloader.is_cached(cache_key) return self._downloader.is_cached(cache_key)
def _uri_to_cache_key(self, uri: str) -> str: def _uri_to_cache_key(self, uri: str) -> str:
return ( """Convert URI to cache key. Delegates to Rust for consistency."""
uri.replace("://", "__") return LoRADownloader.uri_to_cache_key(uri)
.replace(".", "_")
.replace("/", "_")
.replace("\\", "_")
)
...@@ -10,11 +10,19 @@ from contextlib import asynccontextmanager ...@@ -10,11 +10,19 @@ from contextlib import asynccontextmanager
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncGenerator, Dict, Final
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.llm import ZmqKvEventPublisher from dynamo.llm import (
ModelInput,
ModelType,
ZmqKvEventPublisher,
lora_name_to_id,
register_llm,
unregister_llm,
)
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor from .engine_monitor import VllmEngineMonitor
...@@ -29,6 +37,32 @@ DECODED_VARIANT_KEY: Final = "Decoded" ...@@ -29,6 +37,32 @@ DECODED_VARIANT_KEY: Final = "Decoded"
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# LoRAManager singleton - initialized lazily when DYN_LORA_ENABLED is set
# None = not yet initialized, False = disabled/failed, LoRAManager = initialized
_lora_manager = None
def get_lora_manager():
"""Get the LoRAManager singleton, initializing it on first call if enabled."""
global _lora_manager
if _lora_manager is not None:
return _lora_manager
if os.environ.get("DYN_LORA_ENABLED", "").lower() in ("true", "1", "yes"):
try:
from dynamo.common.lora import LoRAManager
_lora_manager = LoRAManager()
logger.info("LoRAManager initialized successfully")
return _lora_manager
except Exception as e:
logger.warning(
f"Failed to initialize LoRAManager: {e}. URI-based LoRA loading will be disabled."
)
return None
def build_sampling_params( def build_sampling_params(
request: Dict[str, Any], request: Dict[str, Any],
...@@ -86,17 +120,24 @@ class BaseWorkerHandler(ABC): ...@@ -86,17 +120,24 @@ class BaseWorkerHandler(ABC):
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
): ):
self.runtime = runtime self.runtime = runtime
self.component = component self.component = component
self.engine_client = engine self.engine_client = engine
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
self.kv_publishers: list[ZmqKvEventPublisher] | None = None self.kv_publishers: list[ZmqKvEventPublisher] | None = None
self.generate_endpoint = generate_endpoint
self.config = config
self.engine_monitor = VllmEngineMonitor(runtime, engine) self.engine_monitor = VllmEngineMonitor(runtime, engine)
self.image_loader = ImageLoader() self.image_loader = ImageLoader()
self.temp_dirs: list[tempfile.TemporaryDirectory] = [] self.temp_dirs: list[tempfile.TemporaryDirectory] = []
self.model_max_len = model_max_len self.model_max_len = model_max_len
self.enable_multimodal = enable_multimodal self.enable_multimodal = enable_multimodal
# LoRA tracking
self.lora_id_for_name: dict[str, int] = {}
self.lora_name_to_path: dict[str, str] = {}
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
...@@ -144,6 +185,296 @@ class BaseWorkerHandler(ABC): ...@@ -144,6 +185,296 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None: if temp_dir is not None:
self.temp_dirs.append(temp_dir) self.temp_dirs.append(temp_dir)
async def load_lora(self, request=None):
"""
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
Request format:
{
"lora_name": str,
"source": {
"uri": str # e.g., "s3://bucket/path" or "file:///path"
}
}
"""
try:
if request is None:
yield {
"status": "error",
"message": "Request is required with 'lora_name' and 'source.uri'",
}
return
lora_name = request.get("lora_name")
if not lora_name:
yield {
"status": "error",
"message": "'lora_name' is required in request",
}
return
# Debug: Log the incoming request
logger.debug(f"load_lora request keys: {list(request.keys())}")
logger.debug(f"load_lora request: {request}")
# Check for URI-based API format (source.uri)
source = request.get("source")
if not source or not isinstance(source, dict):
yield {
"status": "error",
"message": "'source' object is required in request",
}
return
lora_uri = source.get("uri")
if not lora_uri:
yield {
"status": "error",
"message": "'source.uri' is required in request",
}
return
# Use LoRAManager to download from URI
lora_manager = get_lora_manager()
if lora_manager is None:
yield {
"status": "error",
"message": "LoRAManager not initialized. Set DYN_LORA_ENABLED=true to enable URI-based LoRA loading.",
}
return
logger.info(f"Downloading LoRA adapter: {lora_name} from {lora_uri}")
download_result = await lora_manager.download_lora(lora_uri)
if download_result["status"] != "success":
yield {
"status": "error",
"message": f"Failed to download LoRA: {download_result.get('message', 'Unknown error')}",
}
return
lora_path = download_result["local_path"]
logger.debug(f"LoRA downloaded to: {lora_path}")
# Generate deterministic ID from lora_name before using it
lora_id = lora_name_to_id(lora_name)
# Add the LoRA to the engine
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name, lora_int_id=lora_id, lora_path=lora_path
)
)
# Track the LoRA
self.lora_id_for_name[lora_name] = lora_id
self.lora_name_to_path[lora_name] = lora_path
logger.info(
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
)
# Publish LoRA as a ModelDeploymentCard with format:
# v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
# This allows the frontend to discover it and route correctly to the worker instance
if self.generate_endpoint is not None and self.config is not None:
logger.debug(
f"Publishing LoRA '{lora_name}' ModelDeploymentCard to {self.generate_endpoint}"
)
try:
logger.debug(f"Publishing LoRA '{lora_name}' ModelDeploymentCard")
# Mark this as a LoRA in user_data
user_data = {
"lora_adapter": True,
"lora_id": lora_id,
}
# Publish with format: v1/mdc/dynamo/backend/generate/{instance_id}/{lora_slug}
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=self.generate_endpoint,
model_path=self.config.model,
kv_cache_block_size=self.config.engine_args.block_size,
user_data=user_data,
lora_name=lora_name,
base_model_path=self.config.model,
)
logger.info(
f"Successfully published LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
import traceback
logger.error(
f"Failed to publish LoRA {lora_name} ModelDeploymentCard: {e}"
)
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: remove the LoRA from the engine to maintain consistency
try:
logger.debug(
f"Rolling back: removing LoRA '{lora_name}' from engine"
)
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
if lora_name in self.lora_id_for_name:
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
except Exception as rollback_error:
logger.error(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
# Return error status since registration failed
yield {
"status": "error",
"message": f"Failed to register LoRA '{lora_name}' in discovery registry: {str(e)}",
"lora_name": lora_name,
}
return
else:
logger.debug(
f"Cannot publish LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}, config={self.config}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' loaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
except Exception as e:
logger.error(f"Failed to load LoRA adapter: {e}")
yield {"status": "error", "message": str(e)}
async def unload_lora(self, request=None):
"""
Unload a LoRA adapter dynamically from the vLLM's AsyncLLM engine.
Expected request format:
{
"lora_name": str,
}
"""
try:
if request is None:
yield {
"status": "error",
"message": "Request is required with 'lora_name' field",
}
return
lora_name = request.get("lora_name")
if not lora_name:
yield {
"status": "error",
"message": "'lora_name' is required in request",
}
return
# Check if the LoRA exists
if lora_name not in self.lora_id_for_name:
yield {
"status": "error",
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}",
}
return
logger.debug(f"Unloading LoRA adapter: {lora_name}")
lora_id = self.lora_id_for_name[lora_name]
lora_path = self.lora_name_to_path.get(lora_name)
await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
# Unregister the LoRA model from the model registry (outside lock)
if self.generate_endpoint is not None:
logger.debug(f"Unregistering LoRA '{lora_name}' ModelDeploymentCard")
try:
await unregister_llm(
endpoint=self.generate_endpoint,
lora_name=lora_name,
)
logger.info(
f"Successfully unregistered LoRA '{lora_name}' ModelDeploymentCard"
)
except Exception as e:
import traceback
logger.error(
f"Failed to unregister LoRA {lora_name} ModelDeploymentCard: {e}"
)
logger.debug(f"Traceback: {traceback.format_exc()}")
# Rollback: re-add the LoRA to the engine to maintain consistency
try:
logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine"
)
await self.engine_client.add_lora(
LoRARequest(
lora_name=lora_name,
lora_int_id=lora_id,
lora_path=lora_path,
)
)
# Re-add to tracking dictionaries
self.lora_id_for_name[lora_name] = lora_id
if lora_path:
self.lora_name_to_path[lora_name] = lora_path
logger.debug(f"Successfully rolled back LoRA '{lora_name}'")
except Exception as rollback_error:
logger.error(
f"Failed to rollback LoRA {lora_name}: {rollback_error}"
)
# Return error status since unregistration failed
yield {
"status": "error",
"message": f"Failed to unregister LoRA '{lora_name}' from discovery registry: {str(e)}",
"lora_name": lora_name,
}
return
else:
logger.debug(
f"Cannot unregister LoRA '{lora_name}': generate_endpoint={self.generate_endpoint}"
)
logger.info(
f"Successfully unloaded LoRA adapter: {lora_name} with ID {lora_id}"
)
yield {
"status": "success",
"message": f"LoRA adapter '{lora_name}' unloaded successfully",
"lora_name": lora_name,
"lora_id": lora_id,
}
except Exception as e:
logger.error(f"Failed to unload LoRA adapter: {e}")
yield {"status": "error", "message": str(e)}
async def list_loras(self, request=None):
"""
List all loaded LoRA adapters.
Returns a dictionary of lora_name -> lora_id mappings.
"""
try:
loras = dict(self.lora_id_for_name)
yield {
"status": "success",
"loras": loras,
"count": len(loras),
}
except Exception as e:
logger.error(f"Failed to list LoRA adapters: {e}")
yield {"status": "error", "message": str(e)}
def cleanup(self): def cleanup(self):
"""Clean up resources including temporary directories.""" """Clean up resources including temporary directories."""
for temp_dir in self.temp_dirs: for temp_dir in self.temp_dirs:
...@@ -226,13 +557,30 @@ class BaseWorkerHandler(ABC): ...@@ -226,13 +557,30 @@ class BaseWorkerHandler(ABC):
} }
async def generate_tokens( async def generate_tokens(
self, prompt, sampling_params, request_id, data_parallel_rank=None self,
prompt,
sampling_params,
request_id,
data_parallel_rank=None,
lora_request=None,
): ):
try: try:
# Log LoRA usage for this generation (debug level to avoid log spam)
if lora_request:
logger.debug(
f"Starting token generation for request {request_id} with LoRA: "
f"{lora_request.lora_name} (ID: {lora_request.lora_int_id})"
)
else:
logger.debug(
f"Starting token generation for request {request_id} (no LoRA)"
)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt, prompt,
sampling_params, sampling_params,
request_id, request_id,
lora_request=lora_request,
data_parallel_rank=data_parallel_rank, data_parallel_rank=data_parallel_rank,
) )
...@@ -242,6 +590,11 @@ class BaseWorkerHandler(ABC): ...@@ -242,6 +590,11 @@ class BaseWorkerHandler(ABC):
# res is vllm's RequestOutput # res is vllm's RequestOutput
if not res.outputs: if not res.outputs:
if lora_request:
logger.debug(
f"Request {request_id} with LoRA {lora_request.lora_name} "
"returned no outputs"
)
yield {"finish_reason": "error", "token_ids": []} yield {"finish_reason": "error", "token_ids": []}
break break
...@@ -255,6 +608,18 @@ class BaseWorkerHandler(ABC): ...@@ -255,6 +608,18 @@ class BaseWorkerHandler(ABC):
] = BaseWorkerHandler._build_completion_usage( ] = BaseWorkerHandler._build_completion_usage(
request_output=res request_output=res
) )
# Log completion with LoRA info (debug level to avoid log spam)
if lora_request:
logger.debug(
f"Completed token generation for request {request_id} with LoRA "
f"{lora_request.lora_name}: {next_total_toks} output tokens, "
f"finish_reason={output.finish_reason}"
)
else:
logger.debug(
f"Completed token generation for request {request_id}: "
f"{next_total_toks} output tokens, finish_reason={output.finish_reason}"
)
if output.stop_reason: if output.stop_reason:
out["stop_reason"] = output.stop_reason out["stop_reason"] = output.stop_reason
yield out yield out
...@@ -281,6 +646,8 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -281,6 +646,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -289,6 +656,8 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -289,6 +656,8 @@ class DecodeWorkerHandler(BaseWorkerHandler):
default_sampling_params, default_sampling_params,
model_max_len, model_max_len,
enable_multimodal, enable_multimodal,
generate_endpoint,
config,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -327,12 +696,36 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -327,12 +696,36 @@ class DecodeWorkerHandler(BaseWorkerHandler):
prefill_result.get("prompt_tokens_details") if prefill_result else None prefill_result.get("prompt_tokens_details") if prefill_result else None
) )
# Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request = None
model_name = request.get("model")
if model_name and model_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[model_name]
lora_request = LoRARequest(
lora_name=model_name,
lora_int_id=lora_id,
lora_path=self.lora_name_to_path[model_name],
)
logger.info(
f"Decode request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id})"
)
else:
logger.debug(
f"Decode request {request_id} has no LoRA specified (model: {model_name})"
)
dp_rank = request.get("dp_rank", None) dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id): async with self._abort_monitor(context, request_id):
try: try:
async for tok in self.generate_tokens( async for tok in self.generate_tokens(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank prompt,
sampling_params,
request_id,
data_parallel_rank=dp_rank,
lora_request=lora_request,
): ):
if prefill_result is not None and "completion_usage" in tok: if prefill_result is not None and "completion_usage" in tok:
tok["completion_usage"][ tok["completion_usage"][
...@@ -355,6 +748,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -355,6 +748,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
default_sampling_params, default_sampling_params,
model_max_len: int | None = None, model_max_len: int | None = None,
enable_multimodal: bool = False, enable_multimodal: bool = False,
generate_endpoint=None,
config=None,
): ):
super().__init__( super().__init__(
runtime, runtime,
...@@ -363,6 +758,8 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -363,6 +758,8 @@ class PrefillWorkerHandler(BaseWorkerHandler):
default_sampling_params, default_sampling_params,
model_max_len, model_max_len,
enable_multimodal, enable_multimodal,
generate_endpoint,
config,
) )
async def generate(self, request, context): async def generate(self, request, context):
...@@ -403,12 +800,37 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -403,12 +800,37 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params.max_tokens = 1 sampling_params.max_tokens = 1
sampling_params.min_tokens = 1 sampling_params.min_tokens = 1
# Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request = None
model_name = request.get("model")
if model_name and model_name in self.lora_id_for_name:
lora_id = self.lora_id_for_name[model_name]
lora_request = LoRARequest(
lora_name=model_name,
lora_int_id=lora_id,
lora_path=self.lora_name_to_path[model_name],
)
logger.info(
f"Prefill request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id}), "
f"path: {self.lora_name_to_path[model_name]}"
)
else:
logger.debug(
f"Prefill request {request_id} has no LoRA specified (model: {model_name})"
)
dp_rank = request.get("dp_rank", None) dp_rank = request.get("dp_rank", None)
async with self._abort_monitor(context, request_id, is_prefill=True): async with self._abort_monitor(context, request_id, is_prefill=True):
try: try:
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt, sampling_params, request_id, data_parallel_rank=dp_rank prompt,
sampling_params,
request_id,
data_parallel_rank=dp_rank,
lora_request=lora_request,
) )
except EngineDeadError as e: except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}") logger.error(f"vLLM EngineDeadError: {e}")
...@@ -434,6 +856,14 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -434,6 +856,14 @@ class PrefillWorkerHandler(BaseWorkerHandler):
), ),
} }
# Log prefill completion with LoRA info
if lora_request:
logger.info(
f"Prefill completed for request {request_id} with LoRA {lora_request.lora_name}: "
f"generated {len(token_ids)} token(s), "
f"has_kv_params={res.kv_transfer_params is not None}"
)
yield output yield output
except asyncio.CancelledError: except asyncio.CancelledError:
# raise the error because we cannot migrate prefill requests # raise the error because we cannot migrate prefill requests
......
...@@ -198,6 +198,11 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -198,6 +198,11 @@ def setup_vllm_engine(config, stat_logger=None):
engine_args = config.engine_args engine_args = config.engine_args
if engine_args.enable_lora:
if "VLLM_ALLOW_RUNTIME_LORA_UPDATING" not in os.environ:
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
if "VLLM_LORA_MODULES_LOADING_TIMEOUT" not in os.environ:
os.environ["VLLM_LORA_MODULES_LOADING_TIMEOUT"] = "600"
# Load default sampling params from `generation_config.json` # Load default sampling params from `generation_config.json`
default_sampling_params = ( default_sampling_params = (
engine_args.create_model_config().get_diff_sampling_param() engine_args.create_model_config().get_diff_sampling_param()
...@@ -318,6 +323,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -318,6 +323,8 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -425,6 +432,9 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -425,6 +432,9 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
factory = StatLoggerFactory( factory = StatLoggerFactory(
component, component,
...@@ -450,6 +460,8 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -450,6 +460,8 @@ async def init(runtime: DistributedRuntime, config: Config):
default_sampling_params, default_sampling_params,
getattr(getattr(vllm_config, "model_config", None), "max_model_len", None), getattr(getattr(vllm_config, "model_config", None), "max_model_len", None),
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -534,6 +546,18 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -534,6 +546,18 @@ async def init(runtime: DistributedRuntime, config: Config):
handler.clear_kv_blocks, handler.clear_kv_blocks,
metrics_labels=[("model", config.served_model_name or config.model)], metrics_labels=[("model", config.served_model_name or config.model)],
), ),
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=[("model", config.served_model_name or config.model)],
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=[("model", config.served_model_name or config.model)],
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=[("model", config.served_model_name or config.model)],
),
) )
logger.debug("serve_endpoint completed for decode worker") logger.debug("serve_endpoint completed for decode worker")
except Exception as e: except Exception as e:
......
# S3-compatible storage backend LoRA Integration Guide
This guide explains how to set up and use LoRA (Low-Rank Adaptation) adapters with Dynamo using S3-compatible storage backend (e.g. MinIO, AWS S3, GCS, etc.).
## Overview
This example demonstrates how to:
1. Set up MinIO as a local S3-compatible storage
2. Download LoRA adapters from Hugging Face Hub
3. Upload LoRA adapters to MinIO
4. Load and use LoRA adapters with Dynamo
5. Run inference with LoRA-adapted models
6. Manage (load/unload) LoRA adapters
## Prerequisites
### Required Software
- Docker (for running MinIO)
- Python 3.8+
- AWS CLI: `pip install awscli`
- Hugging Face CLI: `pip install huggingface-hub`
- jq (optional, for pretty JSON output): `sudo apt install jq`
### Python Dependencies
Make sure you have Dynamo installed with vLLM support:
```bash
pip install dynamo vllm
```
## Quick Start
### Step 1: Setup MinIO and Upload LoRA
Run the setup script to start MinIO and download/upload a LoRA adapter from Hugging Face:
```bash
./setup_minio.sh
```
This script will:
- Start MinIO in a Docker container
- Download a LoRA adapter from Hugging Face Hub (default: `Neural-Hacker/Qwen3-Math-Reasoning-LoRA`)
- Upload the LoRA to MinIO at `s3://my-loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA`
#### Script Options
The setup script supports different modes:
```bash
# Full setup (default) - start MinIO, download & upload LoRA
./setup_minio.sh
# Start MinIO only (without downloading/uploading)
./setup_minio.sh --start
# Stop MinIO
./setup_minio.sh --stop
# Show help
./setup_minio.sh --help
```
#### Customize the LoRA to Download
You can specify a different LoRA repository and name:
```bash
HF_LORA_REPO="username/lora-repo" \
LORA_NAME="my-lora" \
./setup_minio.sh
```
### Step 2: Launch Dynamo with LoRA Support
Start the Dynamo frontend and worker with LoRA support enabled:
```bash
./agg_lora_s3.sh
```
This will:
- Set up AWS credentials for MinIO
- Start the Dynamo frontend on port 8000
- Start the Dynamo worker (vLLM) on port 8081 with LoRA support
Wait for the services to start (check the logs for "Application startup complete").
## Working with LoRAs
### 1. Check Available Models
List all available models (base model only at first):
```bash
curl http://localhost:8000/v1/models | jq .
```
### 2. Load a LoRA Adapter
Load a LoRA from S3-compatible storage backend (e.g. MinIO):
```bash
curl -X POST http://localhost:8081/v1/loras \
-H "Content-Type: application/json" \
-d '{
"lora_name": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA",
"source": {
"uri": "s3://my-loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA"
}
}' | jq .
```
Expected response:
```json
{
"status": "success",
"message": "LoRA adapter 'Neural-Hacker/Qwen3-Math-Reasoning-LoRA' loaded successfully",
"lora_name": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA",
"lora_id": 1207343256
}
```
### 3. List Loaded LoRAs
Check which LoRAs are currently loaded:
```bash
curl http://localhost:8081/v1/loras | jq .
```
### 4. Verify LoRA in Models List
After loading, the LoRA should appear in the models list:
```bash
curl http://localhost:8000/v1/models | jq .
```
You should see both the base model and the LoRA adapter listed.
### 5. Run Inference with LoRA
#### Using the LoRA-adapted model:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA",
"messages": [{
"role": "user",
"content": "What is good low risk investment strategy?"
}],
"max_tokens": 300,
"temperature": 0.1
}' | jq .
```
#### For comparison, using the base model:
```bash
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-0.6B",
"messages": [{
"role": "user",
"content": "What is good low risk investment strategy?"
}],
"max_tokens": 300
}' | jq .
```
### 6. Unload a LoRA
When you no longer need a LoRA, unload it to free up resources:
```bash
curl -X DELETE http://localhost:8081/v1/loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA | jq .
```
Expected response:
```json
{
"status": "success",
"message": "LoRA unloaded successfully"
}
```
After unloading, the LoRA will be removed from both `/v1/loras` and `/v1/models` endpoints.
## Configuration
### Environment Variables
The following environment variables can be configured:
```bash
# S3-compatible storage backend Configuration
export AWS_ENDPOINT=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
export AWS_REGION=us-east-1
# Dynamo LoRA Configuration
export DYN_LORA_ENABLED=true
export DYN_LORA_PATH=/tmp/dynamo_loras_minio
```
### MinIO Console
Access the MinIO web console at http://localhost:9001
- Username: `minioadmin`
- Password: `minioadmin`
## Troubleshooting
### MinIO won't start
- Check if ports 9000 and 9001 are already in use
- Ensure Docker is running
- Check Docker logs: `docker logs dynamo-minio`
- Try stopping any existing MinIO containers: `./setup_minio.sh --stop`
- Restart MinIO: `./setup_minio.sh --start`
### LoRA fails to load
- Verify the LoRA is uploaded to MinIO: `aws --endpoint-url=http://localhost:9000 s3 ls s3://my-loras/`
- Check AWS credentials are set correctly
- Ensure the LoRA files are compatible with the base model
- Check vLLM logs for detailed error messages
### Inference fails
- Verify the model name matches exactly (case-sensitive)
- Check if the LoRA is loaded: `curl http://localhost:8081/v1/loras`
- Ensure the base model supports the LoRA rank
- Check that max_lora_rank in the worker config is >= the LoRA rank
### Cache issues
- Check the cache directory: `ls -la /tmp/dynamo_loras_minio/`
- Clear the cache if needed: `rm -rf /tmp/dynamo_loras_minio/*`
- Ensure the cache directory is writable
## Advanced Usage
### Loading Multiple LoRAs
You can load multiple LoRA adapters simultaneously:
```bash
# Load first LoRA
curl -X POST http://localhost:8081/v1/loras \
-H "Content-Type: application/json" \
-d '{"lora_name": "lora1", "source": {"uri": "s3://my-loras/lora1"}}'
# Load second LoRA
curl -X POST http://localhost:8081/v1/loras \
-H "Content-Type: application/json" \
-d '{"lora_name": "lora2", "source": {"uri": "s3://my-loras/lora2"}}'
```
### Using Different Base Models
To use a different base model, modify the `--model` parameter in `agg_lora_s3.sh`:
```bash
python -m dynamo.vllm --model meta-llama/Llama-2-7b-hf --enable-lora --max-lora-rank 64
```
Ensure your LoRAs are compatible with the chosen base model.
## Cleanup
### Stop Services
Press `Ctrl+C` in the terminal running `agg_lora_s3.sh` to stop Dynamo services.
### Stop MinIO
```bash
# Using the setup script (recommended)
./setup_minio.sh --stop
# Or manually with Docker
docker stop dynamo-minio
docker rm dynamo-minio
```
### Clean Up Data
```bash
# Remove MinIO data
rm -rf ~/dynamo_minio_data
# Remove LoRA cache
rm -rf /tmp/dynamo_loras_minio
```
## API Reference
### Load LoRA
- **Endpoint**: `POST /v1/loras`
- **Body**: `{"lora_name": "string", "source": {"uri": "string"}}`
- **Response**: `{"status": "success", "lora_id": int}`
### List LoRAs
- **Endpoint**: `GET /v1/loras`
- **Response**: Array of loaded LoRAs
### Unload LoRA
- **Endpoint**: `DELETE /v1/loras/{lora_name}`
- **Response**: `{"status": "success", "message": "string"}`
### List Models
- **Endpoint**: `GET /v1/models`
- **Response**: OpenAI-compatible models list
### Chat Completions
- **Endpoint**: `POST /v1/chat/completions`
- **Body**: OpenAI-compatible chat completion request
- **Response**: OpenAI-compatible chat completion response
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# Follow the README.md instructions to setup MinIO or upload the LoRA to s3/minio
# Adjust these values to match your local MinIO or S3 setup
# load math lora to minio
# LORA_NAME=Neural-Hacker/Qwen3-Math-Reasoning-LoRA HF_LORA_REPO=Neural-Hacker/Qwen3-Math-Reasoning-LoRA ./setup_minio.sh
export AWS_ENDPOINT=http://localhost:9000
export AWS_ACCESS_KEY_ID=minioadmin
export AWS_SECRET_ACCESS_KEY=minioadmin
export AWS_REGION=us-east-1
export AWS_ALLOW_HTTP=true
# Dynamo LoRA Configuration
export DYN_LORA_ENABLED=true
export DYN_LORA_PATH=/tmp/dynamo_loras_minio
export DYN_LOG=debug
# export DYN_LOG_LEVEL=debug
mkdir -p $DYN_LORA_PATH
# run ingress
python -m dynamo.frontend --http-port=8000 &
# run worker
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
DYN_SYSTEM_ENABLED=true DYN_SYSTEM_PORT=8081 \
python -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager \
--connector none \
--enable-lora \
--max-lora-rank 32
################################## Example Usage ##################################
# Check available models
curl http://localhost:8000/v1/models | jq .
# Load LoRA using s3 uri
curl -X POST http://localhost:8081/v1/loras \
-H "Content-Type: application/json" \
-d '{
"lora_name": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA",
"source": {
"uri": "s3://my-loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA"
}
}'
# Test LoRA inference
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Neural-Hacker/Qwen3-Math-Reasoning-LoRA",
"messages": [{"role": "user", "content": "Solve (x*x - x + 1 = 0) for x"}],
"max_tokens": 300,
"temperature": 0.0
}'
# Find the minimum possible value of \( x^2 + y^2 \) given that \( x \) and \( y \) are real numbers satisfying \( xy(x^2 - y^2) = x^2 + y^2 \) and \( x \neq 0 \)
# Test base model inference (for comparison)
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen3-0.6B",
"messages": [{"role": "user", "content": "Solve (x*x - x + 1 = 0) for x"}],
"max_tokens": 300,
"temperature": 0.0
}'
# Unload LoRA
curl -X DELETE http://localhost:8081/v1/loras/Neural-Hacker/Qwen3-Math-Reasoning-LoRA
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Script to setup MinIO and upload LoRA adapters from Hugging Face Hub
set -e
# Colors for output
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
NC='\033[0m' # No Color
# Configuration
MINIO_DATA_DIR="${HOME}/dynamo_minio_data"
MINIO_ENDPOINT="http://localhost:9000"
MINIO_ACCESS_KEY="minioadmin"
MINIO_SECRET_KEY="minioadmin"
BUCKET_NAME="my-loras"
# Default LoRA to download (can be overridden)
HF_LORA_REPO="${HF_LORA_REPO:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA}"
LORA_NAME="${LORA_NAME:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA}"
# TEMP_DIR will be created using mktemp when needed
TEMP_DIR=""
# Parse command line arguments
MODE="full"
if [ "$1" = "--start" ]; then
MODE="start"
elif [ "$1" = "--stop" ]; then
MODE="stop"
elif [ "$1" = "--help" ] || [ "$1" = "-h" ]; then
MODE="help"
elif [ -n "$1" ]; then
echo -e "${RED}Error: Unknown option '$1'${NC}"
MODE="help"
fi
print_info() {
echo -e "${YELLOW}$1${NC}"
}
print_success() {
echo -e "${GREEN}$1${NC}"
}
print_error() {
echo -e "${RED}$1${NC}"
}
# Show help message
show_help() {
echo "Usage: $0 [OPTIONS]"
echo ""
echo "Setup MinIO and upload LoRA adapters from Hugging Face Hub"
echo ""
echo "Options:"
echo " (no options) Run full setup: start MinIO, download and upload LoRA"
echo " --start Only start MinIO container"
echo " --stop Stop and remove MinIO container"
echo " --help, -h Show this help message"
echo ""
echo "Environment Variables:"
echo " HF_LORA_REPO Hugging Face repository (default: ${HF_LORA_REPO:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA})"
echo " LORA_NAME Local name for the LoRA (default: ${LORA_NAME:-Neural-Hacker/Qwen3-Math-Reasoning-LoRA})"
echo ""
echo "Examples:"
echo " $0 # Full setup"
echo " $0 --start # Start MinIO only"
echo " $0 --stop # Stop MinIO"
echo " HF_LORA_REPO=user/repo $0 # Use custom LoRA"
echo ""
}
# Check if required tools are installed
check_dependencies() {
print_info "Checking dependencies..."
if ! command -v docker &> /dev/null; then
echo "Error: docker is not installed"
exit 1
fi
if ! command -v aws &> /dev/null; then
echo "Error: aws-cli is not installed. Install with: pip install awscli"
exit 1
fi
if ! command -v huggingface-cli &> /dev/null; then
echo "Error: huggingface-cli is not installed. Install with: pip install huggingface-hub"
exit 1
fi
print_success "All dependencies are installed"
}
# Start MinIO using Docker
start_minio() {
print_info "Setting up MinIO..."
# Create data directory
mkdir -p "${MINIO_DATA_DIR}"
# Stop and remove existing container if it exists
docker stop dynamo-minio 2>/dev/null || true
docker rm dynamo-minio 2>/dev/null || true
# Start MinIO
print_info "Starting MinIO container..."
docker run -d \
--name dynamo-minio \
-p 9000:9000 \
-p 9001:9001 \
-v "${MINIO_DATA_DIR}:/data" \
quay.io/minio/minio server /data \
--console-address ":9001"
# Wait for MinIO to be ready
print_info "Waiting for MinIO to be ready..."
for i in {1..30}; do
if curl -s ${MINIO_ENDPOINT}/minio/health/live > /dev/null 2>&1; then
print_success "MinIO is ready"
break
fi
if [ $i -eq 30 ]; then
echo "Error: MinIO did not start in time"
exit 1
fi
sleep 1
done
print_success "MinIO started successfully"
echo " - MinIO API: ${MINIO_ENDPOINT}"
echo " - MinIO Console: http://localhost:9001"
echo " - Username: ${MINIO_ACCESS_KEY}"
echo " - Password: ${MINIO_SECRET_KEY}"
}
# Configure AWS CLI for MinIO
configure_aws_cli() {
print_info "Configuring AWS CLI for MinIO..."
export AWS_ACCESS_KEY_ID="${MINIO_ACCESS_KEY}"
export AWS_SECRET_ACCESS_KEY="${MINIO_SECRET_KEY}"
export AWS_ENDPOINT_URL="${MINIO_ENDPOINT}"
# Create bucket if it doesn't exist
if ! aws --endpoint-url=${MINIO_ENDPOINT} s3 ls s3://${BUCKET_NAME} 2>/dev/null; then
print_info "Creating bucket: ${BUCKET_NAME}"
aws --endpoint-url=${MINIO_ENDPOINT} s3 mb s3://${BUCKET_NAME}
print_success "Bucket created"
else
print_success "Bucket already exists: ${BUCKET_NAME}"
fi
}
# Download LoRA from Hugging Face Hub
download_lora_from_hf() {
print_info "Downloading LoRA from Hugging Face Hub..."
echo " - Repository: ${HF_LORA_REPO}"
echo " - Local name: ${LORA_NAME}"
# Create temporary directory using mktemp (global variable for cleanup)
TEMP_DIR=$(mktemp -d -t lora_download_XXXXXX)
# Download LoRA adapter files
print_info "Downloading adapter files..."
huggingface-cli download "${HF_LORA_REPO}" \
--local-dir "${TEMP_DIR}" \
--local-dir-use-symlinks False
print_success "LoRA downloaded to ${TEMP_DIR}"
# List downloaded files
echo "Downloaded files:"
ls -lh "${TEMP_DIR}"
}
# Upload LoRA to MinIO
upload_lora_to_minio() {
print_info "Uploading LoRA to MinIO..."
# Upload all files to S3
aws --endpoint-url=${MINIO_ENDPOINT} s3 sync \
"${TEMP_DIR}" \
"s3://${BUCKET_NAME}/${LORA_NAME}" \
--exclude "*.git*"
print_success "LoRA uploaded to s3://${BUCKET_NAME}/${LORA_NAME}"
# List uploaded files
echo "Uploaded files:"
aws --endpoint-url=${MINIO_ENDPOINT} s3 ls "s3://${BUCKET_NAME}/${LORA_NAME}/" --recursive
}
# Cleanup temp files
cleanup() {
if [ -n "${TEMP_DIR}" ] && [ -d "${TEMP_DIR}" ]; then
print_info "Cleaning up temporary files..."
rm -rf "${TEMP_DIR}"
print_success "Cleanup complete"
fi
}
# Stop MinIO
stop_minio() {
print_info "Stopping MinIO..."
if docker ps | grep -q dynamo-minio; then
docker stop dynamo-minio 2>/dev/null
print_success "MinIO container stopped"
else
print_info "MinIO container is not running"
fi
if docker ps -a | grep -q dynamo-minio; then
docker rm dynamo-minio 2>/dev/null
print_success "MinIO container removed"
fi
echo ""
echo "MinIO has been stopped."
echo "Data is preserved in: ${MINIO_DATA_DIR}"
echo ""
echo "To start MinIO again:"
echo " $0 --start"
echo ""
}
# Start MinIO only (without downloading/uploading LoRA)
start_only() {
echo "========================================"
echo "Starting MinIO"
echo "========================================"
echo ""
start_minio
echo ""
echo "========================================"
echo "MinIO Started!"
echo "========================================"
echo ""
echo "MinIO is now running."
echo ""
echo "To upload a LoRA, run the full setup:"
echo " $0"
echo ""
echo "Or manually upload using AWS CLI:"
echo " export AWS_ACCESS_KEY_ID=${MINIO_ACCESS_KEY}"
echo " export AWS_SECRET_ACCESS_KEY=${MINIO_SECRET_KEY}"
echo " aws --endpoint-url=${MINIO_ENDPOINT} s3 cp your-lora/ s3://${BUCKET_NAME}/your-lora/ --recursive"
echo ""
echo "To stop MinIO:"
echo " $0 --stop"
echo ""
}
# Full setup (start MinIO + download/upload LoRA)
full_setup() {
echo "========================================"
echo "MinIO Setup & LoRA Upload Script"
echo "========================================"
echo ""
check_dependencies
echo ""
start_minio
echo ""
configure_aws_cli
echo ""
download_lora_from_hf
echo ""
upload_lora_to_minio
echo ""
cleanup
echo ""
echo "========================================"
echo "Setup Complete!"
echo "========================================"
echo ""
echo "MinIO is running and LoRA has been uploaded."
echo ""
echo "Next steps:"
echo " 1. Run the Dynamo service with LoRA support:"
echo " ./agg_lora_s3.sh"
echo ""
echo " 2. Load the LoRA adapter:"
echo " curl -X POST http://localhost:8081/v1/loras \\"
echo " -H \"Content-Type: application/json\" \\"
echo " -d '{\"lora_name\": \"${LORA_NAME}\", \"source\": {\"uri\": \"s3://${BUCKET_NAME}/${LORA_NAME}\"}}'"
echo ""
echo " 3. Run inference with the LoRA:"
echo " curl -X POST http://localhost:8000/v1/chat/completions \\"
echo " -H \"Content-Type: application/json\" \\"
echo " -d '{\"model\": \"${LORA_NAME}\", \"messages\": [{\"role\": \"user\", \"content\": \"your prompt here\"}]}'"
echo ""
echo "To stop MinIO:"
echo " $0 --stop"
echo ""
}
# Main execution
case "$MODE" in
start)
start_only
;;
stop)
stop_minio
;;
help)
show_help
exit 0
;;
full)
full_setup
;;
*)
echo "Error: Unknown mode '$MODE'"
show_help
exit 1
;;
esac
...@@ -227,8 +227,15 @@ fn lora_name_to_id(lora_name: &str) -> i32 { ...@@ -227,8 +227,15 @@ fn lora_name_to_id(lora_name: &str) -> i32 {
/// Create an engine and attach it to an endpoint to make it visible to the frontend. /// Create an engine and attach it to an endpoint to make it visible to the frontend.
/// This is the main way you create a Dynamo worker / backend. /// This is the main way you create a Dynamo worker / backend.
///
/// If `lora_name` is provided, this function will publish a LoRA adapter instead of a base model:
/// - LoRA path: v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}
/// - Base model path: v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}
///
/// For LoRA mode, both `lora_name` and `base_model_path` must be provided together.
/// Providing only one of them will result in an error.
#[pyfunction] #[pyfunction]
#[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None, media_decoder=None, media_fetcher=None))] #[pyo3(signature = (model_input, model_type, endpoint, model_path, model_name=None, context_length=None, kv_cache_block_size=None, router_mode=None, migration_limit=0, runtime_config=None, user_data=None, custom_template_path=None, media_decoder=None, media_fetcher=None, lora_name=None, base_model_path=None))]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn register_llm<'p>( fn register_llm<'p>(
py: Python<'p>, py: Python<'p>,
...@@ -246,6 +253,8 @@ fn register_llm<'p>( ...@@ -246,6 +253,8 @@ fn register_llm<'p>(
custom_template_path: Option<&str>, custom_template_path: Option<&str>,
media_decoder: Option<MediaDecoder>, media_decoder: Option<MediaDecoder>,
media_fetcher: Option<MediaFetcher>, media_fetcher: Option<MediaFetcher>,
lora_name: Option<&str>,
base_model_path: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
// Validate Prefill model type requirements // Validate Prefill model type requirements
if model_type.inner == llm_rs::model_type::ModelType::Prefill { if model_type.inner == llm_rs::model_type::ModelType::Prefill {
...@@ -270,7 +279,7 @@ fn register_llm<'p>( ...@@ -270,7 +279,7 @@ fn register_llm<'p>(
let model_type_obj = model_type.inner; let model_type_obj = model_type.inner;
let inner_path = model_path.to_string(); let inner_path = model_path.to_string();
let mut model_name = model_name.map(|n| n.to_string()); let model_name = model_name.map(|n| n.to_string());
let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin); let router_mode = router_mode.unwrap_or(RouterMode::RoundRobin);
let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default()); let router_config = RouterConfig::new(router_mode.into(), KvRouterConfig::default());
...@@ -294,16 +303,31 @@ fn register_llm<'p>( ...@@ -294,16 +303,31 @@ fn register_llm<'p>(
PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err)) PyErr::new::<PyException, _>(format!("Failed to convert user_data: {}", err))
})?; })?;
// Validate LoRA parameters: both or neither must be provided
if lora_name.is_some() ^ base_model_path.is_some() {
return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"lora_name and base_model_path must both be provided together, or neither",
));
}
// Determine source_path and lora_identifier based on registration mode
let (source_path, lora_identifier) = match (lora_name, base_model_path) {
(Some(lora), Some(base)) => (base.to_string(), Some(lora.to_string())),
_ => (inner_path, None),
};
// Model name: use lora name if present, otherwise provided name or default to source path
let model_name = lora_identifier
.clone()
.or(model_name)
.or_else(|| Some(source_path.clone()));
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let model_path = if fs::exists(&inner_path)? { // Resolve the model path (local or fetch from HuggingFace)
PathBuf::from(inner_path) let model_path = if fs::exists(&source_path)? {
PathBuf::from(&source_path)
} else { } else {
// Preserve the model name LocalModel::fetch(&source_path, false)
if model_name.is_none() {
model_name = Some(inner_path.clone());
}
// Likely it's a Hugging Face repo, download it
LocalModel::fetch(&inner_path, false)
.await .await
.map_err(to_pyerr)? .map_err(to_pyerr)?
}; };
...@@ -311,7 +335,7 @@ fn register_llm<'p>( ...@@ -311,7 +335,7 @@ fn register_llm<'p>(
let mut builder = dynamo_llm::local_model::LocalModelBuilder::default(); let mut builder = dynamo_llm::local_model::LocalModelBuilder::default();
builder builder
.model_path(model_path) .model_path(model_path)
.model_name(model_name) .model_name(model_name.clone())
.context_length(context_length) .context_length(context_length)
.kv_cache_block_size(kv_cache_block_size) .kv_cache_block_size(kv_cache_block_size)
.router_config(Some(router_config)) .router_config(Some(router_config))
...@@ -321,24 +345,53 @@ fn register_llm<'p>( ...@@ -321,24 +345,53 @@ fn register_llm<'p>(
.custom_template_path(custom_template_path_owned) .custom_template_path(custom_template_path_owned)
.media_decoder(media_decoder.map(|m| m.inner)) .media_decoder(media_decoder.map(|m| m.inner))
.media_fetcher(media_fetcher.map(|m| m.inner)); .media_fetcher(media_fetcher.map(|m| m.inner));
// Load the ModelDeploymentCard
let mut local_model = builder.build().await.map_err(to_pyerr)?; let mut local_model = builder.build().await.map_err(to_pyerr)?;
// Advertise ourself so ingress can find us
local_model local_model
.attach(&endpoint.inner, model_type_obj, model_input) .attach(
&endpoint.inner,
model_type_obj,
model_input,
lora_identifier.as_deref(),
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
if let Some(lora_name) = lora_identifier {
tracing::info!("Registered LoRA '{}' MDC", lora_name);
} else {
tracing::info!("Registered base model '{:?}' MDC", model_name);
}
Ok(()) Ok(())
}) })
} }
/// Unregister a model from the endpoint. /// Unregister a Model Deployment Card (MDC) from the service registry
///
/// This removes an LLM deployment from the discovery system.
///
/// # Arguments
///
/// * `endpoint` - The endpoint where the model is registered
/// * `lora_name` - Optional LoRA adapter name (if unregistering a LoRA deployment)
///
/// # MDC Path Format
///
/// - Base model: `v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}`
/// - LoRA model: `v1/mdc/{namespace}/{component}/{endpoint}/{instance_id}/{lora_slug}`
#[pyfunction] #[pyfunction]
#[pyo3(signature = (endpoint))] #[pyo3(signature = (endpoint, lora_name=None))]
fn unregister_llm<'p>(py: Python<'p>, endpoint: Endpoint) -> PyResult<Bound<'p, PyAny>> { fn unregister_llm<'p>(
py: Python<'p>,
endpoint: Endpoint,
lora_name: Option<&str>,
) -> PyResult<Bound<'p, PyAny>> {
let lora_name_owned = lora_name.map(|s| s.to_string());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
LocalModel::detach_model_from_endpoint(&endpoint.inner) // Unified detach method handles both base models and LoRA adapters
LocalModel::detach_from_endpoint(&endpoint.inner, lora_name_owned.as_deref())
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
Ok(()) Ok(())
...@@ -606,7 +659,7 @@ impl Endpoint { ...@@ -606,7 +659,7 @@ impl Endpoint {
generator, generator,
self.event_loop.clone(), self.event_loop.clone(),
)?); )?);
let ingress = JsonServerStreamingIngress::for_engine(engine).map_err(to_pyerr)?; let ingress = JsonServerStreamingIngress::for_engine(engine.clone()).map_err(to_pyerr)?;
// Convert Python dict to serde_json::Value if provided and validate it's an object // Convert Python dict to serde_json::Value if provided and validate it's an object
let health_payload_json = health_check_payload let health_payload_json = health_check_payload
...@@ -638,6 +691,9 @@ impl Endpoint { ...@@ -638,6 +691,9 @@ impl Endpoint {
builder = builder.health_check_payload(payload); builder = builder.health_check_payload(payload);
} }
// Register the engine in the local endpoint registry for in-process calls
builder = builder.register_local_engine(engine).map_err(to_pyerr)?;
let graceful_shutdown = graceful_shutdown.unwrap_or(true); let graceful_shutdown = graceful_shutdown.unwrap_or(true);
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
builder builder
......
...@@ -85,4 +85,11 @@ impl LoRADownloader { ...@@ -85,4 +85,11 @@ impl LoRADownloader {
pyo3::exceptions::PyRuntimeError::new_err(format!("Validation failed: {}", e)) pyo3::exceptions::PyRuntimeError::new_err(format!("Validation failed: {}", e))
}) })
} }
/// Convert a LoRA URI to a cache key.
/// This ensures consistent cache key generation across Rust and Python.
#[staticmethod]
fn uri_to_cache_key(uri: &str) -> String {
RsLoRACache::uri_to_cache_key(uri)
}
} }
...@@ -1067,8 +1067,32 @@ async def register_llm( ...@@ -1067,8 +1067,32 @@ async def register_llm(
runtime_config: Optional[ModelRuntimeConfig] = None, runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None, user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None, custom_template_path: Optional[str] = None,
lora_name: Optional[str] = None,
base_model_path: Optional[str] = None,
) -> None: ) -> None:
"""Attach the model at path to the given endpoint, and advertise it as model_type""" """
Attach the model at path to the given endpoint, and advertise it as model_type.
LoRA Registration:
The `lora_name` and `base_model_path` parameters must be provided together or not at all.
Providing only one of these parameters will raise a ValueError.
- `lora_name`: The served model name for the LoRA model
- `base_model_path`: Path to the base model that the LoRA extends
"""
...
async def unregister_llm(
endpoint: Endpoint,
lora_name: Optional[str] = None,
) -> None:
"""
Unregister a model from the discovery system.
If lora_name is provided, unregisters a LoRA adapter instead of a base model.
"""
...
def lora_name_to_id(lora_name: str) -> int:
"""Generate a deterministic integer ID from a LoRA name using blake3 hash."""
... ...
async def fetch_llm(remote_name: str) -> str: async def fetch_llm(remote_name: str) -> str:
......
...@@ -45,7 +45,7 @@ pub async fn run( ...@@ -45,7 +45,7 @@ pub async fn run(
Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>, Pin<Box<dyn AsyncEngineStream<Annotated<NvCreateChatCompletionStreamResponse>>>>,
>::for_engine(engine)?; >::for_engine(engine)?;
model model
.attach(&endpoint, ModelType::Chat, ModelInput::Text) .attach(&endpoint, ModelType::Chat, ModelInput::Text, None)
.await?; .await?;
let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start(); let fut_chat = endpoint.endpoint_builder().handler(ingress_chat).start();
...@@ -76,7 +76,7 @@ pub async fn run( ...@@ -76,7 +76,7 @@ pub async fn run(
ModelType::Chat | ModelType::Completions ModelType::Chat | ModelType::Completions
}; };
model model
.attach(&endpoint, model_type, ModelInput::Tokens) .attach(&endpoint, model_type, ModelInput::Tokens, None)
.await?; .await?;
let fut = endpoint.endpoint_builder().handler(ingress).start(); let fut = endpoint.endpoint_builder().handler(ingress).start();
......
...@@ -424,24 +424,46 @@ impl LocalModel { ...@@ -424,24 +424,46 @@ impl LocalModel {
self.card self.card
} }
/// Attach this model the endpoint. This registers it on the network /// Attach this model to the endpoint. This registers it on the network
/// allowing ingress to discover it. /// allowing ingress to discover it.
///
/// For base models, pass `lora_name = None`.
/// For LoRA adapters, pass `lora_name = Some("adapter-name")`.
pub async fn attach( pub async fn attach(
&mut self, &mut self,
endpoint: &Endpoint, endpoint: &Endpoint,
model_type: ModelType, model_type: ModelType,
model_input: ModelInput, model_input: ModelInput,
lora_name: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.card.model_type = model_type; self.card.model_type = model_type;
self.card.model_input = model_input; self.card.model_input = model_input;
// Compute model_suffix from lora_name if present
let model_suffix = lora_name.map(|name| Slug::slugify(name).to_string());
let suffix_for_log = model_suffix
.as_ref()
.map(|s| format!("/{}", s))
.unwrap_or_default();
tracing::debug!(
"Registering MDC at path: {}/{}/{}/{:x}{}",
endpoint.component().namespace().name(),
endpoint.component().name(),
endpoint.name(),
endpoint.drt().connection_id(),
suffix_for_log
);
// Register the Model Deployment Card via discovery interface // Register the Model Deployment Card via discovery interface
// The model_suffix (for LoRA) will be appended AFTER the instance_id
let discovery = endpoint.drt().discovery(); let discovery = endpoint.drt().discovery();
let spec = DiscoverySpec::from_model( let spec = DiscoverySpec::from_model_with_suffix(
endpoint.component().namespace().name().to_string(), endpoint.component().namespace().name().to_string(),
endpoint.component().name().to_string(), endpoint.component().name().to_string(),
endpoint.name().to_string(), endpoint.name().to_string(),
&self.card, &self.card,
model_suffix,
)?; )?;
let _instance = discovery.register(spec).await?; let _instance = discovery.register(spec).await?;
...@@ -449,24 +471,40 @@ impl LocalModel { ...@@ -449,24 +471,40 @@ impl LocalModel {
} }
/// Helper associated function to detach a model from an endpoint /// Helper associated function to detach a model from an endpoint
pub async fn detach_model_from_endpoint(endpoint: &Endpoint) -> anyhow::Result<()> { ///
/// For base models, pass `lora_name = None`.
/// For LoRA adapters, pass `lora_name = Some("adapter-name")`.
pub async fn detach_from_endpoint(
endpoint: &Endpoint,
lora_name: Option<&str>,
) -> anyhow::Result<()> {
let drt = endpoint.drt(); let drt = endpoint.drt();
let instance_id = drt.connection_id(); let instance_id = drt.connection_id();
let endpoint_id = endpoint.id(); let endpoint_id = endpoint.id();
// Compute model_suffix from lora_name if present
let model_suffix = lora_name.map(|name| Slug::slugify(name).to_string());
let instance = DiscoveryInstance::Model { let instance = DiscoveryInstance::Model {
namespace: endpoint_id.namespace, namespace: endpoint_id.namespace,
component: endpoint_id.component, component: endpoint_id.component,
endpoint: endpoint_id.name, endpoint: endpoint_id.name,
instance_id, instance_id,
card_json: serde_json::Value::Null, card_json: serde_json::Value::Null,
model_suffix,
}; };
let discovery = drt.discovery(); let discovery = drt.discovery();
discovery.unregister(instance).await?; discovery.unregister(instance).await?;
tracing::info!("Successfully unregistered model from discovery"); if let Some(lora_name) = lora_name {
tracing::info!(
"Successfully unregistered LoRA '{}' from discovery",
lora_name
);
} else {
tracing::info!("Successfully unregistered model from discovery");
}
Ok(()) Ok(())
} }
......
...@@ -43,6 +43,13 @@ impl LoRACache { ...@@ -43,6 +43,13 @@ impl LoRACache {
self.get_cache_path(lora_id).exists() self.get_cache_path(lora_id).exists()
} }
/// Convert a LoRA URI to a cache key.
/// This is a static method to ensure consistent cache key generation
/// across Rust and Python code.
pub fn uri_to_cache_key(uri: &str) -> String {
uri.replace("://", "__").replace(['/', '\\', '.'], "_")
}
/// Validate cached LoRA has required files /// Validate cached LoRA has required files
/// TODO: Add support for other weight file formats supported by trtllm /// TODO: Add support for other weight file formats supported by trtllm
pub fn validate_cached(&self, lora_id: &str) -> Result<bool> { pub fn validate_cached(&self, lora_id: &str) -> Result<bool> {
...@@ -121,4 +128,16 @@ mod tests { ...@@ -121,4 +128,16 @@ mod tests {
assert!(!cache.validate_cached("invalid-lora").unwrap()); assert!(!cache.validate_cached("invalid-lora").unwrap());
} }
#[test]
fn test_uri_to_cache_key() {
assert_eq!(
LoRACache::uri_to_cache_key("s3://bucket/path/to/lora"),
"s3__bucket_path_to_lora"
);
assert_eq!(
LoRACache::uri_to_cache_key("file:///local/path"),
"file___local_path"
);
}
} }
...@@ -65,8 +65,8 @@ impl LoRADownloader { ...@@ -65,8 +65,8 @@ impl LoRADownloader {
anyhow::bail!("LoRA {} not found in any source", lora_uri) anyhow::bail!("LoRA {} not found in any source", lora_uri)
} }
/// Convert URI to cache key /// Convert URI to cache key (delegates to LoRACache for consistency)
fn uri_to_cache_key(&self, uri: &str) -> String { fn uri_to_cache_key(&self, uri: &str) -> String {
uri.replace("://", "_").replace(['/', '\\'], "_") LoRACache::uri_to_cache_key(uri)
} }
} }
...@@ -3,11 +3,13 @@ ...@@ -3,11 +3,13 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
use object_store::{ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath}; use object_store::{ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath};
use std::{ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::Duration,
}; };
use url::Url; use url::Url;
...@@ -85,6 +87,57 @@ pub struct S3LoRASource { ...@@ -85,6 +87,57 @@ pub struct S3LoRASource {
endpoint: Option<String>, endpoint: Option<String>,
} }
/// Retry configuration for S3 operations
impl S3LoRASource {
/// Maximum number of retry attempts for S3 operations
const MAX_RETRIES: u32 = 3;
/// Initial backoff duration in milliseconds
const INITIAL_BACKOFF_MS: u64 = 1000;
/// Maximum backoff duration in milliseconds
const MAX_BACKOFF_MS: u64 = 30000;
/// Download a single file with retry logic and exponential backoff
async fn download_file_with_retry(
store: &Arc<dyn ObjectStore>,
location: &ObjectPath,
) -> Result<Bytes> {
for attempt in 1..=Self::MAX_RETRIES {
let result = store.get(location).await;
let error = match result {
Ok(get_result) => match get_result.bytes().await {
Ok(bytes) => return Ok(bytes),
Err(e) => anyhow::anyhow!("Failed to read bytes: {}", e),
},
Err(e) => anyhow::anyhow!("Failed to get object: {}", e),
};
if attempt >= Self::MAX_RETRIES {
return Err(error);
}
// Calculate backoff with exponential increase, capped at MAX_BACKOFF_MS
let backoff_ms = std::cmp::min(
Self::INITIAL_BACKOFF_MS * 2u64.pow(attempt - 1),
Self::MAX_BACKOFF_MS,
);
tracing::warn!(
"S3 download failed (attempt {}/{}), retrying in {}ms: {}",
attempt,
Self::MAX_RETRIES,
backoff_ms,
error
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
// This should be unreachable, but provide a fallback
Err(anyhow::anyhow!(
"S3 download failed after {} retries",
Self::MAX_RETRIES
))
}
}
impl S3LoRASource { impl S3LoRASource {
/// Create S3 source from environment variables: /// Create S3 source from environment variables:
/// - AWS_ACCESS_KEY_ID /// - AWS_ACCESS_KEY_ID
...@@ -173,12 +226,47 @@ impl LoRASource for S3LoRASource { ...@@ -173,12 +226,47 @@ impl LoRASource for S3LoRASource {
let object_prefix = ObjectPath::from(prefix.clone()); let object_prefix = ObjectPath::from(prefix.clone());
let mut list_stream = bucket_store.list(Some(&object_prefix)); let mut list_stream = bucket_store.list(Some(&object_prefix));
// Create destination directory // Create a temporary directory in the same parent as dest_path for atomic download
tokio::fs::create_dir_all(dest_path).await?; // This prevents data loss if dest_path already exists
let parent = dest_path
.parent()
.ok_or_else(|| anyhow::anyhow!("Destination path has no parent directory"))?;
let dest_name = dest_path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Destination path has no file name"))?;
// Generate unique temp directory name
let temp_suffix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos();
let temp_dir_name = format!("{}.tmp.{}", dest_name, temp_suffix);
let temp_path = parent.join(&temp_dir_name);
// Create temporary directory
tokio::fs::create_dir_all(&temp_path)
.await
.context("Failed to create temporary directory")?;
// Cleanup closure that only removes the temp directory on error
let cleanup_on_error = async |err: anyhow::Error| -> anyhow::Error {
tracing::warn!(
"S3 download failed, cleaning up temporary directory at {:?}",
temp_path
);
if let Err(cleanup_err) = tokio::fs::remove_dir_all(&temp_path).await {
tracing::warn!("Failed to cleanup temporary directory: {}", cleanup_err);
}
err
};
let mut file_count = 0; let mut file_count = 0;
while let Some(meta_result) = list_stream.next().await { while let Some(meta_result) = list_stream.next().await {
let meta = meta_result?; let meta = match meta_result {
Ok(m) => m,
Err(e) => return Err(cleanup_on_error(e.into()).await),
};
// Get relative path (remove prefix) // Get relative path (remove prefix)
let rel_path = meta let rel_path = meta
...@@ -192,24 +280,47 @@ impl LoRASource for S3LoRASource { ...@@ -192,24 +280,47 @@ impl LoRASource for S3LoRASource {
continue; // Skip the prefix itself continue; // Skip the prefix itself
} }
let file_path = dest_path.join(rel_path); let file_path = temp_path.join(rel_path);
// Create parent directories // Create parent directories
#[allow(clippy::collapsible_if)]
if let Some(parent) = file_path.parent() { if let Some(parent) = file_path.parent() {
tokio::fs::create_dir_all(parent).await?; if let Err(e) = tokio::fs::create_dir_all(parent).await {
return Err(cleanup_on_error(e.into()).await);
}
} }
// Download file // Download file with retry logic
let bytes = bucket_store.get(&meta.location).await?.bytes().await?; let bytes = match Self::download_file_with_retry(&bucket_store, &meta.location).await {
tokio::fs::write(&file_path, &bytes).await?; Ok(b) => b,
Err(e) => return Err(cleanup_on_error(e).await),
};
if let Err(e) = tokio::fs::write(&file_path, &bytes).await {
return Err(cleanup_on_error(e.into()).await);
}
file_count += 1; file_count += 1;
tracing::debug!("Downloaded: {} ({} bytes)", rel_path, bytes.len()); tracing::debug!("Downloaded: {} ({} bytes)", rel_path, bytes.len());
} }
if file_count == 0 { if file_count == 0 {
anyhow::bail!("No files found at S3 URI: {}", s3_uri); return Err(
cleanup_on_error(anyhow::anyhow!("No files found at S3 URI: {}", s3_uri)).await,
);
}
// Atomically rename temp directory to final destination
// Remove dest_path if it exists (only after successful download to avoid data loss)
if dest_path.exists() {
tokio::fs::remove_dir_all(dest_path)
.await
.context("Failed to remove existing destination directory")?;
} }
// Rename is atomic on most filesystems
tokio::fs::rename(&temp_path, dest_path)
.await
.context("Failed to atomically move temporary directory to destination")?;
tracing::info!("Downloaded {} files from S3 to {:?}", file_count, dest_path); tracing::info!("Downloaded {} files from S3 to {:?}", file_count, dest_path);
......
...@@ -378,6 +378,7 @@ mod integration_tests { ...@@ -378,6 +378,7 @@ mod integration_tests {
&test_endpoint, &test_endpoint,
dynamo_llm::model_type::ModelType::Chat, dynamo_llm::model_type::ModelType::Chat,
dynamo_llm::model_type::ModelInput::Text, dynamo_llm::model_type::ModelInput::Text,
None,
) )
.await .await
.unwrap(); .unwrap();
......
...@@ -63,6 +63,22 @@ impl EndpointConfigBuilder { ...@@ -63,6 +63,22 @@ impl EndpointConfigBuilder {
self._stats_handler(Some(Box::new(handler))) self._stats_handler(Some(Box::new(handler)))
} }
/// Register an async engine in the local endpoint registry for direct in-process calls
pub fn register_local_engine(
self,
engine: crate::local_endpoint_registry::LocalAsyncEngine,
) -> Result<Self> {
if let Some(endpoint) = &self.endpoint {
let registry = endpoint.drt().local_endpoint_registry();
registry.register(endpoint.name.clone(), engine);
tracing::debug!(
"Registered engine for endpoint '{}' in local registry",
endpoint.name
);
}
Ok(self)
}
pub async fn start(self) -> Result<()> { pub async fn start(self) -> Result<()> {
let ( let (
endpoint, endpoint,
......
...@@ -225,6 +225,9 @@ pub mod llm { ...@@ -225,6 +225,9 @@ pub mod llm {
/// HTTP body size limit in MB /// HTTP body size limit in MB
pub const DYN_HTTP_BODY_LIMIT_MB: &str = "DYN_HTTP_BODY_LIMIT_MB"; pub const DYN_HTTP_BODY_LIMIT_MB: &str = "DYN_HTTP_BODY_LIMIT_MB";
/// Enable LoRA adapter support (set to "true" to enable)
pub const DYN_LORA_ENABLED: &str = "DYN_LORA_ENABLED";
/// LoRA cache directory path /// LoRA cache directory path
pub const DYN_LORA_PATH: &str = "DYN_LORA_PATH"; pub const DYN_LORA_PATH: &str = "DYN_LORA_PATH";
...@@ -356,6 +359,7 @@ mod tests { ...@@ -356,6 +359,7 @@ mod tests {
kvbm::leader::DYN_KVBM_LEADER_ZMQ_ACK_PORT, kvbm::leader::DYN_KVBM_LEADER_ZMQ_ACK_PORT,
// LLM // LLM
llm::DYN_HTTP_BODY_LIMIT_MB, llm::DYN_HTTP_BODY_LIMIT_MB,
llm::DYN_LORA_ENABLED,
llm::DYN_LORA_PATH, llm::DYN_LORA_PATH,
llm::metrics::DYN_METRICS_PREFIX, llm::metrics::DYN_METRICS_PREFIX,
// Model // Model
......
...@@ -154,17 +154,39 @@ impl Discovery for KVStoreDiscovery { ...@@ -154,17 +154,39 @@ impl Discovery for KVStoreDiscovery {
component, component,
endpoint, endpoint,
instance_id, instance_id,
model_suffix,
.. ..
} => { } => {
let key = Self::model_key(namespace, component, endpoint, *instance_id); let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
tracing::debug!(
"KVStoreDiscovery::register: Registering model instance_id={}, namespace={}, component={}, endpoint={}, key={}", // If there's a model_suffix (e.g., for LoRA adapters), append it after the instance_id
instance_id, // Key format: {namespace}/{component}/{endpoint}/{instance_id:x}/{model_suffix}
namespace, if let Some(suffix) = model_suffix
component, && !suffix.is_empty()
endpoint, {
key key = format!("{}/{}", key, suffix);
); tracing::debug!(
"KVStoreDiscovery::register: Registering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
suffix,
instance_id,
namespace,
component,
endpoint,
key
);
}
// Log for base models (no suffix or empty suffix)
if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
tracing::debug!(
"KVStoreDiscovery::register: Registering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
instance_id,
namespace,
component,
endpoint,
key
);
}
(MODELS_BUCKET, key) (MODELS_BUCKET, key)
} }
}; };
...@@ -227,17 +249,38 @@ impl Discovery for KVStoreDiscovery { ...@@ -227,17 +249,38 @@ impl Discovery for KVStoreDiscovery {
component, component,
endpoint, endpoint,
instance_id, instance_id,
model_suffix,
.. ..
} => { } => {
let key = Self::model_key(namespace, component, endpoint, *instance_id); let mut key = Self::model_key(namespace, component, endpoint, *instance_id);
tracing::debug!(
"Unregistering model instance_id={}, namespace={}, component={}, endpoint={}, key={}", // If there's a model_suffix (e.g., for LoRA adapters), append it after the instance_id
instance_id, if let Some(suffix) = model_suffix
namespace, && !suffix.is_empty()
component, {
endpoint, key = format!("{}/{}", key, suffix);
key tracing::debug!(
); "KVStoreDiscovery::unregister: Unregistering LoRA model with suffix={}, instance_id={}, namespace={}, component={}, endpoint={}, key={}",
suffix,
instance_id,
namespace,
component,
endpoint,
key
);
}
// Log for base models (no suffix or empty suffix)
if model_suffix.as_ref().is_none_or(|s| s.is_empty()) {
tracing::debug!(
"Unregistering base model instance_id={}, namespace={}, component={}, endpoint={}, key={}",
instance_id,
namespace,
component,
endpoint,
key
);
}
(MODELS_BUCKET, key) (MODELS_BUCKET, key)
} }
}; };
...@@ -353,18 +396,38 @@ impl Discovery for KVStoreDiscovery { ...@@ -353,18 +396,38 @@ impl Discovery for KVStoreDiscovery {
// Extract instance_id from the key path, not the value // Extract instance_id from the key path, not the value
// Delete events have empty values in etcd, so we parse the instance_id from the key // Delete events have empty values in etcd, so we parse the instance_id from the key
// Key format: "v1/instances/namespace/component/endpoint/{instance_id:x}" //
let key_parts: Vec<&str> = key_str.split('/').collect(); // Key format (relative to bucket, after stripping bucket prefix):
match key_parts.last() { // - Instances: "namespace/component/endpoint/{instance_id:x}"
// - Models: "namespace/component/endpoint/{instance_id:x}"
// - LoRA models: "namespace/component/endpoint/{instance_id:x}/{lora_slug}"
//
// The instance_id is always at index 3 in the RELATIVE key (after bucket prefix).
// Use strip_bucket_prefix for consistency with matches_prefix().
let relative_key = Self::strip_bucket_prefix(key_str, bucket_name);
let key_parts: Vec<&str> = relative_key.split('/').collect();
// In relative key: namespace/component/endpoint/{instance_id}[/{lora_slug}]
// instance_id is at index 3
let instance_id_index = 3;
match key_parts.get(instance_id_index) {
Some(instance_id_hex) => { Some(instance_id_hex) => {
match u64::from_str_radix(instance_id_hex, 16) { match u64::from_str_radix(instance_id_hex, 16) {
Ok(instance_id) => { Ok(instance_id) => {
tracing::debug!(
"KVStoreDiscovery::list_and_watch: Emitting Removed event for instance_id={:x}, key={}",
instance_id,
key_str
);
Some(DiscoveryEvent::Removed(instance_id)) Some(DiscoveryEvent::Removed(instance_id))
} }
Err(e) => { Err(e) => {
tracing::warn!( tracing::warn!(
key = %key_str, key = %key_str,
relative_key = %relative_key,
error = %e, error = %e,
instance_id_hex = %instance_id_hex,
"Failed to parse instance_id hex from deleted key" "Failed to parse instance_id hex from deleted key"
); );
None None
...@@ -374,7 +437,10 @@ impl Discovery for KVStoreDiscovery { ...@@ -374,7 +437,10 @@ impl Discovery for KVStoreDiscovery {
None => { None => {
tracing::warn!( tracing::warn!(
key = %key_str, key = %key_str,
"Delete event key has no path components" relative_key = %relative_key,
expected_index = instance_id_index,
actual_parts = key_parts.len(),
"Delete event key doesn't have instance_id at expected position"
); );
None None
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment