Unverified Commit 49538d11 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support dynamic LoRA loading / unloading in engine/server API (#7446)

parent cfe2edac
......@@ -48,6 +48,14 @@ class EngineBase(ABC):
"""Update model weights with in-memory tensor data."""
pass
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
pass
def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
pass
@abstractmethod
def release_memory_occupation(self):
"""Release GPU memory occupation temporarily."""
......
......@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
ImageDataItem,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
......@@ -478,6 +480,29 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None)
)
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new LoRA adapter without re-launching the engine."""
obj = LoadLoRAAdapterReqInput(
lora_name=lora_name,
lora_path=lora_path,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.load_lora_adapter(obj, None)
)
def unload_lora_adapter(self, lora_name: str):
"""Unload a LoRA adapter without re-launching the engine."""
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.unload_lora_adapter(obj, None)
)
def release_memory_occupation(self, tags: Optional[List[str]] = None):
obj = ReleaseMemoryOccupationReqInput(tags=tags)
loop = asyncio.get_event_loop()
......
......@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
GenerateReqInput,
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
OpenSessionReqInput,
ParseFunctionCallReq,
ProfileReqInput,
......@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
SeparateReasoningReqInput,
SetInternalStateReq,
SlowDownReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
......@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
return _create_error_response(e)
@app.api_route("/load_lora_adapter", methods=["POST"])
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/unload_lora_adapter", methods=["POST"])
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
"""Load a new LoRA adapter without re-launching the server."""
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
if result.success:
return ORJSONResponse(
result,
status_code=HTTPStatus.OK,
)
else:
return ORJSONResponse(
result,
status_code=HTTPStatus.BAD_REQUEST,
)
@app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
......
......@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
self.layers: List[LoRALayer] = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
for _ in range(base_hf_config.num_hidden_layers)
]
)
......@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
else:
self.weights[name] = loaded_weight.cpu()
# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
# normalize kv_proj and gate_up_proj
for layer in self.layers:
weight_names = list(layer.weights.keys())
self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights)
......
......@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_lora_weight_names,
get_weight_name,
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import replace_submodule
......@@ -98,44 +99,96 @@ class LoRAManager:
],
)
def load_lora_adapters(self, lora_paths: Dict[str, str]):
def create_lora_update_result(
self, success: bool, error_message: str = ""
) -> LoRAUpdateResult:
return LoRAUpdateResult(
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
},
)
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results = []
for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras:
logger.warning(
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first."
)
continue
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)
self.update_state_from_configs()
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
"""
success = True
error_message = ""
if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try:
self.configs[lora_name] = LoRAConfig(lora_path)
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
)
self.update_state_from_configs()
if update_state:
self.update_state_from_configs()
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def unload_lora_adapters(self, lora_names: Set[str]):
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
"""
for lora_name in lora_names:
if lora_name in self.loras:
del self.configs[lora_name]
else:
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
success = True
error_message = ""
if lora_name in self.loras:
del self.configs[lora_name]
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
self.update_state_from_configs()
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
......@@ -372,8 +425,8 @@ class LoRAManager:
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters
for name in self.loras:
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
......
......@@ -20,7 +20,7 @@ import copy
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
from sglang.srt.multimodal.mm_utils import has_valid_data
......@@ -1002,3 +1002,27 @@ class RpcReqInput:
class RpcReqOutput:
success: bool
message: str
@dataclass
class LoadLoRAAdapterReqInput:
# The name of the lora module to newly loaded.
lora_name: str
# The path of loading.
lora_path: str
@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str
@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
......@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
......@@ -519,6 +523,8 @@ class Scheduler(
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
(ExpertDistributionReq, self.expert_distribution_handle),
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
]
)
......@@ -2241,6 +2247,36 @@ class Scheduler(
logger.error(message)
return UpdateWeightFromDiskReqOutput(success, message, 0)
def load_lora_adapter(
self, recv_req: LoadLoRAAdapterReqInput
) -> LoadLoRAAdapterReqOutput:
"""In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result
def unload_lora_adapter(
self, recv_req: UnloadLoRAAdapterReqInput
) -> UnloadLoRAAdapterReqOutput:
"""Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
"""Initialize the online model parameter update group."""
success, message = self.tp_worker.init_weights_update_group(recv_req)
......
......@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
HealthCheckOutput,
InitWeightsUpdateGroupReqInput,
InitWeightsUpdateGroupReqOutput,
LoadLoRAAdapterReqInput,
LoadLoRAAdapterReqOutput,
LoRAUpdateResult,
OpenSessionReqInput,
OpenSessionReqOutput,
ProfileReq,
......@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
SlowDownReqOutput,
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
UnloadLoRAAdapterReqInput,
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
......@@ -311,6 +316,9 @@ class TokenizerManager:
self.expert_distribution_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_lora_adapter_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self._result_dispatcher = TypeBasedDispatcher(
[
......@@ -377,6 +385,10 @@ class TokenizerManager:
ExpertDistributionReqOutput,
self.expert_distribution_communicator.handle_recv,
),
(
LoRAUpdateResult,
self.update_lora_adapter_communicator.handle_recv,
),
(HealthCheckOutput, lambda x: None),
]
)
......@@ -960,6 +972,49 @@ class TokenizerManager:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def load_lora_adapter(
self,
obj: LoadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start load Lora adapter. Lora name=%s, path=%s",
obj.lora_name,
obj.lora_path,
)
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
async def unload_lora_adapter(
self,
obj: UnloadLoRAAdapterReqInput,
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for dynamic lora loading"
logger.info(
"Start unload Lora adapter. Lora name=%s",
obj.lora_name,
)
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
return result
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
......
......@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
......@@ -275,3 +277,13 @@ class TpModelWorker:
recv_req.name, recv_req.truncate_size
)
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
return result
......@@ -26,6 +26,8 @@ import torch
from sglang.srt.managers.io_struct import (
GetWeightsByNameReqInput,
InitWeightsUpdateGroupReqInput,
LoadLoRAAdapterReqInput,
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
......@@ -268,6 +270,12 @@ class TpModelWorkerClient:
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
return self.worker.get_weights_by_name(recv_req)
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
return self.worker.load_lora_adapter(recv_req)
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))
......@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt import debug_utils
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
......@@ -819,8 +818,47 @@ class ModelRunner:
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
logger.info("LoRA manager ready.")
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
if result.success:
logger.info(
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
)
else:
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
def load_lora_adapter(self, lora_name: str, lora_path: str):
"""Load a new lora adapter from disk or huggingface."""
logger.info(
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
logger.info(
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
return result
def unload_lora_adapter(self, lora_name: str):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger.info(
f"LoRA adapter unloading starts: name={lora_name}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.unload_lora_adapter(lora_name)
logger.info(
f"LoRA adapter unloading completes: name={lora_name}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
return result
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
......
......@@ -503,6 +503,7 @@ class SRTRunner:
disable_overlap_schedule: bool = False,
disable_custom_all_reduce: bool = False,
torchao_config: Optional[str] = None,
cuda_graph_max_bs: int = 4,
sleep_on_idle=False,
):
self.model_type = model_type
......@@ -539,7 +540,7 @@ class SRTRunner:
tokenizer_path=tokenizer_path,
enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=4,
cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle,
**spec_kwargs,
......@@ -552,6 +553,12 @@ class SRTRunner:
else:
self.tokenizer = None
def load_lora_adapter(self, lora_name: str, lora_path: str):
return self.engine.load_lora_adapter(lora_name, lora_path)
def unload_lora_adapter(self, lora_name: str):
return self.engine.unload_lora_adapter(lora_name)
def forward(
self,
prompts: Union[
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
import requests
import torch
from sglang.srt.utils import kill_process_tree
from sglang.test.runners import SRTRunner
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
PROMPTS = [
"SGL is a",
"AI is a field of computer science focused on",
"Computer science is the study of",
"Write a short story.",
"What are the main components of a computer?",
]
class OperationType(Enum):
LOAD = "load"
UNLOAD = "unload"
NOOP = "noop"
FORWARD = "forward"
@dataclass
class Operation:
type: OperationType
data: Optional[str]
@dataclass
class TestCase:
base: str
max_loras_per_batch: int
all_adapters: List[str]
initial_adapters: List[str]
op_sequence: List[Operation]
max_new_tokens: int = 32
def create_batch_data(adapters: Union[str, list]) -> dict:
if not isinstance(adapters, list):
adapters = [adapters]
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
TEST_CASES = [
# basic test, no eviction
TestCase(
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
]
),
),
Operation(
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
]
),
),
Operation(
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pbevan11/llama-3.1-8b-ocr-correction",
]
),
),
],
),
# Eviction
TestCase(
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
],
),
]
class LoRAUpdateTestSessionMode(Enum):
ENGINE = "engine"
SERVER = "server"
class LoRAUpdateTestSessionBase:
"""
Base context manager for testing LoRA adapters.
"""
def __init__(
self,
*,
testcase: Optional[TestCase],
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int = 1,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
):
self.testcase = testcase
self.model_path = model_path
self.lora_paths = lora_paths
self.max_loras_per_batch = max_loras_per_batch
self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph
self.cuda_graph_max_bs = cuda_graph_max_bs
self.expected_adapters = set(lora_paths)
self.handle = None # Will be set in __enter__
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Don't suppress exceptions by default
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
"""
Load a LoRA adapter by name and path.
"""
raise NotImplementedError("Subclasses must implement load_lora_adapter")
def unload_lora_adapter(self, lora_name: str):
"""
Unload a LoRA adapter by name.
"""
raise NotImplementedError("Subclasses must implement unload_lora_adapter")
def forward(
self,
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
"""
raise NotImplementedError("Subclasses must implement forward")
class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
"""
Context manager for testing LoRA adapters with in-process engine.
"""
def __enter__(self):
# in-process runner
self.handle = SRTRunner(
model_path=self.model_path,
model_type="generation",
lora_paths=self.lora_paths,
lora_backend=self.lora_backend,
torch_dtype=torch.float16,
max_loras_per_batch=self.max_loras_per_batch,
disable_cuda_graph=self.disable_cuda_graph,
cuda_graph_max_bs=self.cuda_graph_max_bs,
disable_radix_cache=True,
)
self.handle.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.handle is not None:
# delegate cleanup to SRTRunner
return self.handle.__exit__(exc_type, exc_val, exc_tb)
# don't suppress exceptions
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
"""
Load a LoRA adapter by name and path.
"""
if lora_path is None:
lora_path = lora_name
self.expected_adapters.add(lora_name)
response = self.handle.load_lora_adapter(
lora_name=lora_name,
lora_path=lora_path,
)
self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def unload_lora_adapter(self, lora_name: str):
"""
Unload a LoRA adapter by name.
"""
self.expected_adapters.remove(lora_name)
response = self.handle.unload_lora_adapter(
lora_name=lora_name,
)
self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def forward(
self,
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
"""
response = self.handle.batch_forward(
prompts=prompts,
lora_paths=lora_paths,
max_new_tokens=max_new_tokens,
)
output_strs = response.output_strs
print(f"output_strs: {output_strs}")
return output_strs
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
"""
Context manager for testing LoRA adapters with standalone server.
"""
def __enter__(self):
other_args = [
"--cuda-graph-max-bs",
str(self.cuda_graph_max_bs),
"--lora-paths",
*self.lora_paths,
"--max-loras-per-batch",
str(self.max_loras_per_batch),
"--lora-backend",
self.lora_backend,
"--disable-radix-cache",
"--random-seed",
"42",
"--max-running-request",
"1",
]
if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph")
# launch external server
self.handle = popen_launch_server(
self.model_path,
DEFAULT_URL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_args,
)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.handle is not None:
kill_process_tree(self.handle.pid)
# don't suppress exceptions
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
"""
Load a LoRA adapter by name and path.
"""
if lora_path is None:
lora_path = lora_name
self.expected_adapters.add(lora_name)
response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path},
)
self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def unload_lora_adapter(self, lora_name: str):
"""
Unload a LoRA adapter by name.
"""
self.expected_adapters.remove(lora_name)
response = requests.post(
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
json={"lora_name": lora_name},
)
self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def forward(
self,
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
"""
response = requests.post(
DEFAULT_URL_FOR_TEST + "/generate",
json={
"text": prompts,
"lora_path": lora_paths,
"sampling_params": {
"temperature": 0,
"top_k": 1,
"max_new_tokens": max_new_tokens,
},
},
)
self.testcase.assertTrue(response.ok)
output_strs = [r["text"] for r in response.json()]
print(f"output_strs: {output_strs}")
return output_strs
# Factory function to create the appropriate LoRA test session based on mode
def LoRAUpdateTestSession(
*,
testcase: Optional[TestCase],
mode: LoRAUpdateTestSessionMode,
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int = 1,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
):
common_kwargs = {
"testcase": testcase,
"model_path": model_path,
"lora_paths": lora_paths,
"max_loras_per_batch": max_loras_per_batch,
"lora_backend": lora_backend,
"disable_cuda_graph": disable_cuda_graph,
"cuda_graph_max_bs": cuda_graph_max_bs,
}
if mode == LoRAUpdateTestSessionMode.ENGINE:
return LoRAUpdateEngineTestSession(**common_kwargs)
elif mode == LoRAUpdateTestSessionMode.SERVER:
return LoRAUpdateServerTestSession(**common_kwargs)
else:
raise ValueError(f"Unrecognized mode: {mode!r}")
class TestLoRADynamicUpdate(CustomTestCase):
"""
This test case verifies that the SRT runner can dynamically load and unload LoRA adapters
during a sequence of operations, and that the outputs of forward passes with dynamically loaded
adapters match the outputs of forward passes with statically loaded adapters.
"""
def _repeat_each(lst, n):
return [x for x in lst for _ in range(n)]
def _run_operation_sequence(
self,
mode: LoRAUpdateTestSessionMode,
base: str,
initial_adapters: List[str],
max_loras_per_batch: int,
op_sequence: List[Operation],
max_new_tokens: int = 32,
) -> List[tuple]:
"""
Runs a sequence of operations on the SRT runner, including loading and unloading LoRA adapters,
and performing forward passes with the current set of loaded adapters.
"""
forward_outputs = []
with LoRAUpdateTestSession(
testcase=self,
mode=mode,
model_path=base,
lora_paths=initial_adapters,
max_loras_per_batch=max_loras_per_batch,
) as session:
for op in op_sequence:
op_type = op.type
data = op.data
print("-" * 100)
print(
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
)
if op_type == OperationType.LOAD:
result = session.load_lora_adapter(
lora_name=data,
lora_path=data,
)
elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter(
lora_name=data,
)
elif op_type == OperationType.FORWARD:
prompts, adapters = zip(*data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
)
forward_outputs.append(result)
return forward_outputs
def test_dynamic_adapter_updates(self):
for case_idx, test_case in enumerate(TEST_CASES, start=1):
for mode in [
LoRAUpdateTestSessionMode.SERVER,
LoRAUpdateTestSessionMode.ENGINE,
]:
print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.")
print("=" * 100)
print(
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
)
# Test dynamic loading of adapters
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
dynamic_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.initial_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens,
)
# static loading
forward_ops = [
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
]
print("=" * 100)
print(
f"\n--- Running static pass with {len(forward_ops)} operations ---"
)
static_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.all_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=forward_ops,
max_new_tokens=test_case.max_new_tokens,
)
print(f"Dynamic output: {dynamic_output}")
print(f"Static output: {static_output}")
print("=" * 100)
self.assertEqual(
len(dynamic_output),
len(static_output),
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
)
for i, (dynamic, static) in enumerate(
zip(dynamic_output, static_output), start=1
):
self.assertEqual(
len(dynamic),
len(static),
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
)
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
d_out = d_out.strip()
s_out = s_out.strip()
self.assertEqual(
d_out,
s_out,
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
......@@ -17,6 +17,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 400),
TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment