Unverified Commit 026f361d authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: resolve lora request for multimodal workers (#6399)

parent bc320806
...@@ -12,6 +12,7 @@ import threading ...@@ -12,6 +12,7 @@ import threading
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Final from typing import Any, AsyncGenerator, Dict, Final
import torch import torch
...@@ -50,6 +51,14 @@ configure_dynamo_logging() ...@@ -50,6 +51,14 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class LoRAInfo:
"""Metadata for a loaded LoRA adapter."""
id: int
path: str
def _compute_mm_uuids( def _compute_mm_uuids(
multi_modal_data: Dict[str, Any] | None multi_modal_data: Dict[str, Any] | None
) -> Dict[str, list[str]] | None: ) -> Dict[str, list[str]] | None:
...@@ -287,9 +296,8 @@ class BaseWorkerHandler(ABC): ...@@ -287,9 +296,8 @@ class BaseWorkerHandler(ABC):
# NIXL connector for frontend decoding - lazy initialized # NIXL connector for frontend decoding - lazy initialized
self._nixl_connector = None self._nixl_connector = None
self._nixl_connector_lock = asyncio.Lock() self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking # LoRA tracking: name -> LoRAInfo(id, path)
self.lora_id_for_name: dict[str, int] = {} self.loaded_loras: dict[str, LoRAInfo] = {}
self.lora_name_to_path: dict[str, str] = {}
# Per-LoRA locks to prevent concurrent load operations for the same LoRA # Per-LoRA locks to prevent concurrent load operations for the same LoRA
self._lora_load_locks: dict[str, asyncio.Lock] = {} self._lora_load_locks: dict[str, asyncio.Lock] = {}
# Guard lock-map access in case handlers are invoked from multiple threads. # Guard lock-map access in case handlers are invoked from multiple threads.
...@@ -458,6 +466,16 @@ class BaseWorkerHandler(ABC): ...@@ -458,6 +466,16 @@ class BaseWorkerHandler(ABC):
if temp_dir is not None: if temp_dir is not None:
self.temp_dirs.append(temp_dir) self.temp_dirs.append(temp_dir)
def _resolve_lora_request(self, model_name: str | None) -> LoRARequest | None:
"""Return a LoRARequest if model_name is a loaded adapter, else None."""
if model_name and (lora := self.loaded_loras.get(model_name)):
return LoRARequest(
lora_name=model_name,
lora_int_id=lora.id,
lora_path=lora.path,
)
return None
def _get_lora_lock(self, lora_name: str) -> asyncio.Lock: def _get_lora_lock(self, lora_name: str) -> asyncio.Lock:
"""Get/create the per-LoRA lock without eagerly allocating a new lock each call.""" """Get/create the per-LoRA lock without eagerly allocating a new lock each call."""
with self._lora_load_locks_guard: with self._lora_load_locks_guard:
...@@ -534,8 +552,8 @@ class BaseWorkerHandler(ABC): ...@@ -534,8 +552,8 @@ class BaseWorkerHandler(ABC):
try: try:
# Check if already loaded (idempotency check after acquiring lock). # Check if already loaded (idempotency check after acquiring lock).
# Another concurrent request may have loaded this LoRA while we waited. # Another concurrent request may have loaded this LoRA while we waited.
if lora_name in self.lora_id_for_name: if lora_name in self.loaded_loras:
lora_id = self.lora_id_for_name[lora_name] lora_id = self.loaded_loras[lora_name].id
logger.info( logger.info(
f"LoRA adapter already loaded (concurrent request completed): " f"LoRA adapter already loaded (concurrent request completed): "
f"{lora_name} with ID {lora_id}" f"{lora_name} with ID {lora_id}"
...@@ -576,8 +594,7 @@ class BaseWorkerHandler(ABC): ...@@ -576,8 +594,7 @@ class BaseWorkerHandler(ABC):
) )
# Track the LoRA # Track the LoRA
self.lora_id_for_name[lora_name] = lora_id self.loaded_loras[lora_name] = LoRAInfo(id=lora_id, path=lora_path)
self.lora_name_to_path[lora_name] = lora_path
logger.info( logger.info(
f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}" f"Successfully loaded LoRA adapter: {lora_name} with ID {lora_id}"
) )
...@@ -625,11 +642,7 @@ class BaseWorkerHandler(ABC): ...@@ -625,11 +642,7 @@ class BaseWorkerHandler(ABC):
f"Rolling back: removing LoRA '{lora_name}' from engine" f"Rolling back: removing LoRA '{lora_name}' from engine"
) )
await self.engine_client.remove_lora(lora_id) await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries self.loaded_loras.pop(lora_name, None)
if lora_name in self.lora_id_for_name:
del self.lora_id_for_name[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
logger.debug( logger.debug(
f"Successfully rolled back LoRA '{lora_name}'" f"Successfully rolled back LoRA '{lora_name}'"
) )
...@@ -661,7 +674,7 @@ class BaseWorkerHandler(ABC): ...@@ -661,7 +674,7 @@ class BaseWorkerHandler(ABC):
# loaded, remove the lock entry (best-effort). # loaded, remove the lock entry (best-effort).
with self._lora_load_locks_guard: with self._lora_load_locks_guard:
if ( if (
lora_name not in self.lora_id_for_name lora_name not in self.loaded_loras
and self._lora_load_locks.get(lora_name) is lock and self._lora_load_locks.get(lora_name) is lock
): ):
self._lora_load_locks.pop(lora_name, None) self._lora_load_locks.pop(lora_name, None)
...@@ -697,23 +710,22 @@ class BaseWorkerHandler(ABC): ...@@ -697,23 +710,22 @@ class BaseWorkerHandler(ABC):
async with lock: async with lock:
try: try:
# Check if the LoRA exists *after* waiting for any in-progress load. # Check if the LoRA exists *after* waiting for any in-progress load.
if lora_name not in self.lora_id_for_name: lora = self.loaded_loras.get(lora_name)
if lora is None:
yield { yield {
"status": "error", "status": "error",
"message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.lora_id_for_name.keys())}", "message": f"LoRA adapter '{lora_name}' not found. Available LoRAs: {list(self.loaded_loras.keys())}",
} }
return return
logger.debug(f"Unloading LoRA adapter: {lora_name}") logger.debug(f"Unloading LoRA adapter: {lora_name}")
lora_id = self.lora_id_for_name[lora_name] lora_id = lora.id
lora_path = self.lora_name_to_path.get(lora_name) lora_path = lora.path
await self.engine_client.remove_lora(lora_id) await self.engine_client.remove_lora(lora_id)
# Remove from tracking dictionaries # Remove from tracking
del self.lora_id_for_name[lora_name] del self.loaded_loras[lora_name]
if lora_name in self.lora_name_to_path:
del self.lora_name_to_path[lora_name]
# Unregister the LoRA model from the model registry # Unregister the LoRA model from the model registry
if self.generate_endpoint is not None: if self.generate_endpoint is not None:
...@@ -734,11 +746,6 @@ class BaseWorkerHandler(ABC): ...@@ -734,11 +746,6 @@ class BaseWorkerHandler(ABC):
) )
# Rollback: re-add the LoRA to the engine to maintain consistency # Rollback: re-add the LoRA to the engine to maintain consistency
if lora_path is None:
logger.error(
f"Cannot rollback LoRA '{lora_name}': lora_path is None (data inconsistency)"
)
else:
try: try:
logger.debug( logger.debug(
f"Rolling back: re-adding LoRA '{lora_name}' to engine" f"Rolling back: re-adding LoRA '{lora_name}' to engine"
...@@ -750,9 +757,10 @@ class BaseWorkerHandler(ABC): ...@@ -750,9 +757,10 @@ class BaseWorkerHandler(ABC):
lora_path=lora_path, lora_path=lora_path,
) )
) )
# Re-add to tracking dictionaries # Re-add to tracking
self.lora_id_for_name[lora_name] = lora_id self.loaded_loras[lora_name] = LoRAInfo(
self.lora_name_to_path[lora_name] = lora_path id=lora_id, path=lora_path
)
logger.debug( logger.debug(
f"Successfully rolled back LoRA '{lora_name}'" f"Successfully rolled back LoRA '{lora_name}'"
) )
...@@ -786,7 +794,7 @@ class BaseWorkerHandler(ABC): ...@@ -786,7 +794,7 @@ class BaseWorkerHandler(ABC):
# Remove lock entry once the LoRA is not loaded (or never was). # Remove lock entry once the LoRA is not loaded (or never was).
with self._lora_load_locks_guard: with self._lora_load_locks_guard:
if ( if (
lora_name not in self.lora_id_for_name lora_name not in self.loaded_loras
and self._lora_load_locks.get(lora_name) is lock and self._lora_load_locks.get(lora_name) is lock
): ):
self._lora_load_locks.pop(lora_name, None) self._lora_load_locks.pop(lora_name, None)
...@@ -800,7 +808,7 @@ class BaseWorkerHandler(ABC): ...@@ -800,7 +808,7 @@ class BaseWorkerHandler(ABC):
Returns a dictionary of lora_name -> lora_id mappings. Returns a dictionary of lora_name -> lora_id mappings.
""" """
try: try:
loras = dict(self.lora_id_for_name) loras = {name: lora.id for name, lora in self.loaded_loras.items()}
yield { yield {
"status": "success", "status": "success",
"loras": loras, "loras": loras,
...@@ -1354,19 +1362,11 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -1354,19 +1362,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
) )
# Extract LoRA request if present # Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request = None
model_name = request.get("model") model_name = request.get("model")
lora_request = self._resolve_lora_request(model_name)
if model_name and model_name in self.lora_id_for_name: if lora_request:
lora_id = self.lora_id_for_name[model_name]
lora_request = LoRARequest(
lora_name=model_name,
lora_int_id=lora_id,
lora_path=self.lora_name_to_path[model_name],
)
logger.info( logger.info(
f"Decode request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id})" f"Decode request {request_id} will use LoRA adapter: {model_name} (ID: {lora_request.lora_int_id})"
) )
else: else:
logger.debug( logger.debug(
...@@ -1570,20 +1570,12 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -1570,20 +1570,12 @@ class PrefillWorkerHandler(BaseWorkerHandler):
sampling_params.min_tokens = 1 sampling_params.min_tokens = 1
# Extract LoRA request if present # Extract LoRA request if present
# Check if model name matches a loaded LoRA adapter
lora_request = None
model_name = request.get("model") model_name = request.get("model")
lora_request = self._resolve_lora_request(model_name)
if model_name and model_name in self.lora_id_for_name: if lora_request:
lora_id = self.lora_id_for_name[model_name]
lora_request = LoRARequest(
lora_name=model_name,
lora_int_id=lora_id,
lora_path=self.lora_name_to_path[model_name],
)
logger.info( logger.info(
f"Prefill request {request_id} will use LoRA adapter: {model_name} (ID: {lora_id}), " f"Prefill request {request_id} will use LoRA adapter: {model_name} "
f"path: {self.lora_name_to_path[model_name]}" f"(ID: {lora_request.lora_int_id}), path: {lora_request.lora_path}"
) )
else: else:
logger.debug( logger.debug(
......
...@@ -687,6 +687,9 @@ async def init( ...@@ -687,6 +687,9 @@ async def init(
) )
component = generate_endpoint.component() component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora") load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora") unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras") list_loras_endpoint = component.endpoint("list_loras")
...@@ -812,13 +815,8 @@ async def init( ...@@ -812,13 +815,8 @@ async def init(
try: try:
logger.debug("Starting serve_endpoint for decode worker") logger.debug("Starting serve_endpoint for decode worker")
await asyncio.gather(
# for decode, we want to transfer the in-flight requests to other decode engines, model_metrics_labels = [
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=[
( (
prometheus_names.labels.MODEL, prometheus_names.labels.MODEL,
config.served_model_name or config.model, config.served_model_name or config.model,
...@@ -827,62 +825,42 @@ async def init( ...@@ -827,62 +825,42 @@ async def init(
prometheus_names.labels.MODEL_NAME, prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model, config.served_model_name or config.model,
), ),
], ]
serve_tasks = [
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate,
graceful_shutdown=True,
metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
), ),
clear_endpoint.serve_endpoint( clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, handler.clear_kv_blocks,
metrics_labels=[ metrics_labels=model_metrics_labels,
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
), ),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint( load_lora_endpoint.serve_endpoint(
handler.load_lora, handler.load_lora,
metrics_labels=[ metrics_labels=model_metrics_labels,
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
), ),
unload_lora_endpoint.serve_endpoint( unload_lora_endpoint.serve_endpoint(
handler.unload_lora, handler.unload_lora,
metrics_labels=[ metrics_labels=model_metrics_labels,
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
), ),
list_loras_endpoint.serve_endpoint( list_loras_endpoint.serve_endpoint(
handler.list_loras, handler.list_loras,
metrics_labels=[ metrics_labels=model_metrics_labels,
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
), ),
]
) )
await asyncio.gather(*serve_tasks)
logger.debug("serve_endpoint completed for decode worker") logger.debug("serve_endpoint completed for decode worker")
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
......
...@@ -56,6 +56,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -56,6 +56,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
encode_worker_client: Client | None = None, encode_worker_client: Client | None = None,
decode_worker_client: Client | None = None, decode_worker_client: Client | None = None,
shutdown_event=None, shutdown_event=None,
generate_endpoint=None,
): ):
# Get default_sampling_params from config # Get default_sampling_params from config
default_sampling_params = ( default_sampling_params = (
...@@ -69,6 +70,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -69,6 +70,8 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client, engine_client,
default_sampling_params, default_sampling_params,
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
...@@ -318,6 +321,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -318,6 +321,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
received_tensor_ids: list[int], received_tensor_ids: list[int],
): ):
"""Run prefill and decode on this worker (aggregated mode).""" """Run prefill and decode on this worker (aggregated mode)."""
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"], prompt_token_ids=request.engine_prompt["prompt_token_ids"],
...@@ -325,6 +329,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -325,6 +329,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
), ),
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
request_id=request.request_id, request_id=request.request_id,
lora_request=lora_request,
) )
for tensor_id in received_tensor_ids: for tensor_id in received_tensor_ids:
...@@ -358,6 +363,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -358,6 +363,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
prefill_only_request.sampling_params.min_tokens = 1 prefill_only_request.sampling_params.min_tokens = 1
logger.debug("Prefill request: %s", prefill_only_request) logger.debug("Prefill request: %s", prefill_only_request)
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=prefill_only_request.engine_prompt["prompt_token_ids"], prompt_token_ids=prefill_only_request.engine_prompt["prompt_token_ids"],
...@@ -365,6 +371,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -365,6 +371,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
), ),
sampling_params=prefill_only_request.sampling_params, sampling_params=prefill_only_request.sampling_params,
request_id=prefill_only_request.request_id, request_id=prefill_only_request.request_id,
lora_request=lora_request,
) )
for tensor_id in received_tensor_ids: for tensor_id in received_tensor_ids:
...@@ -400,6 +407,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -400,6 +407,14 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# embeddings_shape). Heavy multimodal data was consumed locally by # embeddings_shape). Heavy multimodal data was consumed locally by
# engine_client.generate() and multimodal_inputs was cleared by # engine_client.generate() and multimodal_inputs was cleared by
# `_finalize_request_metadata`. # `_finalize_request_metadata`.
#
# request.model (LoRA name) is preserved in the serialized request
# so the decode worker can resolve the same LoRA adapter.
if lora_request and request.model:
logger.debug(
f"Forwarding disaggregated decode with LoRA '{request.model}' "
f"— ensure the same adapter is loaded on the decode worker."
)
async for ( async for (
decode_response decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr] ) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
......
...@@ -26,6 +26,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -26,6 +26,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client, engine_client,
config: Config, config: Config,
shutdown_event=None, shutdown_event=None,
generate_endpoint=None,
): ):
# Get default_sampling_params from config # Get default_sampling_params from config
default_sampling_params = ( default_sampling_params = (
...@@ -39,6 +40,8 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -39,6 +40,8 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client, engine_client,
default_sampling_params, default_sampling_params,
enable_multimodal=config.enable_multimodal, enable_multimodal=config.enable_multimodal,
generate_endpoint=generate_endpoint,
config=config,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
) )
...@@ -82,6 +85,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -82,6 +85,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
image_grid_thw, embeddings_shape, request.request_id image_grid_thw, embeddings_shape, request.request_id
) )
lora_request = self._resolve_lora_request(request.model)
gen = self.engine_client.generate( gen = self.engine_client.generate(
prompt=TokensPrompt( prompt=TokensPrompt(
prompt_token_ids=request.engine_prompt["prompt_token_ids"], prompt_token_ids=request.engine_prompt["prompt_token_ids"],
...@@ -89,6 +93,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -89,6 +93,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
), ),
sampling_params=request.sampling_params, sampling_params=request.sampling_params,
request_id=request.request_id, request_id=request.request_id,
lora_request=lora_request,
) )
async for response in gen: async for response in gen:
......
...@@ -94,6 +94,12 @@ class WorkerFactory: ...@@ -94,6 +94,12 @@ class WorkerFactory:
component = generate_endpoint.component() component = generate_endpoint.component()
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
lora_enabled = config.engine_args.enable_lora
if lora_enabled:
load_lora_endpoint = component.endpoint("load_lora")
unload_lora_endpoint = component.endpoint("unload_lora")
list_loras_endpoint = component.endpoint("list_loras")
# Use pre-created engine if provided (checkpoint mode), otherwise create new # Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None: if pre_created_engine is not None:
( (
...@@ -134,7 +140,12 @@ class WorkerFactory: ...@@ -134,7 +140,12 @@ class WorkerFactory:
# Choose handler based on worker type # Choose handler based on worker type
if config.multimodal_decode_worker: if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler( handler = MultimodalDecodeWorkerHandler(
runtime, component, engine_client, config, shutdown_event runtime,
component,
engine_client,
config,
shutdown_event,
generate_endpoint=generate_endpoint,
) )
else: else:
handler = MultimodalPDWorkerHandler( handler = MultimodalPDWorkerHandler(
...@@ -145,6 +156,7 @@ class WorkerFactory: ...@@ -145,6 +156,7 @@ class WorkerFactory:
encode_worker_client, encode_worker_client,
decode_worker_client, decode_worker_client,
shutdown_event, shutdown_event,
generate_endpoint=generate_endpoint,
) )
handler.add_temp_dir(prometheus_temp_dir) handler.add_temp_dir(prometheus_temp_dir)
...@@ -173,7 +185,7 @@ class WorkerFactory: ...@@ -173,7 +185,7 @@ class WorkerFactory:
metrics_labels = [("model", config.served_model_name or config.model)] metrics_labels = [("model", config.served_model_name or config.model)]
try: try:
await asyncio.gather( serve_tasks = [
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.generate, handler.generate,
metrics_labels=metrics_labels, metrics_labels=metrics_labels,
...@@ -182,7 +194,27 @@ class WorkerFactory: ...@@ -182,7 +194,27 @@ class WorkerFactory:
handler.clear_kv_blocks, handler.clear_kv_blocks,
metrics_labels=metrics_labels, metrics_labels=metrics_labels,
), ),
]
if lora_enabled:
serve_tasks.extend(
[
load_lora_endpoint.serve_endpoint(
handler.load_lora,
metrics_labels=metrics_labels,
),
unload_lora_endpoint.serve_endpoint(
handler.unload_lora,
metrics_labels=metrics_labels,
),
list_loras_endpoint.serve_endpoint(
handler.list_loras,
metrics_labels=metrics_labels,
),
]
) )
await asyncio.gather(*serve_tasks)
except Exception as e: except Exception as e:
logger.error(f"Failed to serve endpoints: {e}") logger.error(f"Failed to serve endpoints: {e}")
raise raise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment