Unverified Commit 49538d11 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support dynamic LoRA loading / unloading in engine/server API (#7446)

parent cfe2edac
...@@ -48,6 +48,14 @@ class EngineBase(ABC): ...@@ -48,6 +48,14 @@ class EngineBase(ABC):
"""Update model weights with in-memory tensor data.""" """Update model weights with in-memory tensor data."""
pass pass
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
pass
def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
pass
@abstractmethod @abstractmethod
def release_memory_occupation(self): def release_memory_occupation(self):
"""Release GPU memory occupation temporarily.""" """Release GPU memory occupation temporarily."""
......
...@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import ( ...@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
ImageDataItem, ImageDataItem,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
RpcReqInput, RpcReqInput,
RpcReqOutput, RpcReqOutput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
...@@ -478,6 +480,29 @@ class Engine(EngineBase): ...@@ -478,6 +480,29 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None) self.tokenizer_manager.get_weights_by_name(obj, None)
) )
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
obj = LoadLoRAAdapterReqInput(
lora_name=lora_name,
lora_path=lora_path,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.load_lora_adapter(obj, None)
)
def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.unload_lora_adapter(obj, None)
)
def release_memory_occupation(self, tags: Optional[List[str]] = None): def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags) obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
......
...@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
OpenSessionReqInput, OpenSessionReqInput,
ParseFunctionCallReq, ParseFunctionCallReq,
ProfileReqInput, ProfileReqInput,
...@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
SeparateReasoningReqInput, SeparateReasoningReqInput,
SetInternalStateReq, SetInternalStateReq,
SlowDownReqInput, SlowDownReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
...@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request): ...@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
return _create_error_response(e) return _create_error_response(e)
@app.api_route("/load_lora_adapter", methods=["POST"])
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/unload_lora_adapter", methods=["POST"])
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/open_session", methods=["GET", "POST"]) @app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request): async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
......
...@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module): ...@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
self.layers: List[LoRALayer] = nn.ModuleList( self.layers: List[LoRALayer] = nn.ModuleList(
[ [
LoRALayer(config, base_hf_config) LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers) for _ in range(base_hf_config.num_hidden_layers)
] ]
) )
...@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module): ...@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
else: else:
self.weights[name] = loaded_weight.cpu() self.weights[name] = loaded_weight.cpu()
# stack kv_proj and gate_up_proj # normalize kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers): for layer in self.layers:
layer = self.layers[i] weight_names = list(layer.weights.keys())
weight_names = [name for name, _ in layer.weights.items()]
self.normalize_qkv_proj(weight_names, layer.weights) self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights) self.normalize_gate_up_proj(weight_names, layer.weights)
......
...@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import ( ...@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_lora_weight_names, get_normalized_lora_weight_names,
get_weight_name, get_weight_name,
) )
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import replace_submodule from sglang.srt.utils import replace_submodule
...@@ -98,44 +99,96 @@ class LoRAManager: ...@@ -98,44 +99,96 @@ class LoRAManager:
], ],
) )
def load_lora_adapters(self, lora_paths: Dict[str, str]): def create_lora_update_result(
self, success: bool, error_message: str = ""
) -> LoRAUpdateResult:
return LoRAUpdateResult(
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
},
)
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
""" """
Load LoRA adapters from the specified paths. Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
Args: Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths. lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning. If a LoRA adapter is already loaded, it will be skipped with a warning.
""" """
results = []
for lora_name, lora_path in lora_paths.items(): for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras: result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
logger.warning( results.append(result)
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first." self.update_state_from_configs()
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
) )
continue
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
"""
success = True
error_message = ""
if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try:
self.configs[lora_name] = LoRAConfig(lora_path) self.configs[lora_name] = LoRAConfig(lora_path)
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
)
if update_state:
self.update_state_from_configs() self.update_state_from_configs()
def unload_lora_adapters(self, lora_names: Set[str]): return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
""" """
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules. delete the corresponding LoRA modules.
Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
""" """
for lora_name in lora_names:
success = True
error_message = ""
if lora_name in self.loras: if lora_name in self.loras:
del self.configs[lora_name] del self.configs[lora_name]
else: else:
logger.warning(f"LoRA adapter {lora_name} is not loaded.") error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
self.update_state_from_configs() self.update_state_from_configs()
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool # load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths) cur_uids = set(forward_batch.lora_paths)
...@@ -372,8 +425,8 @@ class LoRAManager: ...@@ -372,8 +425,8 @@ class LoRAManager:
lora_adapter.initialize_weights() lora_adapter.initialize_weights()
self.loras[name] = lora_adapter self.loras[name] = lora_adapter
# Clean up unused LoRA adapters # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in self.loras: for name in list(self.loras):
if name not in self.configs: if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}") logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name] del self.loras[name]
......
...@@ -20,7 +20,7 @@ import copy ...@@ -20,7 +20,7 @@ import copy
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.multimodal.mm_utils import has_valid_data
...@@ -1002,3 +1002,27 @@ class RpcReqInput: ...@@ -1002,3 +1002,27 @@ class RpcReqInput:
class RpcReqOutput: class RpcReqOutput:
success: bool success: bool
message: str message: str
@dataclass
class LoadLoRAAdapterReqInput:
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
lora_path: str
@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str
@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
...@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput, HealthCheckOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput, SlowDownReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
...@@ -519,6 +523,8 @@ class Scheduler( ...@@ -519,6 +523,8 @@ class Scheduler(
(SetInternalStateReq, self.set_internal_state), (SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request), (RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle), (ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
] ]
) )
...@@ -2241,6 +2247,36 @@ class Scheduler( ...@@ -2241,6 +2247,36 @@ class Scheduler(
logger.error(message) logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0) return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput:
"""In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result
def unload_lora_adapter(
self, recv_req: UnloadLoRAAdapterReqInput
) -> UnloadLoRAAdapterReqOutput:
"""Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group.""" """Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req) success, message = self.tp_worker.init_weights_update_group(recv_req)
......
...@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import ( ...@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput, HealthCheckOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
OpenSessionReqInput, OpenSessionReqInput,
OpenSessionReqOutput, OpenSessionReqOutput,
ProfileReq, ProfileReq,
...@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput, SlowDownReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput, UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
...@@ -311,6 +316,9 @@ class TokenizerManager: ...@@ -311,6 +316,9 @@ class TokenizerManager:
self.expert_distribution_communicator = _Communicator( self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher( self._result_dispatcher = TypeBasedDispatcher(
[ [
...@@ -377,6 +385,10 @@ class TokenizerManager: ...@@ -377,6 +385,10 @@ class TokenizerManager:
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv, self.expert_distribution_communicator.handle_recv,
), ),
(
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None), (HealthCheckOutput, lambda x: None),
] ]
) )
...@@ -960,6 +972,49 @@ class TokenizerManager: ...@@ -960,6 +972,49 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0] result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message return result.success, result.message
async def load_lora_adapter(
self,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
async def unload_lora_adapter(
self,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
): ):
......
...@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput ...@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
...@@ -275,3 +277,13 @@ class TpModelWorker: ...@@ -275,3 +277,13 @@ class TpModelWorker:
recv_req.name, recv_req.truncate_size recv_req.name, recv_req.truncate_size
) )
return parameter return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
return result
...@@ -26,6 +26,8 @@ import torch ...@@ -26,6 +26,8 @@ import torch
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
...@@ -268,6 +270,12 @@ class TpModelWorkerClient: ...@@ -268,6 +270,12 @@ class TpModelWorkerClient:
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req) return self.worker.get_weights_by_name(recv_req)
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
return self.worker.load_lora_adapter(recv_req)
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def __delete__(self): def __delete__(self):
self.input_queue.put((None, None)) self.input_queue.put((None, None))
self.copy_queue.put((None, None, None)) self.copy_queue.put((None, None, None))
...@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union ...@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
...@@ -819,8 +818,47 @@ class ModelRunner: ...@@ -819,8 +818,47 @@ class ModelRunner:
tp_size=self.tp_size, tp_size=self.tp_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
) )
self.lora_manager.load_lora_adapters(self.server_args.lora_paths) result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
logger.info("LoRA manager ready.") if result.success:
logger.info(
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
)
else:
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new lora adapter from disk or huggingface."""
logger.info(
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
logger.info(
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
return result
def unload_lora_adapter(self, lora_name: str):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger.info(
f"LoRA adapter unloading starts: name={lora_name}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.unload_lora_adapter(lora_name)
logger.info(
f"LoRA adapter unloading completes: name={lora_name}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
return result
def profile_max_num_token(self, total_gpu_memory: int): def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
......
...@@ -503,6 +503,7 @@ class SRTRunner: ...@@ -503,6 +503,7 @@ class SRTRunner:
disable_overlap_schedule: bool = False, disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None, torchao_config: Optional[str] = None,
cuda_graph_max_bs: int = 4,
sleep_on_idle=False, sleep_on_idle=False,
): ):
self.model_type = model_type self.model_type = model_type
...@@ -539,7 +540,7 @@ class SRTRunner: ...@@ -539,7 +540,7 @@ class SRTRunner:
tokenizer_path=tokenizer_path, tokenizer_path=tokenizer_path,
enable_ep_moe=enable_ep_moe, enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule, disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4, cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle, sleep_on_idle=sleep_on_idle,
**spec_kwargs, **spec_kwargs,
...@@ -552,6 +553,12 @@ class SRTRunner: ...@@ -552,6 +553,12 @@ class SRTRunner:
else: else:
self.tokenizer = None self.tokenizer = None
def load_lora_adapter(self, lora_name: str, lora_path: str):
return self.engine.load_lora_adapter(lora_name, lora_path)
def unload_lora_adapter(self, lora_name: str):
return self.engine.unload_lora_adapter(lora_name)
def forward( def forward(
self, self,
prompts: Union[ prompts: Union[
......
This diff is collapsed.
...@@ -17,6 +17,7 @@ suites = { ...@@ -17,6 +17,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 400),
TestFile("models/test_embedding_models.py", 73), TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52), # TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100), TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment