Unverified Commit 4e3defe5 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support start up LoRA server without initial adapters (#8019)

parent 60468da4
......@@ -27,6 +27,8 @@
"source": [
"The following server arguments are relevant for multi-LoRA serving:\n",
"\n",
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
"\n",
"* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n",
"\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
......@@ -35,7 +37,7 @@
"\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n",
......@@ -79,6 +81,7 @@
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" --max-loras-per-batch 1 --lora-backend triton \\\n",
" --disable-radix-cache\n",
......@@ -98,7 +101,7 @@
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses the base model\n",
......@@ -137,6 +140,7 @@
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
......@@ -157,7 +161,7 @@
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
......@@ -191,11 +195,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Basic Usage\n",
"\n",
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
"\n",
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
"When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
]
},
{
......@@ -204,13 +206,22 @@
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
" --enable-lora \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 256\n",
" --lora-target-modules all\n",
" \"\"\"\n",
")\n",
"\n",
......@@ -218,6 +229,13 @@
"wait_for_server(url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load adapter lora0"
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -227,8 +245,8 @@
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
" \"lora_name\": \"lora0\",\n",
" \"lora_path\": lora0,\n",
" },\n",
")\n",
"\n",
......@@ -239,38 +257,10 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
" },\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")"
"Load adapter lora1:"
]
},
{
......@@ -282,8 +272,8 @@
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora2\",\n",
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" },\n",
")\n",
"\n",
......@@ -294,24 +284,10 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
" },\n",
")\n",
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
"Check inference output:"
]
},
{
......@@ -320,18 +296,29 @@
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Advanced: hosting adapters of different shapes\n",
"\n",
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
"\n",
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
"Unload lora0 and replace it with a different adapter:"
]
},
{
......@@ -340,39 +327,18 @@
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" f\"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0={lora0} \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 64\n",
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
" \"\"\"\n",
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" \"lora_name\": \"lora0\",\n",
" \"lora_path\": lora0_new,\n",
" },\n",
")\n",
"\n",
......@@ -382,6 +348,13 @@
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check output again:"
]
},
{
"cell_type": "code",
"execution_count": null,
......@@ -392,7 +365,7 @@
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
......@@ -402,8 +375,8 @@
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
......
......@@ -176,8 +176,9 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults |
|-----------|-------------|----------|
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
......
......@@ -186,9 +186,9 @@ class LoRAManager:
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
......
......@@ -574,7 +574,7 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.lora_paths and obj.lora_path:
if self.server_args.enable_lora and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab(
......@@ -1037,6 +1037,10 @@ class TokenizerManager:
_: Optional[fastapi.Request] = None,
) -> LoadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
......@@ -1060,6 +1064,10 @@ class TokenizerManager:
_: Optional[fastapi.Request] = None,
) -> UnloadLoRAAdapterReqOutput:
self.auto_create_handle_loop()
if not self.server_args.enable_lora:
raise ValueError(
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
)
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
# with dp_size > 1.
......
......@@ -264,7 +264,7 @@ class CudaGraphRunner:
if self.enable_torch_compile:
set_torch_compile_config()
if self.model_runner.server_args.lora_paths is not None:
if self.model_runner.server_args.enable_lora:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
# Graph inputs
......@@ -510,11 +510,10 @@ class CudaGraphRunner:
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
if self.model_runner.server_args.lora_paths is not None:
# Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
# different logic to handle lora, so we need to set `lora_paths` to a list of non-None
# values if lora is enabled.
lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
if self.model_runner.server_args.enable_lora:
# It is safe to capture CUDA graph using empty LoRA path, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA path is empty for perf optimization).
lora_paths = [None] * bs
else:
lora_paths = None
......
......@@ -418,7 +418,7 @@ class ForwardBatch:
ret._compute_mrope_positions(model_runner, batch)
# Init lora information
if model_runner.server_args.lora_paths is not None:
if model_runner.server_args.enable_lora:
model_runner.lora_manager.prepare_lora_batch(ret)
TboForwardBatchPreparer.prepare(
......
......@@ -304,11 +304,7 @@ class ModelRunner:
self.apply_torch_tp()
# Init lora
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
# load LoRA adapters dynamically later.
if server_args.lora_paths is not None:
if server_args.enable_lora:
self.init_lora_manager()
# Init memory pool and attention backends
......@@ -895,7 +891,7 @@ class ModelRunner:
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
)
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
if result.success:
logger.info(
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
......
......@@ -26,6 +26,8 @@ from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.utils import (
LORA_TARGET_ALL_MODULES,
SUPPORTED_LORA_TARGET_MODULES,
configure_ipv6,
get_device,
get_device_memory_capacity,
......@@ -140,8 +142,9 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None
# LoRA
enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List[str]] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
......@@ -1148,6 +1151,12 @@ class ServerArgs:
)
# LoRA
parser.add_argument(
"--enable-lora",
default=ServerArgs.enable_lora,
action="store_true",
help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.",
)
parser.add_argument(
"--max-lora-rank",
default=ServerArgs.max_lora_rank,
......@@ -1157,18 +1166,12 @@ class ServerArgs:
parser.add_argument(
"--lora-target-modules",
type=str,
choices=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES],
nargs="*",
default=None,
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
help="The union set of all target modules where LoRA should be applied. If not specified, "
"it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, "
"all supported modules will be targeted.",
)
parser.add_argument(
"--lora-paths",
......@@ -1816,15 +1819,46 @@ class ServerArgs:
None,
}, "moe_dense_tp_size only support 1 and None currently"
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
else:
self.lora_paths[lora_path] = lora_path
self.check_lora_server_args()
def check_lora_server_args(self):
# Enable LoRA if any LoRA paths are provided for backward compatibility.
if self.lora_paths:
if self.enable_lora is None:
self.enable_lora = True
logger.info(
"--enable-lora is set to True because --lora-paths is provided."
)
elif self.enable_lora is False:
logger.warning(
"--enable-lora is set to False, any provided lora_paths will be ignored."
)
if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list.
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = path
else:
self.lora_paths[lora_path] = lora_path
# Expand target modules
if self.lora_target_modules:
self.lora_target_modules = set(self.lora_target_modules)
if "all" in self.lora_target_modules:
assert (
len(self.lora_target_modules) == 1
), "If 'all' is specified in --lora-target-modules, it should be the only module specified."
self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES)
# Ensure sufficient information is provided for LoRA initialization.
assert self.lora_paths or (
self.max_lora_rank and self.lora_target_modules
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
......
......@@ -2892,3 +2892,17 @@ def parse_module_path(module_path, function_name, create_dummy):
return final_module, getattr(final_module, function_name)
return final_module, None
# LoRA-related constants and utilities
SUPPORTED_LORA_TARGET_MODULES = [
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
]
LORA_TARGET_ALL_MODULES = "all"
......@@ -507,6 +507,7 @@ class SRTRunner:
sleep_on_idle=False,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
......@@ -547,6 +548,7 @@ class SRTRunner:
sleep_on_idle=sleep_on_idle,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
**spec_kwargs,
)
......
......@@ -64,8 +64,9 @@ class TestCase:
base: str
max_loras_per_batch: int
all_adapters: List[str]
initial_adapters: List[str]
op_sequence: List[Operation]
initial_adapters: Optional[List[str]] = None
enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List] = None
max_new_tokens: int = 32
......@@ -171,6 +172,64 @@ BASIC_TESTS = [
),
],
),
TestCase(
description="dynamic lora update without initial lora_paths",
base="meta-llama/Llama-3.1-8B-Instruct",
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
max_loras_per_batch=4,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
op_sequence=[
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
None,
]
),
),
Operation(
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
None,
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
None,
]
),
),
],
),
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
......@@ -371,7 +430,7 @@ TARGET_MODULE_TESTS = [
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
expected_error="updating LoRA shapes",
expected_error="incompatible",
),
Operation(
type=OperationType.FORWARD,
......@@ -431,7 +490,7 @@ MAX_LORA_RANK_TESTS = [
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
expected_error="incompatible",
),
Operation(
type=OperationType.FORWARD,
......@@ -470,7 +529,7 @@ MAX_LORA_RANK_TESTS = [
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
expected_error="incompatible",
),
Operation(
type=OperationType.FORWARD,
......@@ -521,6 +580,7 @@ class LoRAUpdateTestSessionBase:
lora_paths: list[str],
max_loras_per_batch: int,
max_lora_rank: Optional[int],
enable_lora: Optional[bool] = None,
lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
......@@ -535,8 +595,9 @@ class LoRAUpdateTestSessionBase:
self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph
self.cuda_graph_max_bs = cuda_graph_max_bs
self.enable_lora = enable_lora
self.expected_adapters = set(lora_paths)
self.expected_adapters = set(lora_paths or [])
self.handle = None # Will be set in __enter__
def __enter__(self):
......@@ -596,6 +657,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
disable_cuda_graph=self.disable_cuda_graph,
cuda_graph_max_bs=self.cuda_graph_max_bs,
disable_radix_cache=True,
enable_lora=self.enable_lora,
)
self.handle.__enter__()
return self
......@@ -690,8 +752,6 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
other_args = [
"--cuda-graph-max-bs",
str(self.cuda_graph_max_bs),
"--lora-paths",
*self.lora_paths,
"--max-loras-per-batch",
str(self.max_loras_per_batch),
"--lora-backend",
......@@ -704,6 +764,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
"--mem-fraction-static",
str(MEM_FRACTION_STATIC),
]
if self.enable_lora:
other_args.append("--enable-lora")
if self.lora_paths:
other_args.extend(["--lora-paths"] + self.lora_paths)
if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None:
......@@ -836,6 +900,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
initial_adapters: List[str],
max_loras_per_batch: int,
op_sequence: List[Operation],
enable_lora: Optional[bool] = None,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
max_new_tokens: int = 32,
......@@ -854,6 +919,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
max_loras_per_batch=max_loras_per_batch,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
) as session:
for op in op_sequence:
op_type = op.type
......@@ -903,6 +969,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
dynamic_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.initial_adapters,
enable_lora=test_case.enable_lora,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence,
......@@ -923,6 +990,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
static_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.all_adapters,
enable_lora=test_case.enable_lora,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=forward_ops,
......
......@@ -18,7 +18,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 700),
TestFile("models/lora/test_lora_update.py", 800),
TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment