Unverified Commit 8abd3e77 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)

parent e885bfdc
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving" # and "Punica: Multi-Tenant LoRA Serving"
import logging import logging
from typing import Dict, Iterable, Optional, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
...@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr ...@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.mem_pool import LoRAMemoryPool from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import ( from sglang.srt.lora.utils import (
LoRABatchInfo, LoRABatchInfo,
...@@ -55,6 +56,7 @@ class LoRAManager: ...@@ -55,6 +56,7 @@ class LoRAManager:
tp_rank: int = 0, tp_rank: int = 0,
max_lora_rank: Optional[int] = None, max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None, target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
): ):
self.base_model: torch.nn.Module = base_model self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
...@@ -64,10 +66,6 @@ class LoRAManager: ...@@ -64,10 +66,6 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels # LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.") logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
...@@ -75,7 +73,11 @@ class LoRAManager: ...@@ -75,7 +73,11 @@ class LoRAManager:
self.lora_backend: BaseLoRABackend = backend_type(lora_backend) self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
# Initialize mutable internal state of the LoRAManager. # Initialize mutable internal state of the LoRAManager.
self.init_state() self.init_state(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
lora_paths=lora_paths,
)
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int): def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
...@@ -112,108 +114,87 @@ class LoRAManager: ...@@ -112,108 +114,87 @@ class LoRAManager:
success=success, success=success,
error_message=error_message, error_message=error_message,
loaded_adapters={ loaded_adapters={
name: config.path for name, config in self.configs.items() lora_ref.lora_name: lora_ref.lora_path
for lora_ref in self.lora_refs.values()
}, },
) )
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult: def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results = []
for lora_name, lora_path in lora_paths.items():
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)
self.update_state_from_configs()
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
""" """
Load a single LoRA adapter from the specified path. Load a single LoRA adapter from the specified path.
Args: Args:
lora_name (str): The name of the LoRA adapter. lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
""" """
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
), "LoRARef must have both lora_name and lora_path set for loading."
assert (
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
success = True try:
error_message = "" # load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter
if lora_name in self.loras: # load weights
success = False self.load_lora_weights(lora_ref)
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try: # keep metadata for displayed messages
new_adapter = LoRAConfig(lora_path) self.lora_refs[lora_ref.lora_id] = lora_ref
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
except Exception as e: except Exception as e:
success = False return self.create_lora_update_result(
error_message = ( success=False,
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}" error_message=str(e),
) )
if update_state: return self.create_lora_update_result(success=True)
self.update_state_from_configs()
return self.create_lora_update_result( def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
success=success,
error_message=error_message,
)
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
""" """
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible. Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
""" """
incompatible = self.memory_pool and not self.memory_pool.can_support( memory_pool = getattr(self, "memory_pool", None)
lora_config incompatible = memory_pool and not memory_pool.can_support(lora_config)
)
if incompatible: if incompatible:
raise ValueError( raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. " f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are " "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`." "included in `--enable_lora_modules`."
) )
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
""" """
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules. delete the corresponding LoRA modules.
""" """
success = True adapter = self.configs.get(lora_ref.lora_id, None)
error_message = "" assert (
if lora_name in self.loras: adapter is not None
del self.configs[lora_name] ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
self.update_state_from_configs() try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
return self.create_lora_update_result( return self.create_lora_update_result(success=True)
success=success,
error_message=error_message,
)
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool # Load active loras into lora memory pool
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
cur_uids = set(forward_batch.lora_paths) cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules) self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
...@@ -233,10 +214,10 @@ class LoRAManager: ...@@ -233,10 +214,10 @@ class LoRAManager:
weight_indices = [0] * len(forward_batch.lora_paths) weight_indices = [0] * len(forward_batch.lora_paths)
lora_ranks = [0] * self.max_loras_per_batch lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch scalings = [0] * self.max_loras_per_batch
for i, lora_path in enumerate(forward_batch.lora_paths): for i, uid in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if lora_path is not None: if uid is not None:
lora = self.loras[lora_path] lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling scalings[weight_indices[i]] = lora.scaling
...@@ -326,7 +307,7 @@ class LoRAManager: ...@@ -326,7 +307,7 @@ class LoRAManager:
""" """
Update all LoRA modules to associate them with the latest memory buffer. Update all LoRA modules to associate them with the latest memory buffer.
""" """
for layer_id, layer_modules in self.lora_modules.items(): for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items(): for module_name, module in layer_modules.items():
if "qkv_proj" in module_name: if "qkv_proj" in module_name:
module.set_lora_info( module.set_lora_info(
...@@ -353,115 +334,94 @@ class LoRAManager: ...@@ -353,115 +334,94 @@ class LoRAManager:
), ),
) )
def init_state(self): def init_state(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
""" """
Initialize the internal (mutable) state of the LoRAManager. Initialize the internal (mutable) state of the LoRAManager.
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically. When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
the target modules and max_lora_rank.
""" """
# Configs of all active LoRA adapters. assert lora_paths or (
self.configs: Dict[str, LoRAConfig] = {} max_lora_rank is not None and target_modules is not None
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
# LoRA adapter weights cached in CPU memory.
self.loras: Dict[str, LoRAAdapter] = {}
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively. self.init_lora_adapters(lora_paths)
self.lora_weight_names: Tuple[Set[str]] = (set(), set()) self.init_lora_shapes(
max_lora_rank=max_lora_rank,
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module. target_modules=target_modules,
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = { )
i: {} for i in range(self.base_hf_config.num_hidden_layers) self.init_lora_weight_names()
} self.init_lora_modules()
self.init_memory_pool()
# The LoRA memory pool that manages the GPU buffers for active LoRA weights. def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# It is initialized lazily when the first LoRA adapter is loaded. # Configs of all active LoRA adapters, indexed by LoRA ID.
self.memory_pool: Optional[LoRAMemoryPool] = None self.configs: Dict[str, LoRAConfig] = {}
def update_state_from_configs(self): # LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
""" self.loras: Dict[str, LoRAAdapter] = {}
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# Loads / unloads LoRA adapters based on the latest configs. # Mapping from LoRA ID to LoRARef object.
self.update_lora_adapters() self.lora_refs: Dict[str, LoRARef] = {}
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
def apply_lora_configs(self): if lora_paths:
""" for lora_ref in lora_paths.values():
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing. result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
)
Notes: def init_lora_shapes(
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as self,
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer max_lora_rank: Optional[int] = None,
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in target_modules: Optional[Iterable[str]] = None,
early CY25H2. ):
""" """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
if self.memory_pool is None: if target_modules is not None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args. self.target_modules = set(target_modules)
if self.target_modules is None: else:
self.target_modules = set() self.target_modules = set()
for config in self.configs.values(): for config in self.configs.values():
self.target_modules.update(config.target_modules) self.target_modules.update(config.target_modules)
if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
self.update_lora_weight_names() if max_lora_rank is not None:
self.update_lora_modules() self.max_lora_rank = max_lora_rank
self.update_memory_buffers()
else: else:
# No-op if the memory pool can support the current LoRA configurations. self.max_lora_rank = max(
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target [x.hf_config["r"] for x in self.configs.values()],
# module is changed once FlashInfer backend is deprecated. default=0,
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
) )
def update_lora_weight_names(self): def init_lora_weight_names(self):
""" """
Add new LoRA weight names if needed based on the current `self.configs`. Add new LoRA weight names if needed based on the current `self.configs`.
""" """
# Target lora weight names for lora_a and lora_b modules respectively. # Target lora weight names for lora_a and lora_b modules respectively.
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules) lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A) self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
self.lora_weight_names[1].update(lora_B)
def update_lora_adapters(self): def load_lora_weights(self, lora_ref: LoRARef):
""" """
Update the LoRA adapters in CPU memory based on the current `self.configs`. Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
""" """
lora_adapter = LoRAAdapter(
# Load new adapter weights to cpu lora_ref.lora_id,
for name, config in self.configs.items(): self.configs[lora_ref.lora_id],
if name not in self.loras: self.base_hf_config,
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}") self.load_config,
lora_adapter = LoRAAdapter( self.lora_backend,
name, )
config, lora_adapter.initialize_weights()
self.base_hf_config, self.loras[lora_ref.lora_id] = lora_adapter
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
# Additional checks for flashinfer backend # Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
...@@ -472,7 +432,7 @@ class LoRAManager: ...@@ -472,7 +432,7 @@ class LoRAManager:
len(lora_dims) == 1 and len(scalings) == 1 len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self): def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations.""" """(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool( self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config, base_hf_config=self.base_hf_config,
...@@ -490,7 +450,12 @@ class LoRAManager: ...@@ -490,7 +450,12 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module) replace_submodule(self.base_model, module_name, lora_module)
return lora_module return lora_module
def update_lora_modules(self): def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
{} for _ in range(self.base_hf_config.num_hidden_layers)
]
# Target module names of customized layers defined in python/sglang/srt/layers # Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"} # e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names( customized_target_names = get_customized_names_from_hf_names(
...@@ -511,7 +476,6 @@ class LoRAManager: ...@@ -511,7 +476,6 @@ class LoRAManager:
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names: if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
if module_name not in self.lora_modules[layer_id]: self.lora_modules[layer_id][module_name] = self.set_lora_module(
self.lora_modules[layer_id][module_name] = self.set_lora_module( module_name, module
module_name, module )
)
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
from dataclasses import dataclass, field, fields
from typing import Dict, List, Optional, Union
from uuid import uuid4
@dataclass(frozen=True, slots=True)
class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
def __post_init__(self):
if self.lora_id is None:
raise ValueError("lora_id cannot be None")
def __str__(self) -> str:
parts = [
f"{f.name}={value}"
for f in fields(self)
if (value := getattr(self, f.name)) is not None
]
return f"{self.__class__.__name__}({', '.join(parts)})"
class LoRARegistry:
"""
The central registry to keep track of available LoRA adapters.
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
)
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
async def register(self, lora_ref: LoRARef):
"""
Register a new LoRARef object in the registry.
Args:
lora_ref (LoRARef): The LoRARef object to register.
"""
if lora_ref.lora_name in self._registry:
raise ValueError(
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
)
self._registry[lora_ref.lora_name] = lora_ref
async def unregister(self, lora_name: str) -> str:
"""
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
Args:
lora_name (str): The name of the LoRA model to unregister.
"""
lora_ref = self._registry.get(lora_name, None)
if lora_ref is None:
raise ValueError(
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
)
del self._registry[lora_name]
return lora_ref.lora_id
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
"""
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
by incrementing its counter.
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
"""
async def _acquire_single(name: str) -> str:
lora_ref = self._registry.get(name, None)
if lora_ref is None:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {name}\n"
f"Loaded adapters: {self._registry.keys()}."
)
# await self._counters[lora_ref.lora_id].increment()
return lora_ref.lora_id
if isinstance(lora_name, str):
lora_id = await _acquire_single(lora_name)
return lora_id
elif isinstance(lora_name, list):
lora_ids = await asyncio.gather(
*[_acquire_single(name) for name in lora_name]
)
return lora_ids
else:
raise TypeError("lora_name must be either a string or a list of strings.")
...@@ -153,7 +153,7 @@ class LoRAMemoryPool: ...@@ -153,7 +153,7 @@ class LoRAMemoryPool:
self, self,
cur_uids: Set[Optional[str]], cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter], lora_adapters: Dict[str, LoRAAdapter],
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], lora_modules: List[Dict[str, BaseLayerWithLoRA]],
): ):
def get_available_buffer_slot(): def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch): for buffer_id in range(self.max_loras_per_batch):
...@@ -186,7 +186,7 @@ class LoRAMemoryPool: ...@@ -186,7 +186,7 @@ class LoRAMemoryPool:
uid: str, uid: str,
buffer_id: int, buffer_id: int,
lora_adapter: LoRAAdapter, lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], lora_modules: List[Dict[str, BaseLayerWithLoRA]],
): ):
def load_lora_weight_tensor( def load_lora_weight_tensor(
buffer_view: torch.Tensor, weight: Optional[torch.Tensor] buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
......
...@@ -22,6 +22,7 @@ from dataclasses import dataclass, field ...@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.multimodal.mm_utils import has_valid_data from sglang.srt.multimodal.mm_utils import has_valid_data
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
...@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput: ...@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
lora_name: str lora_name: str
# The path of loading. # The path of loading.
lora_path: str lora_path: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
lora_path=self.lora_path,
)
@dataclass @dataclass
class UnloadLoRAAdapterReqInput: class UnloadLoRAAdapterReqInput:
# The name of lora module to unload. # The name of lora module to unload.
lora_name: str lora_name: str
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
def to_ref(self) -> LoRARef:
return LoRARef(
lora_id=self.lora_id,
lora_name=self.lora_name,
)
@dataclass @dataclass
class LoRAUpdateResult: class LoRAUpdateResult:
success: bool success: bool
error_message: Optional[str] = None error_message: Optional[str] = None
loaded_adapters: Dict[str, str] = field(default_factory=dict) loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
...@@ -247,7 +247,7 @@ class Scheduler( ...@@ -247,7 +247,7 @@ class Scheduler(
self.pp_size = server_args.pp_size self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths self.enable_lora = server_args.enable_lora
self.max_loras_per_batch = server_args.max_loras_per_batch self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule self.enable_overlap = not server_args.disable_overlap_schedule
self.skip_tokenizer_init = server_args.skip_tokenizer_init self.skip_tokenizer_init = server_args.skip_tokenizer_init
...@@ -1706,13 +1706,13 @@ class Scheduler( ...@@ -1706,13 +1706,13 @@ class Scheduler(
self.chunked_req.init_next_round_input() self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req) self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths: if self.enable_lora:
lora_set = set([req.lora_path for req in self.running_batch.reqs]) lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if ( if (
self.lora_paths self.enable_lora
and len( and len(
lora_set lora_set
| set([req.lora_path for req in adder.can_run_list]) | set([req.lora_path for req in adder.can_run_list])
...@@ -2466,12 +2466,6 @@ class Scheduler( ...@@ -2466,12 +2466,6 @@ class Scheduler(
"""In-place loading a new lora adapter from disk or huggingface.""" """In-place loading a new lora adapter from disk or huggingface."""
result = self.tp_worker.load_lora_adapter(recv_req) result = self.tp_worker.load_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after loading lora adapter."
else:
logger.error(result.error_message)
return result return result
def unload_lora_adapter( def unload_lora_adapter(
...@@ -2480,14 +2474,6 @@ class Scheduler( ...@@ -2480,14 +2474,6 @@ class Scheduler(
"""Unload the lora adapter.""" """Unload the lora adapter."""
result = self.tp_worker.unload_lora_adapter(recv_req) result = self.tp_worker.unload_lora_adapter(recv_req)
if result.success:
flush_cache_success = self.flush_cache()
assert (
flush_cache_success
), "Cache flush failed after unloading LoRA weights"
else:
logger.error(result.error_message)
return result return result
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
......
...@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
get_tokenizer, get_tokenizer,
get_tokenizer_from_processor, get_tokenizer_from_processor,
) )
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
...@@ -242,11 +243,11 @@ class TokenizerManager: ...@@ -242,11 +243,11 @@ class TokenizerManager:
revision=server_args.revision, revision=server_args.revision,
) )
# Initialize loaded loRA adapters with the initial lora paths in the server_args. # Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically. # The registry dynamically updates as adapters are loaded / unloaded during runtime. It
self.loaded_lora_adapters: Dict[str, str] = dict( # serves as the source of truth for available adapters and maps user-friendly LoRA names
self.server_args.lora_paths or {} # to internally used unique LoRA IDs.
) self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
# Store states # Store states
self.no_create_loop = False self.no_create_loop = False
...@@ -523,6 +524,10 @@ class TokenizerManager: ...@@ -523,6 +524,10 @@ class TokenizerManager:
else: else:
mm_inputs = None mm_inputs = None
if self.server_args.enable_lora and obj.lora_path:
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
self._validate_one_request(obj, input_ids) self._validate_one_request(obj, input_ids)
return self._create_tokenized_object( return self._create_tokenized_object(
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
...@@ -574,8 +579,6 @@ class TokenizerManager: ...@@ -574,8 +579,6 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. " "The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature." "Please set `--enable-custom-logits-processor` to enable this feature."
) )
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab( def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int self, input_ids: List[int], vocab_size: int
...@@ -689,21 +692,6 @@ class TokenizerManager: ...@@ -689,21 +692,6 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`." "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
) )
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request( def _send_one_request(
self, self,
obj: Union[GenerateReqInput, EmbeddingReqInput], obj: Union[GenerateReqInput, EmbeddingReqInput],
...@@ -1054,8 +1042,18 @@ class TokenizerManager: ...@@ -1054,8 +1042,18 @@ class TokenizerManager:
) )
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
# Generate new uniquely identifiable LoRARef object.
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
)
# Register the new adapter in the registry.
obj.lora_id = new_adapter.lora_id
result = (await self.update_lora_adapter_communicator(obj))[0] result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters if result.success:
await self.lora_registry.register(new_adapter)
return result return result
async def unload_lora_adapter( async def unload_lora_adapter(
...@@ -1069,6 +1067,10 @@ class TokenizerManager: ...@@ -1069,6 +1067,10 @@ class TokenizerManager:
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA." "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
) )
assert (
obj.lora_name is not None
), "lora_name must be provided to unload LoRA adapter"
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1. # with dp_size > 1.
assert ( assert (
...@@ -1080,8 +1082,9 @@ class TokenizerManager: ...@@ -1080,8 +1082,9 @@ class TokenizerManager:
) )
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
result = (await self.update_lora_adapter_communicator(obj))[0] result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result return result
async def get_weights_by_name( async def get_weights_by_name(
...@@ -1309,7 +1312,7 @@ class TokenizerManager: ...@@ -1309,7 +1312,7 @@ class TokenizerManager:
filename = os.path.join( filename = os.path.join(
self.crash_dump_folder, self.crash_dump_folder,
os.getenv("HOSTNAME", None), os.getenv("HOSTNAME", None),
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl', f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
) )
os.makedirs(os.path.dirname(filename), exist_ok=True) os.makedirs(os.path.dirname(filename), exist_ok=True)
......
...@@ -293,11 +293,9 @@ class TpModelWorker: ...@@ -293,11 +293,9 @@ class TpModelWorker:
return parameter return parameter
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput): def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
result = self.model_runner.load_lora_adapter( result = self.model_runner.load_lora_adapter(recv_req.to_ref())
recv_req.lora_name, recv_req.lora_path
)
return result return result
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput): def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.lora_name) result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result return result
...@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler ...@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS, GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict, global_server_args_dict,
...@@ -890,44 +891,38 @@ class ModelRunner: ...@@ -890,44 +891,38 @@ class ModelRunner:
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank, max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules, target_modules=self.server_args.lora_target_modules,
lora_paths=self.server_args.lora_paths,
) )
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
if result.success:
logger.info(
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
)
else:
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
def load_lora_adapter(self, lora_name: str, lora_path: str): def load_lora_adapter(self, lora_ref: LoRARef):
"""Load a new lora adapter from disk or huggingface.""" """Load a new lora adapter from disk or huggingface."""
logger.info( logger.info(
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. " f"LoRA adapter loading starts: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
result = self.lora_manager.load_lora_adapter(lora_name, lora_path) result = self.lora_manager.load_lora_adapter(lora_ref)
logger.info( logger.info(
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. " f"LoRA adapter loading completes: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
return result return result
def unload_lora_adapter(self, lora_name: str): def unload_lora_adapter(self, lora_ref: LoRARef):
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading.""" """Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
logger.info( logger.info(
f"LoRA adapter unloading starts: name={lora_name}. " f"LoRA adapter unloading starts: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
result = self.lora_manager.unload_lora_adapter(lora_name) result = self.lora_manager.unload_lora_adapter(lora_ref)
logger.info( logger.info(
f"LoRA adapter unloading completes: name={lora_name}. " f"LoRA adapter unloading completes: {lora_ref}. "
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
) )
......
...@@ -20,10 +20,10 @@ import logging ...@@ -20,10 +20,10 @@ import logging
import os import os
import random import random
import tempfile import tempfile
from token import OP
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import ( from sglang.srt.utils import (
LORA_TARGET_ALL_MODULES, LORA_TARGET_ALL_MODULES,
...@@ -145,7 +145,7 @@ class ServerArgs: ...@@ -145,7 +145,7 @@ class ServerArgs:
enable_lora: Optional[bool] = None enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton" lora_backend: str = "triton"
...@@ -1843,9 +1843,24 @@ class ServerArgs: ...@@ -1843,9 +1843,24 @@ class ServerArgs:
for lora_path in lora_paths: for lora_path in lora_paths:
if "=" in lora_path: if "=" in lora_path:
name, path = lora_path.split("=", 1) name, path = lora_path.split("=", 1)
self.lora_paths[name] = path self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
else: else:
self.lora_paths[lora_path] = lora_path self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path,
lora_path=lora_path,
)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v)
for k, v in self.lora_paths.items()
}
elif self.lora_paths is None:
self.lora_paths = {}
else:
raise ValueError(
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
"Expected a list or a dictionary."
)
# Expand target modules # Expand target modules
if self.lora_target_modules: if self.lora_target_modules:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
import contextlib
import multiprocessing as mp import multiprocessing as mp
import unittest import unittest
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
...@@ -39,6 +40,16 @@ ADAPTERS = [ ...@@ -39,6 +40,16 @@ ADAPTERS = [
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
@contextlib.contextmanager
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
"""A context manager to load and automatically unload a LoRA adapter."""
try:
runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
yield
finally:
runner.unload_lora_adapter(lora_name=lora_name)
class TestLoRAEviction(CustomTestCase): class TestLoRAEviction(CustomTestCase):
def test_lora_eviction_with_different_target_modules(self): def test_lora_eviction_with_different_target_modules(self):
""" """
...@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase): ...@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase):
self._run_test(ADAPTERS, output_history, reverse=False) self._run_test(ADAPTERS, output_history, reverse=False)
self._run_test(ADAPTERS, output_history, reverse=True) self._run_test(ADAPTERS, output_history, reverse=True)
def test_lora_eviction_with_reused_lora_name(self):
"""
Test LoRA eviction with reused LoRA names.
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
works correctly when reusing LoRA names.
"""
output_history = {}
self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)
def _run_test( def _run_test(
self, self,
lora_paths: List[str], lora_paths: List[str],
output_history: Dict[Tuple[str, str], str], output_history: Dict[Tuple[str, str], str],
reverse: bool, reverse: bool = False,
repeat: int = 2, repeat: int = 2,
reuse_lora_name: bool = False,
): ):
REUSED_LORA_NAME = "lora"
max_new_tokens = 256 max_new_tokens = 256
backend = "triton" backend = "triton"
torch_dtype = torch.float16 torch_dtype = torch.float16
base_path = BASE_MODEL base_path = BASE_MODEL
assert len(lora_paths) >= 2 assert len(lora_paths) >= 2
initial_lora_paths = lora_paths if not reuse_lora_name else None
# Initialize runners # Initialize runners
with SRTRunner( with SRTRunner(
base_path, base_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
model_type="generation", model_type="generation",
lora_paths=lora_paths, lora_paths=initial_lora_paths,
max_loras_per_batch=1, max_loras_per_batch=1,
lora_backend=backend, lora_backend=backend,
disable_radix_cache=True, disable_radix_cache=True,
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
) as srt_runner: ) as srt_runner:
adapter_sequence = lora_paths if not reverse else lora_paths[::-1] adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
for i in range(repeat): for i in range(repeat):
for j, adapter in enumerate(adapter_sequence): for j, lora_path in enumerate(adapter_sequence):
print( print(
f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---" f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
)
lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
context = (
dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
if reuse_lora_name
else contextlib.nullcontext()
) )
for prompt in PROMPTS: with context:
print("\nprompt:\n", prompt) for prompt in PROMPTS:
srt_outputs = srt_runner.forward( print("\nprompt:\n", prompt)
[prompt], srt_outputs = srt_runner.forward(
max_new_tokens=max_new_tokens, [prompt],
lora_paths=[adapter], max_new_tokens=max_new_tokens,
) lora_paths=[lora_name],
output = srt_outputs.output_strs[0].strip()
print("\noutput:\n", output)
prev_output = output_history.get((adapter, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
) )
else: output = srt_outputs.output_strs[0].strip()
output_history[(adapter, prompt)] = output print("\noutput:\n", output)
prev_output = output_history.get((lora_path, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
)
else:
output_history[(lora_path, prompt)] = output
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -14,7 +14,7 @@ class TestFile: ...@@ -14,7 +14,7 @@ class TestFile:
suites = { suites = {
"per-commit": [ "per-commit": [
TestFile("models/lora/test_lora.py", 200), TestFile("models/lora/test_lora.py", 200),
TestFile("models/lora/test_lora_eviction.py", 120), TestFile("models/lora/test_lora_eviction.py", 200),
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_cuda_graph.py", 250),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment