"torchvision/transforms/v2/__init__.py" did not exist on "aea748b3ef0387b0b6ac1efa8aee07bb109a9561"
Unverified Commit e2ed9d04 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)

parent b5dd5e87
......@@ -33,6 +33,10 @@
"\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n",
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
......@@ -176,6 +180,241 @@
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dynamic LoRA loading"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Basic Usage\n",
"\n",
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
"\n",
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
" },\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora2\",\n",
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
" },\n",
")\n",
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Advanced: hosting adapters of different shapes\n",
"\n",
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
"\n",
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" f\"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0={lora0} \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 64\n",
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
......
......@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
## Kernel backend
......
......@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, Set, Tuple
from typing import Dict, Iterable, Optional, Set, Tuple
import torch
......@@ -53,6 +53,8 @@ class LoRAManager:
lora_backend: str = "triton",
tp_size: int = 1,
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
......@@ -62,6 +64,10 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
......@@ -153,7 +159,9 @@ class LoRAManager:
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try:
self.configs[lora_name] = LoRAConfig(lora_path)
new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
except Exception as e:
success = False
error_message = (
......@@ -168,6 +176,21 @@ class LoRAManager:
error_message=error_message,
)
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
......@@ -214,7 +237,7 @@ class LoRAManager:
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
......@@ -319,7 +342,7 @@ class LoRAManager:
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(
......@@ -351,58 +374,67 @@ class LoRAManager:
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
# It is initialized lazily when the first LoRA adapter is loaded.
self.memory_pool: Optional[LoRAMemoryPool] = None
def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
This includes:
- Initializing LoRA adapters if they are not already loaded.
- Collect all LoRA weight names based on the current loaded adapters.
- Lazily monkey-patching the base model to use LoRA layers where applicable.
- Preparing the GPU buffer pool for active LoRA weights.
"""
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
hf_target_module_names: Set[str] = set()
for config in self.configs.values():
hf_target_module_names.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
Notes:
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
early CY25H2.
"""
if self.memory_pool is None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
self.update_lora_weight_names()
self.update_lora_modules()
self.update_memory_buffers()
else:
# No-op if the memory pool can support the current LoRA configurations.
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# module is changed once FlashInfer backend is deprecated.
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
)
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
#
# Please note that the following update operations are "monotonic" by design, meaning that we update
# multiple places to support the new weight names when the first adapter targeting such weight names
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
# list of LoRA weight names is expected to be extremely finite and stable.
self.update_lora_weight_names(hf_target_module_names)
self.update_lora_modules(hf_target_module_names)
self.update_memory_buffers(max_lora_dim)
def update_lora_weight_names(self, hf_target_names: Set[str]):
def update_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
def update_lora_adapters(self):
"""
......@@ -434,21 +466,23 @@ class LoRAManager:
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self, max_lora_dim: int):
"""
Update the LoRA memory pool buffers based on the current LoRA configurations and update
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
are set or updated.
"""
self.memory_pool.init_buffers(
self.lora_weight_names, self.base_model, max_lora_dim
def update_memory_buffers(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
max_loras_per_batch=self.max_loras_per_batch,
dtype=self.dtype,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
base_model=self.base_model,
)
def set_lora_module(self, module_name, module):
......@@ -456,11 +490,11 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def update_lora_modules(self, hf_target_names: Set[str]):
def update_lora_modules(self):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
hf_target_names, self.base_model
self.target_modules, self.base_model
)
for module_name, module in self.base_model.named_modules():
......
from typing import Callable, Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
......@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
get_hidden_dim,
get_normalized_lora_weight_names,
get_stacked_multiply,
get_weight_name,
)
......@@ -25,6 +27,9 @@ class LoRAMemoryPool:
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]],
base_model: torch.nn.Module,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
......@@ -32,6 +37,10 @@ class LoRAMemoryPool:
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank
# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
......@@ -49,6 +58,31 @@ class LoRAMemoryPool:
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.init_buffers(base_model)
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
"""
Check if the memory pool can support the given LoRA adapters.
"""
def _can_support(config: LoRAConfig) -> bool:
"""
Check if the memory pool can support a single LoRA adapter.
"""
if config.r > self.max_lora_rank:
return False
weights_a, weights_b = get_normalized_lora_weight_names(
config.target_modules
)
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
self.lora_weight_names[1]
)
if isinstance(config, LoRAConfig):
return _can_support(config)
else:
return all(_can_support(x) for x in config)
def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]:
......@@ -82,25 +116,18 @@ class LoRAMemoryPool:
max_lora_dim,
)
def init_buffers(
self,
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
max_lora_dim: int,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
def init_buffers(self, base_model: torch.nn.Module):
device = next(base_model.parameters()).device
def update_buffer(
def init_buffer(
buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
):
new_weight_names = lora_weight_names - buffer.keys()
for module_name in new_weight_names:
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
for module_name in lora_weight_names:
lora_shape = get_lora_shape_fn(
module_name, base_model, self.max_lora_rank
)
buffer[module_name] = [
torch.empty(
lora_shape,
......@@ -110,15 +137,15 @@ class LoRAMemoryPool:
for _ in range(self.num_layer)
]
update_buffer(
init_buffer(
self.A_buffer,
lora_weight_names[0],
self.lora_weight_names[0],
self.get_lora_A_shape,
)
update_buffer(
init_buffer(
self.B_buffer,
lora_weight_names[1],
self.lora_weight_names[1],
self.get_lora_B_shape,
)
......
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
......@@ -106,9 +106,11 @@ def get_hidden_dim(
raise NotImplementedError()
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
"""
Mapping a target module name to names of the normalized LoRA weights.
Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
......@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
}
stacked = params_mapping.get(name, ([name], [name]))
return stacked
result = (set(), set())
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
return result
def get_stacked_multiply(module_name: str) -> int:
......
......@@ -891,6 +891,8 @@ class ModelRunner:
lora_backend=self.server_args.lora_backend,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
)
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
if result.success:
......
......@@ -134,6 +134,8 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None
# LoRA
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List[str]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
......@@ -1129,6 +1131,28 @@ class ServerArgs:
)
# LoRA
parser.add_argument(
"--max-lora-rank",
default=ServerArgs.max_lora_rank,
type=int,
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument(
"--lora-target-modules",
type=str,
choices=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
nargs="*",
default=None,
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument(
"--lora-paths",
type=str,
......
......@@ -505,6 +505,8 @@ class SRTRunner:
torchao_config: Optional[str] = None,
cuda_graph_max_bs: int = 4,
sleep_on_idle=False,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
......@@ -543,6 +545,8 @@ class SRTRunner:
cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
**spec_kwargs,
)
......
......@@ -16,7 +16,7 @@ import multiprocessing as mp
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union
from typing import Any, Iterable, List, Optional, Union
import requests
import torch
......@@ -27,6 +27,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
......@@ -45,24 +46,28 @@ class OperationType(Enum):
LOAD = "load"
UNLOAD = "unload"
FORWARD = "forward"
EXPECT_ERROR = "expect_error"
@dataclass
class Operation:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
# Operation type, can be LOAD, UNLOAD, FORWARD
type: OperationType
# Data associated with the operation. Exact type varies depending on the operation
data: Optional[Any]
# If the operation is expected to fail, this is the error message to expect
expected_error: Optional[str] = None
@dataclass
class TestCase:
description: str
base: str
max_loras_per_batch: int
all_adapters: List[str]
initial_adapters: List[str]
op_sequence: List[Operation]
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List] = None
max_new_tokens: int = 32
......@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
TEST_CASES = [
# basic test, no eviction
BASIC_TESTS = [
TestCase(
description="dynamic lora update with initial lora_paths",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
all_adapters=[
......@@ -89,20 +94,16 @@ TEST_CASES = [
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
......@@ -127,11 +128,9 @@ TEST_CASES = [
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
......@@ -147,13 +146,11 @@ TEST_CASES = [
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
......@@ -174,8 +171,8 @@ TEST_CASES = [
),
],
),
# Eviction
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1,
all_adapters=[
......@@ -190,20 +187,16 @@ TEST_CASES = [
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
......@@ -214,11 +207,9 @@ TEST_CASES = [
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
......@@ -263,6 +254,253 @@ TEST_CASES = [
],
),
]
TARGET_MODULE_TESTS = [
TestCase(
description="Test explicitly specified lora-target-modules.",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
lora_target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="Test inferred lora-target-modules - start with larger adapter",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="Test inferred lora-target-modules - start with smaller adapter",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
None,
]
),
),
],
),
]
MAX_LORA_RANK_TESTS = [
TestCase(
description="Test explicitly specified max-lora-rank.",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=32,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
],
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"pbevan11/llama-3.1-8b-ocr-correction",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"pbevan11/llama-3.1-8b-ocr-correction",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="test implicitly inferred max-lora-rank",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
],
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
None,
]
),
),
],
),
]
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
class LoRAUpdateTestSessionMode(Enum):
......@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase:
testcase: Optional[TestCase],
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int = 1,
max_loras_per_batch: int,
max_lora_rank: Optional[int],
lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
......@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase:
self.testcase = testcase
self.model_path = model_path
self.lora_paths = lora_paths
self.max_lora_rank = max_lora_rank
self.lora_target_modules = lora_target_modules
self.max_loras_per_batch = max_loras_per_batch
self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph
......@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase:
# Don't suppress exceptions by default
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
"""
Load a LoRA adapter by name and path.
"""
......@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase:
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: Optional[str] = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
......@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
model_path=self.model_path,
model_type="generation",
lora_paths=self.lora_paths,
max_lora_rank=self.max_lora_rank,
lora_target_modules=self.lora_target_modules,
lora_backend=self.lora_backend,
torch_dtype=torch.float16,
mem_fraction_static=MEM_FRACTION_STATIC,
......@@ -357,24 +607,32 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
# don't suppress exceptions
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
"""
Load a LoRA adapter by name and path.
"""
if lora_path is None:
lora_path = lora_name
self.expected_adapters.add(lora_name)
response = self.handle.load_lora_adapter(
lora_name=lora_name,
lora_path=lora_path,
)
self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
if expected_error:
self.testcase.assertFalse(response.success)
self.testcase.assertIn(expected_error, response.error_message)
print(f"Received error as expected: {response.error_message}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def unload_lora_adapter(self, lora_name: str):
"""
......@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
expected_error: Optional[str] = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
......@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
]
if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None:
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
if self.lora_target_modules is not None:
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
# launch external server
self.handle = popen_launch_server(
......@@ -464,24 +726,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
# don't suppress exceptions
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
"""
Load a LoRA adapter by name and path.
"""
if lora_path is None:
lora_path = lora_name
self.expected_adapters.add(lora_name)
response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path},
)
self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
print(f"Received error as expected: {response.text}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def unload_lora_adapter(self, lora_name: str):
"""
......@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
expected_error: Optional[str] = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
......@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
# Factory function to create the appropriate LoRA test session based on mode
def LoRAUpdateTestSession(
*,
testcase: Optional[TestCase],
mode: LoRAUpdateTestSessionMode,
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int = 1,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
**kwargs: Any,
):
common_kwargs = {
"testcase": testcase,
"model_path": model_path,
"lora_paths": lora_paths,
"max_loras_per_batch": max_loras_per_batch,
"lora_backend": lora_backend,
"disable_cuda_graph": disable_cuda_graph,
"cuda_graph_max_bs": cuda_graph_max_bs,
}
if mode == LoRAUpdateTestSessionMode.ENGINE:
return LoRAUpdateEngineTestSession(**common_kwargs)
return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
elif mode == LoRAUpdateTestSessionMode.SERVER:
return LoRAUpdateServerTestSession(**common_kwargs)
return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
else:
raise ValueError(f"Unrecognized mode: {mode!r}")
......@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase):
initial_adapters: List[str],
max_loras_per_batch: int,
op_sequence: List[Operation],
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
max_new_tokens: int = 32,
) -> List[tuple]:
"""
......@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase):
model_path=base,
lora_paths=initial_adapters,
max_loras_per_batch=max_loras_per_batch,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
) as session:
for op in op_sequence:
op_type = op.type
data = op.data
expected_error = op.expected_error
print("-" * 100)
print(
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
......@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
result = session.load_lora_adapter(
lora_name=data,
lora_path=data,
expected_error=expected_error,
)
elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter(
......@@ -615,91 +875,105 @@ class TestLoRADynamicUpdate(CustomTestCase):
)
elif op_type == OperationType.FORWARD:
prompts, adapters = zip(*data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
)
forward_outputs.append(result)
elif op_type == OperationType.EXPECT_ERROR:
input_data, expected_error = data
prompts, adapters = zip(*input_data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
expected_error=expected_error,
)
if not expected_error:
forward_outputs.append(result)
return forward_outputs
def test_dynamic_adapter_updates(self):
for case_idx, test_case in enumerate(TEST_CASES, start=1):
for mode in [
LoRAUpdateTestSessionMode.ENGINE,
LoRAUpdateTestSessionMode.SERVER,
]:
print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.")
print("=" * 100)
print(
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
)
# Test dynamic loading of adapters
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
dynamic_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.initial_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens,
)
def _run_dynamic_adapter_updates(
self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
):
for case_idx, test_case in enumerate(test_cases, start=1):
print("=" * 100)
print(
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
)
print("=" * 100)
# static loading
forward_ops = [
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
]
print(
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
)
# Test dynamic loading of adapters
dynamic_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.initial_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens,
max_lora_rank=test_case.max_lora_rank,
lora_target_modules=test_case.lora_target_modules,
)
print("=" * 100)
print(
f"\n--- Running static pass with {len(forward_ops)} operations ---"
)
static_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.all_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=forward_ops,
max_new_tokens=test_case.max_new_tokens,
)
# static loading
forward_ops = [
x
for x in test_case.op_sequence
if x.type == OperationType.FORWARD and x.expected_error is None
]
print("=" * 100)
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
static_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.all_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=forward_ops,
max_new_tokens=test_case.max_new_tokens,
)
print(f"Dynamic output: {dynamic_output}")
print(f"Static output: {static_output}")
print("=" * 100)
print(f"Dynamic output: {dynamic_output}")
print(f"Static output: {static_output}")
print("=" * 100)
self.assertEqual(
len(dynamic_output),
len(static_output),
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
)
for i, (dynamic, static) in enumerate(
zip(dynamic_output, static_output), start=1
):
self.assertEqual(
len(dynamic_output),
len(static_output),
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
len(dynamic),
len(static),
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
)
for i, (dynamic, static) in enumerate(
zip(dynamic_output, static_output), start=1
):
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
d_out = d_out.strip()
s_out = s_out.strip()
self.assertEqual(
len(dynamic),
len(static),
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
d_out,
s_out,
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
)
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
d_out = d_out.strip()
s_out = s_out.strip()
self.assertEqual(
d_out,
s_out,
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
)
def test_dynamic_lora_update_engine(self):
"""
Test dynamic LoRA updates in engine mode.
"""
test_cases = ALL_TESTS
self._run_dynamic_adapter_updates(
mode=LoRAUpdateTestSessionMode.ENGINE,
test_cases=test_cases,
)
def test_dynamic_lora_update_server(self):
"""
Test dynamic LoRA updates in server mode.
"""
# In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness.
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
self._run_dynamic_adapter_updates(
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
)
if __name__ == "__main__":
......
......@@ -17,7 +17,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 400),
TestFile("models/lora/test_lora_update.py", 700),
TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment