"torchvision/vscode:/vscode.git/clone" did not exist on "517f6d3bd9ff12a3e02f39b816edce7346bc1075"
Unverified Commit e2ed9d04 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)

parent b5dd5e87
...@@ -33,6 +33,10 @@ ...@@ -33,6 +33,10 @@
"\n", "\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n", "\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n", "\n",
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to." "From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
...@@ -176,6 +180,241 @@ ...@@ -176,6 +180,241 @@
"terminate_process(server_process)" "terminate_process(server_process)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dynamic LoRA loading"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Basic Usage\n",
"\n",
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
"\n",
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
" },\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora2\",\n",
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
" },\n",
")\n",
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Advanced: hosting adapters of different shapes\n",
"\n",
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
"\n",
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" f\"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0={lora0} \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 64\n",
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
......
...@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
## Kernel backend ## Kernel backend
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving" # and "Punica: Multi-Tenant LoRA Serving"
import logging import logging
from typing import Dict, Set, Tuple from typing import Dict, Iterable, Optional, Set, Tuple
import torch import torch
...@@ -53,6 +53,8 @@ class LoRAManager: ...@@ -53,6 +53,8 @@ class LoRAManager:
lora_backend: str = "triton", lora_backend: str = "triton",
tp_size: int = 1, tp_size: int = 1,
tp_rank: int = 0, tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
): ):
self.base_model: torch.nn.Module = base_model self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
...@@ -62,6 +64,10 @@ class LoRAManager: ...@@ -62,6 +64,10 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels # LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.") logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
...@@ -153,7 +159,9 @@ class LoRAManager: ...@@ -153,7 +159,9 @@ class LoRAManager:
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first." error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try: try:
self.configs[lora_name] = LoRAConfig(lora_path) new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
except Exception as e: except Exception as e:
success = False success = False
error_message = ( error_message = (
...@@ -168,6 +176,21 @@ class LoRAManager: ...@@ -168,6 +176,21 @@ class LoRAManager:
error_message=error_message, error_message=error_message,
) )
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
""" """
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
...@@ -214,7 +237,7 @@ class LoRAManager: ...@@ -214,7 +237,7 @@ class LoRAManager:
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None: if lora_path is not None:
lora = self.loras[lora_path] lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer # Use pinned memory to avoid synchronizations during host-to-device transfer
...@@ -319,7 +342,7 @@ class LoRAManager: ...@@ -319,7 +342,7 @@ class LoRAManager:
) )
else: else:
weight_name = get_weight_name( weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
) )
module.set_lora_info( module.set_lora_info(
self.memory_pool.get_tensor( self.memory_pool.get_tensor(
...@@ -351,56 +374,65 @@ class LoRAManager: ...@@ -351,56 +374,65 @@ class LoRAManager:
i: {} for i in range(self.base_hf_config.num_hidden_layers) i: {} for i in range(self.base_hf_config.num_hidden_layers)
} }
# Initialize memory pool # The LoRA memory pool that manages the GPU buffers for active LoRA weights.
self.memory_pool = LoRAMemoryPool( # It is initialized lazily when the first LoRA adapter is loaded.
self.base_hf_config, self.memory_pool: Optional[LoRAMemoryPool] = None
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)
def update_state_from_configs(self): def update_state_from_configs(self):
""" """
Update the internal state of the LoRAManager based on the current `self.configs`. This method Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
This includes: Notes:
- Initializing LoRA adapters if they are not already loaded. - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
- Collect all LoRA weight names based on the current loaded adapters. we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
- Lazily monkey-patching the base model to use LoRA layers where applicable. LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
- Preparing the GPU buffer pool for active LoRA weights. early CY25H2.
""" """
# Target module names in huggingface lora configs. if self.memory_pool is None:
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} # Infer max_lora_rank and target_modules if not explicitly specified in server args.
hf_target_module_names: Set[str] = set() if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values(): for config in self.configs.values():
hf_target_module_names.update(config.target_modules) self.target_modules.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
# Loads / unloads LoRA adapters based on the latest configs. if self.max_lora_rank is None:
self.update_lora_adapters() self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed. self.update_lora_weight_names()
# self.update_lora_modules()
# Please note that the following update operations are "monotonic" by design, meaning that we update self.update_memory_buffers()
# multiple places to support the new weight names when the first adapter targeting such weight names else:
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer) # No-op if the memory pool can support the current LoRA configurations.
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# list of LoRA weight names is expected to be extremely finite and stable. # module is changed once FlashInfer backend is deprecated.
self.update_lora_weight_names(hf_target_module_names) assert self.memory_pool.can_support(self.configs.values()), (
self.update_lora_modules(hf_target_module_names) "LoRA memory pool cannot support the current LoRA configuration. "
self.update_memory_buffers(max_lora_dim) "This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
def update_lora_weight_names(self, hf_target_names: Set[str]): )
def update_lora_weight_names(self):
""" """
Add new LoRA weight names if needed based on the current `self.configs`. Add new LoRA weight names if needed based on the current `self.configs`.
""" """
# Target lora weight names for lora_a and lora_b modules respectively. # Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names: lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A) self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B) self.lora_weight_names[1].update(lora_B)
...@@ -434,21 +466,23 @@ class LoRAManager: ...@@ -434,21 +466,23 @@ class LoRAManager:
# Additional checks for flashinfer backend # Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer": if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values()) lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values()) scalings = set(x.scaling for x in self.loras.values())
assert ( assert (
len(lora_dims) == 1 and len(scalings) == 1 len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self, max_lora_dim: int): def update_memory_buffers(self):
""" """(Re)initialize the LoRA memory pool based on the current configurations."""
Update the LoRA memory pool buffers based on the current LoRA configurations and update self.memory_pool = LoRAMemoryPool(
LoRA modules to use the new buffers. This method should be called after the LoRA configurations base_hf_config=self.base_hf_config,
are set or updated. max_loras_per_batch=self.max_loras_per_batch,
""" dtype=self.dtype,
tp_size=self.tp_size,
self.memory_pool.init_buffers( tp_rank=self.tp_rank,
self.lora_weight_names, self.base_model, max_lora_dim max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
base_model=self.base_model,
) )
def set_lora_module(self, module_name, module): def set_lora_module(self, module_name, module):
...@@ -456,11 +490,11 @@ class LoRAManager: ...@@ -456,11 +490,11 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module) replace_submodule(self.base_model, module_name, lora_module)
return lora_module return lora_module
def update_lora_modules(self, hf_target_names: Set[str]): def update_lora_modules(self):
# Target module names of customized layers defined in python/sglang/srt/layers # Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"} # e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names( customized_target_names = get_customized_names_from_hf_names(
hf_target_names, self.base_model self.target_modules, self.base_model
) )
for module_name, module in self.base_model.named_modules(): for module_name, module in self.base_model.named_modules():
......
from typing import Callable, Dict, List, Optional, Set, Tuple from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide ...@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.utils import ( from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES, ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType, LoRAType,
get_hidden_dim, get_hidden_dim,
get_normalized_lora_weight_names,
get_stacked_multiply, get_stacked_multiply,
get_weight_name, get_weight_name,
) )
...@@ -25,6 +27,9 @@ class LoRAMemoryPool: ...@@ -25,6 +27,9 @@ class LoRAMemoryPool:
dtype: torch.dtype, dtype: torch.dtype,
tp_size: int, tp_size: int,
tp_rank: int, tp_rank: int,
max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]],
base_model: torch.nn.Module,
): ):
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers self.num_layer: int = base_hf_config.num_hidden_layers
...@@ -32,6 +37,10 @@ class LoRAMemoryPool: ...@@ -32,6 +37,10 @@ class LoRAMemoryPool:
self.dtype: torch.dtype = dtype self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank
# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
# Both A_buffer and B_buffer maps lora weight names to its buffer space. # Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape # A_buffer contains num_layer number of row-major tensors with shape
...@@ -49,6 +58,31 @@ class LoRAMemoryPool: ...@@ -49,6 +58,31 @@ class LoRAMemoryPool:
# Here we don't initialize to None since None is a valid uid # Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.init_buffers(base_model)
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
"""
Check if the memory pool can support the given LoRA adapters.
"""
def _can_support(config: LoRAConfig) -> bool:
"""
Check if the memory pool can support a single LoRA adapter.
"""
if config.r > self.max_lora_rank:
return False
weights_a, weights_b = get_normalized_lora_weight_names(
config.target_modules
)
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
self.lora_weight_names[1]
)
if isinstance(config, LoRAConfig):
return _can_support(config)
else:
return all(_can_support(x) for x in config)
def get_lora_A_shape( def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -82,25 +116,18 @@ class LoRAMemoryPool: ...@@ -82,25 +116,18 @@ class LoRAMemoryPool:
max_lora_dim, max_lora_dim,
) )
def init_buffers( def init_buffers(self, base_model: torch.nn.Module):
self,
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
max_lora_dim: int,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
device = next(base_model.parameters()).device device = next(base_model.parameters()).device
def update_buffer( def init_buffer(
buffer: Dict[str, List[torch.Tensor]], buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str], lora_weight_names: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
): ):
new_weight_names = lora_weight_names - buffer.keys() for module_name in lora_weight_names:
for module_name in new_weight_names: lora_shape = get_lora_shape_fn(
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim) module_name, base_model, self.max_lora_rank
)
buffer[module_name] = [ buffer[module_name] = [
torch.empty( torch.empty(
lora_shape, lora_shape,
...@@ -110,15 +137,15 @@ class LoRAMemoryPool: ...@@ -110,15 +137,15 @@ class LoRAMemoryPool:
for _ in range(self.num_layer) for _ in range(self.num_layer)
] ]
update_buffer( init_buffer(
self.A_buffer, self.A_buffer,
lora_weight_names[0], self.lora_weight_names[0],
self.get_lora_A_shape, self.get_lora_A_shape,
) )
update_buffer( init_buffer(
self.B_buffer, self.B_buffer,
lora_weight_names[1], self.lora_weight_names[1],
self.get_lora_B_shape, self.get_lora_B_shape,
) )
......
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
...@@ -106,9 +106,11 @@ def get_hidden_dim( ...@@ -106,9 +106,11 @@ def get_hidden_dim(
raise NotImplementedError() raise NotImplementedError()
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
""" """
Mapping a target module name to names of the normalized LoRA weights. Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B) Returned tuple contains (name for Lora A, name for Lora B)
""" """
params_mapping = { params_mapping = {
...@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: ...@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]), "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]), "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
} }
stacked = params_mapping.get(name, ([name], [name]))
return stacked result = (set(), set())
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
return result
def get_stacked_multiply(module_name: str) -> int: def get_stacked_multiply(module_name: str) -> int:
......
...@@ -891,6 +891,8 @@ class ModelRunner: ...@@ -891,6 +891,8 @@ class ModelRunner:
lora_backend=self.server_args.lora_backend, lora_backend=self.server_args.lora_backend,
tp_size=self.tp_size, tp_size=self.tp_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
) )
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
if result.success: if result.success:
......
...@@ -134,6 +134,8 @@ class ServerArgs: ...@@ -134,6 +134,8 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None preferred_sampling_params: Optional[str] = None
# LoRA # LoRA
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List[str]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton" lora_backend: str = "triton"
...@@ -1129,6 +1131,28 @@ class ServerArgs: ...@@ -1129,6 +1131,28 @@ class ServerArgs:
) )
# LoRA # LoRA
parser.add_argument(
"--max-lora-rank",
default=ServerArgs.max_lora_rank,
type=int,
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument(
"--lora-target-modules",
type=str,
choices=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
nargs="*",
default=None,
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument( parser.add_argument(
"--lora-paths", "--lora-paths",
type=str, type=str,
......
...@@ -505,6 +505,8 @@ class SRTRunner: ...@@ -505,6 +505,8 @@ class SRTRunner:
torchao_config: Optional[str] = None, torchao_config: Optional[str] = None,
cuda_graph_max_bs: int = 4, cuda_graph_max_bs: int = 4,
sleep_on_idle=False, sleep_on_idle=False,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
...@@ -543,6 +545,8 @@ class SRTRunner: ...@@ -543,6 +545,8 @@ class SRTRunner:
cuda_graph_max_bs=cuda_graph_max_bs, cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle, sleep_on_idle=sleep_on_idle,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
**spec_kwargs, **spec_kwargs,
) )
......
...@@ -16,7 +16,7 @@ import multiprocessing as mp ...@@ -16,7 +16,7 @@ import multiprocessing as mp
import unittest import unittest
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Any, List, Optional, Union from typing import Any, Iterable, List, Optional, Union
import requests import requests
import torch import torch
...@@ -27,6 +27,7 @@ from sglang.test.test_utils import ( ...@@ -27,6 +27,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase, CustomTestCase,
is_in_ci,
popen_launch_server, popen_launch_server,
) )
...@@ -45,24 +46,28 @@ class OperationType(Enum): ...@@ -45,24 +46,28 @@ class OperationType(Enum):
LOAD = "load" LOAD = "load"
UNLOAD = "unload" UNLOAD = "unload"
FORWARD = "forward" FORWARD = "forward"
EXPECT_ERROR = "expect_error"
@dataclass @dataclass
class Operation: class Operation:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR # Operation type, can be LOAD, UNLOAD, FORWARD
type: OperationType type: OperationType
# Data associated with the operation. Exact type varies depending on the operation # Data associated with the operation. Exact type varies depending on the operation
data: Optional[Any] data: Optional[Any]
# If the operation is expected to fail, this is the error message to expect
expected_error: Optional[str] = None
@dataclass @dataclass
class TestCase: class TestCase:
description: str
base: str base: str
max_loras_per_batch: int max_loras_per_batch: int
all_adapters: List[str] all_adapters: List[str]
initial_adapters: List[str] initial_adapters: List[str]
op_sequence: List[Operation] op_sequence: List[Operation]
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List] = None
max_new_tokens: int = 32 max_new_tokens: int = 32
...@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]: ...@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters] return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
TEST_CASES = [ BASIC_TESTS = [
# basic test, no eviction
TestCase( TestCase(
description="dynamic lora update with initial lora_paths",
base="meta-llama/Llama-3.1-8B-Instruct", base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3, max_loras_per_batch=3,
all_adapters=[ all_adapters=[
...@@ -89,20 +94,16 @@ TEST_CASES = [ ...@@ -89,20 +94,16 @@ TEST_CASES = [
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
), ),
"not loaded", expected_error="not loaded",
),
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), expected_error="not loaded",
"not loaded",
),
), ),
Operation( Operation(
type=OperationType.LOAD, type=OperationType.LOAD,
...@@ -127,11 +128,9 @@ TEST_CASES = [ ...@@ -127,11 +128,9 @@ TEST_CASES = [
data="philschmid/code-llama-3-1-8b-text-to-sql-lora", data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), expected_error="not loaded",
"not loaded",
),
), ),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
...@@ -147,13 +146,11 @@ TEST_CASES = [ ...@@ -147,13 +146,11 @@ TEST_CASES = [
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
), ),
"not loaded", expected_error="not loaded",
),
), ),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
...@@ -174,8 +171,8 @@ TEST_CASES = [ ...@@ -174,8 +171,8 @@ TEST_CASES = [
), ),
], ],
), ),
# Eviction
TestCase( TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct", base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1, max_loras_per_batch=1,
all_adapters=[ all_adapters=[
...@@ -190,20 +187,16 @@ TEST_CASES = [ ...@@ -190,20 +187,16 @@ TEST_CASES = [
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16" "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
), ),
"not loaded", expected_error="not loaded",
),
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"), expected_error="not loaded",
"not loaded",
),
), ),
Operation( Operation(
type=OperationType.LOAD, type=OperationType.LOAD,
...@@ -214,11 +207,9 @@ TEST_CASES = [ ...@@ -214,11 +207,9 @@ TEST_CASES = [
data="philschmid/code-llama-3-1-8b-text-to-sql-lora", data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
), ),
Operation( Operation(
type=OperationType.EXPECT_ERROR, type=OperationType.FORWARD,
data=( data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"), expected_error="not loaded",
"not loaded",
),
), ),
Operation( Operation(
type=OperationType.FORWARD, type=OperationType.FORWARD,
...@@ -263,6 +254,253 @@ TEST_CASES = [ ...@@ -263,6 +254,253 @@ TEST_CASES = [
], ],
), ),
] ]
TARGET_MODULE_TESTS = [
TestCase(
description="Test explicitly specified lora-target-modules.",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
lora_target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="Test inferred lora-target-modules - start with larger adapter",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="Test inferred lora-target-modules - start with smaller adapter",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
None,
]
),
),
],
),
]
MAX_LORA_RANK_TESTS = [
TestCase(
description="Test explicitly specified max-lora-rank.",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=32,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
],
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"pbevan11/llama-3.1-8b-ocr-correction",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"pbevan11/llama-3.1-8b-ocr-correction",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="test implicitly inferred max-lora-rank",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
],
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
None,
]
),
),
],
),
]
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
class LoRAUpdateTestSessionMode(Enum): class LoRAUpdateTestSessionMode(Enum):
...@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase: ...@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase:
testcase: Optional[TestCase], testcase: Optional[TestCase],
model_path: str, model_path: str,
lora_paths: list[str], lora_paths: list[str],
max_loras_per_batch: int = 1, max_loras_per_batch: int,
max_lora_rank: Optional[int],
lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton", lora_backend: str = "triton",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4, cuda_graph_max_bs: int = 4,
...@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase: ...@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase:
self.testcase = testcase self.testcase = testcase
self.model_path = model_path self.model_path = model_path
self.lora_paths = lora_paths self.lora_paths = lora_paths
self.max_lora_rank = max_lora_rank
self.lora_target_modules = lora_target_modules
self.max_loras_per_batch = max_loras_per_batch self.max_loras_per_batch = max_loras_per_batch
self.lora_backend = lora_backend self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph self.disable_cuda_graph = disable_cuda_graph
...@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase: ...@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase:
# Don't suppress exceptions by default # Don't suppress exceptions by default
return False return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
""" """
Load a LoRA adapter by name and path. Load a LoRA adapter by name and path.
""" """
...@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase: ...@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase:
prompts: List[str], prompts: List[str],
lora_paths: List[str], lora_paths: List[str],
max_new_tokens: int = 32, max_new_tokens: int = 32,
expected_error: Optional[str] = None,
): ):
""" """
Perform a batch forward pass with the current set of loaded LoRA adapters. Perform a batch forward pass with the current set of loaded LoRA adapters.
...@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
model_path=self.model_path, model_path=self.model_path,
model_type="generation", model_type="generation",
lora_paths=self.lora_paths, lora_paths=self.lora_paths,
max_lora_rank=self.max_lora_rank,
lora_target_modules=self.lora_target_modules,
lora_backend=self.lora_backend, lora_backend=self.lora_backend,
torch_dtype=torch.float16, torch_dtype=torch.float16,
mem_fraction_static=MEM_FRACTION_STATIC, mem_fraction_static=MEM_FRACTION_STATIC,
...@@ -357,22 +607,30 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -357,22 +607,30 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
# don't suppress exceptions # don't suppress exceptions
return False return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
""" """
Load a LoRA adapter by name and path. Load a LoRA adapter by name and path.
""" """
if lora_path is None: if lora_path is None:
lora_path = lora_name lora_path = lora_name
self.expected_adapters.add(lora_name)
response = self.handle.load_lora_adapter( response = self.handle.load_lora_adapter(
lora_name=lora_name, lora_name=lora_name,
lora_path=lora_path, lora_path=lora_path,
) )
if expected_error:
self.testcase.assertFalse(response.success)
self.testcase.assertIn(expected_error, response.error_message)
print(f"Received error as expected: {response.error_message}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.success) self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters) loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}") print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters) self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
...@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase): ...@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts: List[str], prompts: List[str],
lora_paths: List[str], lora_paths: List[str],
max_new_tokens: int = 32, max_new_tokens: int = 32,
expected_error: str = None, expected_error: Optional[str] = None,
): ):
""" """
Perform a batch forward pass with the current set of loaded LoRA adapters. Perform a batch forward pass with the current set of loaded LoRA adapters.
...@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
] ]
if self.disable_cuda_graph: if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph") other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None:
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
if self.lora_target_modules is not None:
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
# launch external server # launch external server
self.handle = popen_launch_server( self.handle = popen_launch_server(
...@@ -464,22 +726,30 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -464,22 +726,30 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
# don't suppress exceptions # don't suppress exceptions
return False return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None): def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
""" """
Load a LoRA adapter by name and path. Load a LoRA adapter by name and path.
""" """
if lora_path is None: if lora_path is None:
lora_path = lora_name lora_path = lora_name
self.expected_adapters.add(lora_name)
response = requests.post( response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter", DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path}, json={"lora_name": lora_name, "lora_path": lora_path},
) )
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
print(f"Received error as expected: {response.text}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.ok) self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"]) loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}") print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters) self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
...@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts: List[str], prompts: List[str],
lora_paths: List[str], lora_paths: List[str],
max_new_tokens: int = 32, max_new_tokens: int = 32,
expected_error: str = None, expected_error: Optional[str] = None,
): ):
""" """
Perform a batch forward pass with the current set of loaded LoRA adapters. Perform a batch forward pass with the current set of loaded LoRA adapters.
...@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase): ...@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
# Factory function to create the appropriate LoRA test session based on mode # Factory function to create the appropriate LoRA test session based on mode
def LoRAUpdateTestSession( def LoRAUpdateTestSession(
*,
testcase: Optional[TestCase], testcase: Optional[TestCase],
mode: LoRAUpdateTestSessionMode, mode: LoRAUpdateTestSessionMode,
model_path: str, **kwargs: Any,
lora_paths: list[str],
max_loras_per_batch: int = 1,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
): ):
common_kwargs = {
"testcase": testcase,
"model_path": model_path,
"lora_paths": lora_paths,
"max_loras_per_batch": max_loras_per_batch,
"lora_backend": lora_backend,
"disable_cuda_graph": disable_cuda_graph,
"cuda_graph_max_bs": cuda_graph_max_bs,
}
if mode == LoRAUpdateTestSessionMode.ENGINE: if mode == LoRAUpdateTestSessionMode.ENGINE:
return LoRAUpdateEngineTestSession(**common_kwargs) return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
elif mode == LoRAUpdateTestSessionMode.SERVER: elif mode == LoRAUpdateTestSessionMode.SERVER:
return LoRAUpdateServerTestSession(**common_kwargs) return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
else: else:
raise ValueError(f"Unrecognized mode: {mode!r}") raise ValueError(f"Unrecognized mode: {mode!r}")
...@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase):
initial_adapters: List[str], initial_adapters: List[str],
max_loras_per_batch: int, max_loras_per_batch: int,
op_sequence: List[Operation], op_sequence: List[Operation],
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
max_new_tokens: int = 32, max_new_tokens: int = 32,
) -> List[tuple]: ) -> List[tuple]:
""" """
...@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase):
model_path=base, model_path=base,
lora_paths=initial_adapters, lora_paths=initial_adapters,
max_loras_per_batch=max_loras_per_batch, max_loras_per_batch=max_loras_per_batch,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
) as session: ) as session:
for op in op_sequence: for op in op_sequence:
op_type = op.type op_type = op.type
data = op.data data = op.data
expected_error = op.expected_error
print("-" * 100) print("-" * 100)
print( print(
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---" f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
...@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
result = session.load_lora_adapter( result = session.load_lora_adapter(
lora_name=data, lora_name=data,
lora_path=data, lora_path=data,
expected_error=expected_error,
) )
elif op_type == OperationType.UNLOAD: elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter( result = session.unload_lora_adapter(
...@@ -615,40 +875,31 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -615,40 +875,31 @@ class TestLoRADynamicUpdate(CustomTestCase):
) )
elif op_type == OperationType.FORWARD: elif op_type == OperationType.FORWARD:
prompts, adapters = zip(*data) prompts, adapters = zip(*data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
)
forward_outputs.append(result)
elif op_type == OperationType.EXPECT_ERROR:
input_data, expected_error = data
prompts, adapters = zip(*input_data)
result = session.forward( result = session.forward(
prompts=list(prompts), prompts=list(prompts),
lora_paths=list(adapters), lora_paths=list(adapters),
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
expected_error=expected_error, expected_error=expected_error,
) )
if not expected_error:
forward_outputs.append(result)
return forward_outputs return forward_outputs
def test_dynamic_adapter_updates(self): def _run_dynamic_adapter_updates(
for case_idx, test_case in enumerate(TEST_CASES, start=1): self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
for mode in [ ):
LoRAUpdateTestSessionMode.ENGINE, for case_idx, test_case in enumerate(test_cases, start=1):
LoRAUpdateTestSessionMode.SERVER,
]:
print("=" * 100) print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.") print(
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
)
print("=" * 100) print("=" * 100)
print( print(
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---" f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
) )
# Test dynamic loading of adapters # Test dynamic loading of adapters
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
dynamic_output = self._run_operation_sequence( dynamic_output = self._run_operation_sequence(
mode=mode, mode=mode,
initial_adapters=test_case.initial_adapters, initial_adapters=test_case.initial_adapters,
...@@ -656,17 +907,19 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -656,17 +907,19 @@ class TestLoRADynamicUpdate(CustomTestCase):
max_loras_per_batch=test_case.max_loras_per_batch, max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence, op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens, max_new_tokens=test_case.max_new_tokens,
max_lora_rank=test_case.max_lora_rank,
lora_target_modules=test_case.lora_target_modules,
) )
# static loading # static loading
forward_ops = [ forward_ops = [
x for x in test_case.op_sequence if x.type == OperationType.FORWARD x
for x in test_case.op_sequence
if x.type == OperationType.FORWARD and x.expected_error is None
] ]
print("=" * 100) print("=" * 100)
print( print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
f"\n--- Running static pass with {len(forward_ops)} operations ---"
)
static_output = self._run_operation_sequence( static_output = self._run_operation_sequence(
mode=mode, mode=mode,
initial_adapters=test_case.all_adapters, initial_adapters=test_case.all_adapters,
...@@ -701,6 +954,27 @@ class TestLoRADynamicUpdate(CustomTestCase): ...@@ -701,6 +954,27 @@ class TestLoRADynamicUpdate(CustomTestCase):
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'", f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
) )
def test_dynamic_lora_update_engine(self):
"""
Test dynamic LoRA updates in engine mode.
"""
test_cases = ALL_TESTS
self._run_dynamic_adapter_updates(
mode=LoRAUpdateTestSessionMode.ENGINE,
test_cases=test_cases,
)
def test_dynamic_lora_update_server(self):
"""
Test dynamic LoRA updates in server mode.
"""
# In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness.
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
self._run_dynamic_adapter_updates(
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
)
if __name__ == "__main__": if __name__ == "__main__":
try: try:
......
...@@ -17,7 +17,7 @@ suites = { ...@@ -17,7 +17,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 400), TestFile("models/lora/test_lora_update.py", 700),
TestFile("models/test_embedding_models.py", 73), TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52), # TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100), TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment