Unverified Commit 8abd3e77 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)

parent e885bfdc
......@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, Iterable, Optional, Set, Tuple
from typing import Dict, Iterable, List, Optional, Set, Tuple
import torch
......@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
......@@ -55,6 +56,7 @@ class LoRAManager:
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
......@@ -64,10 +66,6 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
......@@ -75,7 +73,11 @@ class LoRAManager:
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
# Initialize mutable internal state of the LoRAManager.
self.init_state()
self.init_state(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
lora_paths=lora_paths,
)
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
......@@ -112,108 +114,87 @@ class LoRAManager:
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
lora_ref.lora_name: lora_ref.lora_path
for lora_ref in self.lora_refs.values()
},
)
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results = []
for lora_name, lora_path in lora_paths.items():
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)
self.update_state_from_configs()
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
"""
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
), "LoRARef must have both lora_name and lora_path set for loading."
assert (
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
success = True
error_message = ""
try:
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter
if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
# load weights
self.load_lora_weights(lora_ref)
try:
new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
if update_state:
self.update_state_from_configs()
return self.create_lora_update_result(success=True)
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
"""
success = True
error_message = ""
if lora_name in self.loras:
del self.configs[lora_name]
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
adapter = self.configs.get(lora_ref.lora_id, None)
assert (
adapter is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
self.update_state_from_configs()
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
return self.create_lora_update_result(success=True)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
# Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
......@@ -233,10 +214,10 @@ class LoRAManager:
weight_indices = [0] * len(forward_batch.lora_paths)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
for i, uid in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
......@@ -326,7 +307,7 @@ class LoRAManager:
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in self.lora_modules.items():
for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items():
if "qkv_proj" in module_name:
module.set_lora_info(
......@@ -353,115 +334,94 @@ class LoRAManager:
),
)
def init_state(self):
def init_state(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
"""
Initialize the internal (mutable) state of the LoRAManager.
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
the target modules and max_lora_rank.
"""
# Configs of all active LoRA adapters.
self.configs: Dict[str, LoRAConfig] = {}
# LoRA adapter weights cached in CPU memory.
self.loras: Dict[str, LoRAAdapter] = {}
assert lora_paths or (
max_lora_rank is not None and target_modules is not None
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}
self.init_lora_adapters(lora_paths)
self.init_lora_shapes(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
)
self.init_lora_weight_names()
self.init_lora_modules()
self.init_memory_pool()
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
# It is initialized lazily when the first LoRA adapter is loaded.
self.memory_pool: Optional[LoRAMemoryPool] = None
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {}
def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
self.loras: Dict[str, LoRAAdapter] = {}
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
)
Notes:
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
early CY25H2.
"""
def init_lora_shapes(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
if self.memory_pool is None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
if target_modules is not None:
self.target_modules = set(target_modules)
else:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
self.update_lora_weight_names()
self.update_lora_modules()
self.update_memory_buffers()
if max_lora_rank is not None:
self.max_lora_rank = max_lora_rank
else:
# No-op if the memory pool can support the current LoRA configurations.
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# module is changed once FlashInfer backend is deprecated.
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
def update_lora_weight_names(self):
def init_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
def update_lora_adapters(self):
def load_lora_weights(self, lora_ref: LoRARef):
"""
Update the LoRA adapters in CPU memory based on the current `self.configs`.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
"""
# Load new adapter weights to cpu
for name, config in self.configs.items():
if name not in self.loras:
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
lora_adapter = LoRAAdapter(
name,
config,
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
lora_adapter = LoRAAdapter(
lora_ref.lora_id,
self.configs[lora_ref.lora_id],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[lora_ref.lora_id] = lora_adapter
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
......@@ -472,7 +432,7 @@ class LoRAManager:
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self):
def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
......@@ -490,7 +450,12 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def update_lora_modules(self):
def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
{} for _ in range(self.base_hf_config.num_hidden_layers)
]
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
......@@ -511,7 +476,6 @@ class LoRAManager:
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name)
if module_name not in self.lora_modules[layer_id]:
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
@dataclass(frozen=True, slots=True)
class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
def __post_init__(self):
if self.lora_id is None:
raise ValueError("lora_id cannot be None")
def __str__(self) -> str:
parts = [
f"{f.name}={value}"
for f in fields(self)
if (value := getattr(self, f.name)) is not None
]
return f"{self.__class__.__name__}({', '.join(parts)})"
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
async def register(self, lora_ref: LoRARef):
"""
Register a new LoRARef object in the registry.
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
async def unregister(self, lora_name: str) -> str:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
return lora_ref.lora_id
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async def _acquire_single(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id
if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")
......@@ -153,7 +153,7 @@ class LoRAMemoryPool:
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
......@@ -186,7 +186,7 @@ class LoRAMemoryPool:
uid: str,
buffer_id: int,
lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
):
def load_lora_weight_tensor(
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
......
......@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams
......@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
lora_name: str
# The path of loading.
lora_path: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
lora_path=self.lora_path,
)
@dataclass
class UnloadLoRAAdapterReqInput:
# The name of lora module to unload.
lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
)
@dataclass
class LoRAUpdateResult:
success: bool
error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict)
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
......@@ -247,7 +247,7 @@ class Scheduler(
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths
self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init
......@@ -1706,13 +1706,13 @@ class Scheduler(
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths:
if self.enable_lora:
lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.lora_paths
self.enable_lora
and len(
lora_set
| set([req.lora_path for req in adder.can_run_list])
......@@ -2466,12 +2466,6 @@ class Scheduler(
"""In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result
def unload_lora_adapter(
......@@ -2480,14 +2474,6 @@ class Scheduler(
"""Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
......
......@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer,
get_tokenizer_from_processor,
)
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
......@@ -242,11 +243,11 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Store states
self.no_create_loop = False
......@@ -523,6 +524,10 @@ class TokenizerManager:
else:
mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids)
return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
......@@ -574,8 +579,6 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
......@@ -689,21 +692,6 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
......@@ -1054,8 +1042,18 @@ class TokenizerManager:
)
async with self.model_update_lock.writer_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
# Register the new adapter in the registry.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
if result.success:
await self.lora_registry.register(new_adapter)
return result
async def unload_lora_adapter(
......@@ -1069,6 +1067,10 @@ class TokenizerManager:
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
assert (
......@@ -1080,8 +1082,9 @@ class TokenizerManager:
)
async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def get_weights_by_name(
......@@ -1309,7 +1312,7 @@ class TokenizerManager:
filename = os.path.join(
self.crash_dump_folder,
os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
......
......@@ -293,11 +293,9 @@ class TpModelWorker:
return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter(
recv_req.lora_name, recv_req.lora_path
)
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
......@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
......@@ -890,44 +891,38 @@ class ModelRunner:
tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
lora_paths=self.server_args.lora_paths,
)
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
if result.success:
logger.info(
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
)
else:
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
def load_lora_adapter(self, lora_name: str, lora_path: str):
def load_lora_adapter(self, lora_ref: LoRARef):
"""Load a new lora adapter from disk or huggingface."""
logger.info(
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
f"LoRA adapter loading starts: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
result = self.lora_manager.load_lora_adapter(lora_ref)
logger.info(
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
f"LoRA adapter loading completes: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
return result
def unload_lora_adapter(self, lora_name: str):
def unload_lora_adapter(self, lora_ref: LoRARef):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger.info(
f"LoRA adapter unloading starts: name={lora_name}. "
f"LoRA adapter unloading starts: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
result = self.lora_manager.unload_lora_adapter(lora_name)
result = self.lora_manager.unload_lora_adapter(lora_ref)
logger.info(
f"LoRA adapter unloading completes: name={lora_name}. "
f"LoRA adapter unloading completes: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
)
......
......@@ -20,10 +20,10 @@ import logging
import os
import random
import tempfile
from token import OP
from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
LORA_TARGET_ALL_MODULES,
......@@ -145,7 +145,7 @@ class ServerArgs:
enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
......@@ -1843,9 +1843,24 @@ class ServerArgs:
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
else:
self.lora_paths[lora_path] = lora_path
self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path,
lora_path=lora_path,
)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v)
for k, v in self.lora_paths.items()
}
elif self.lora_paths is None:
self.lora_paths = {}
else:
raise ValueError(
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
"Expected a list or a dictionary."
)
# Expand target modules
if self.lora_target_modules:
......
......@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
import contextlib
import multiprocessing as mp
import unittest
from typing import Dict, List, Tuple
......@@ -39,6 +40,16 @@ ADAPTERS = [
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
@contextlib.contextmanager
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
"""A context manager to load and automatically unload a LoRA adapter."""
try:
runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
yield
finally:
runner.unload_lora_adapter(lora_name=lora_name)
class TestLoRAEviction(CustomTestCase):
def test_lora_eviction_with_different_target_modules(self):
"""
......@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase):
self._run_test(ADAPTERS, output_history, reverse=False)
self._run_test(ADAPTERS, output_history, reverse=True)
def test_lora_eviction_with_reused_lora_name(self):
"""
Test LoRA eviction with reused LoRA names.
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
works correctly when reusing LoRA names.
"""
output_history = {}
self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)
def _run_test(
self,
lora_paths: List[str],
output_history: Dict[Tuple[str, str], str],
reverse: bool,
reverse: bool = False,
repeat: int = 2,
reuse_lora_name: bool = False,
):
REUSED_LORA_NAME = "lora"
max_new_tokens = 256
backend = "triton"
torch_dtype = torch.float16
base_path = BASE_MODEL
assert len(lora_paths) >= 2
initial_lora_paths = lora_paths if not reuse_lora_name else None
# Initialize runners
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=lora_paths,
lora_paths=initial_lora_paths,
max_loras_per_batch=1,
lora_backend=backend,
disable_radix_cache=True,
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
) as srt_runner:
adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
for i in range(repeat):
for j, adapter in enumerate(adapter_sequence):
for j, lora_path in enumerate(adapter_sequence):
print(
f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---"
f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
)
lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
context = (
dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
if reuse_lora_name
else contextlib.nullcontext()
)
for prompt in PROMPTS:
print("\nprompt:\n", prompt)
srt_outputs = srt_runner.forward(
[prompt],
max_new_tokens=max_new_tokens,
lora_paths=[adapter],
)
output = srt_outputs.output_strs[0].strip()
print("\noutput:\n", output)
prev_output = output_history.get((adapter, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
with context:
for prompt in PROMPTS:
print("\nprompt:\n", prompt)
srt_outputs = srt_runner.forward(
[prompt],
max_new_tokens=max_new_tokens,
lora_paths=[lora_name],
)
else:
output_history[(adapter, prompt)] = output
output = srt_outputs.output_strs[0].strip()
print("\noutput:\n", output)
prev_output = output_history.get((lora_path, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
)
else:
output_history[(lora_path, prompt)] = output
if __name__ == "__main__":
......
......@@ -14,7 +14,7 @@ class TestFile:
suites = {
"per-commit": [
TestFile("models/lora/test_lora.py", 200),
TestFile("models/lora/test_lora_eviction.py", 120),
TestFile("models/lora/test_lora_eviction.py", 200),
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment