Unverified Commit b0980af8 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Support pinning adapter via server args. (#9249)

parent 24eaebeb
......@@ -29,7 +29,7 @@
"\n",
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
"\n",
"* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n",
"* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n",
"\n",
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
"\n",
......@@ -372,6 +372,15 @@
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
......@@ -387,7 +396,40 @@
"\n",
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
"\n",
"In the example below, we unload `lora1` and reload it as a `pinned` adapter:"
"In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --cuda-graph-max-bs 8 \\\n",
" --max-loras-per-batch 3 --lora-backend triton \\\n",
" --max-lora-rank 256 \\\n",
" --lora-target-modules all \\\n",
" --lora-paths \\\n",
" {\"lora_name\":\"lora0\",\"lora_path\":\"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\"pinned\":true} \\\n",
" {\"lora_name\":\"lora1\",\"lora_path\":\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"} \\\n",
" lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\n",
" \"\"\"\n",
")\n",
"\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:"
]
},
{
......@@ -407,7 +449,7 @@
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" \"lora_path\": \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\",\n",
" \"pinned\": True, # Pin the adapter to GPU\n",
" },\n",
")"
......@@ -417,7 +459,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Verify that the result is identical as before:"
"Verify that the results are expected:"
]
},
{
......@@ -431,17 +473,19 @@
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
" \"lora_path\": [\"lora0\", \"lora1\", \"lora2\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")"
"print(f\"Output from lora0 (pinned): \\n{response.json()[0]['text']}\\n\")\n",
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")\n",
"print(f\"Output from lora2 (not pinned): \\n{response.json()[2]['text']}\\n\")"
]
},
{
......
......@@ -179,7 +179,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
......
......@@ -55,7 +55,7 @@ class LoRAManager:
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
lora_paths: Optional[List[LoRARef]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
......@@ -370,7 +370,7 @@ class LoRAManager:
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
lora_paths: Optional[List[LoRARef]] = None,
):
"""
Initialize the internal (mutable) state of the LoRAManager.
......@@ -392,7 +392,7 @@ class LoRAManager:
self.init_memory_pool()
self.update_lora_info()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {}
......@@ -406,7 +406,7 @@ class LoRAManager:
self.num_pinned_loras: int = 0
if lora_paths:
for lora_ref in lora_paths.values():
for lora_ref in lora_paths:
result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(
......
......@@ -59,9 +59,9 @@ class LoRARegistry:
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
"""
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
assert lora_paths is None or all(
isinstance(lora, LoRARef) for lora in lora_paths.values()
isinstance(lora, LoRARef) for lora in lora_paths
), (
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
"Please file an issue if you see this error."
......@@ -78,7 +78,7 @@ class LoRARegistry:
# Initialize the registry with provided LoRA paths, if present.
if lora_paths:
for lora_ref in lora_paths.values():
for lora_ref in lora_paths:
self._register_adapter(lora_ref)
async def register(self, lora_ref: LoRARef):
......
......@@ -298,7 +298,7 @@ class TokenizerManager:
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
# serves as the source of truth for available adapters and maps user-friendly LoRA names
# to internally used unique LoRA IDs.
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
self.lora_registry = LoRARegistry(self.server_args.lora_paths)
# Lock to serialize LoRA update operations.
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
# LoRA updates and inference to overlap.
......
......@@ -153,7 +153,9 @@ class ServerArgs:
enable_lora: Optional[bool] = None
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[Union[set[str], List[str]]] = None
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
lora_paths: Optional[
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
] = None
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
......@@ -1319,7 +1321,7 @@ class ServerArgs:
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
)
parser.add_argument(
"--max-loras-per-batch",
......@@ -2086,28 +2088,42 @@ class ServerArgs:
)
if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list.
# TODO (lifuhuang): support specifying pinned adapters in server_args.
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
self.lora_paths = []
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = LoRARef(
lora_name=name, lora_path=path, pinned=False
if isinstance(lora_path, str):
if "=" in lora_path:
name, path = lora_path.split("=", 1)
lora_ref = LoRARef(
lora_name=name, lora_path=path, pinned=False
)
else:
lora_ref = LoRARef(
lora_name=lora_path, lora_path=lora_path, pinned=False
)
elif isinstance(lora_path, dict):
assert (
"lora_name" in lora_path and "lora_path" in lora_path
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
lora_ref = LoRARef(
lora_name=lora_path["lora_name"],
lora_path=lora_path["lora_path"],
pinned=lora_path.get("pinned", False),
)
else:
self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path, lora_path=lora_path, pinned=False
raise ValueError(
f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
"Expected a string or a dictionary."
)
self.lora_paths.append(lora_ref)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
self.lora_paths = [
LoRARef(lora_name=k, lora_path=v, pinned=False)
for k, v in self.lora_paths.items()
}
]
elif self.lora_paths is None:
self.lora_paths = {}
self.lora_paths = []
else:
raise ValueError(
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
......@@ -2134,9 +2150,7 @@ class ServerArgs:
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
)
assert (
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
), (
assert len(self.lora_paths) <= self.max_loaded_loras, (
"The number of LoRA paths should not exceed max_loaded_loras. "
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
)
......@@ -2357,13 +2371,22 @@ class PortArgs:
class LoRAPathAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, {})
for lora_path in values:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
getattr(namespace, self.dest)[name] = path
else:
getattr(namespace, self.dest)[lora_path] = lora_path
lora_paths = []
if values:
assert isinstance(values, list), "Expected a list of LoRA paths."
for lora_path in values:
lora_path = lora_path.strip()
if lora_path.startswith("{") and lora_path.endswith("}"):
obj = json.loads(lora_path)
assert "lora_path" in obj and "lora_name" in obj, (
f"{repr(lora_path)} looks like a JSON str, "
"but it does not contain 'lora_name' and 'lora_path' keys."
)
lora_paths.append(obj)
else:
lora_paths.append(lora_path)
setattr(namespace, self.dest, lora_paths)
class DeprecatedAction(argparse.Action):
......
......@@ -491,7 +491,7 @@ class SRTRunner:
tp_size: int = 1,
model_impl: str = "auto",
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
max_loras_per_batch: int = 4,
attention_backend: Optional[str] = None,
prefill_attention_backend: Optional[str] = None,
......
......@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
import json
import multiprocessing as mp
import unittest
from dataclasses import dataclass
......@@ -89,8 +90,35 @@ BASIC_TESTS = [
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
initial_adapters=[
# Testing 3 supported lora-path formats.
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
{
"lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": False,
},
],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
]
),
),
Operation(
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.UNLOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
......@@ -147,6 +175,10 @@ BASIC_TESTS = [
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.UNLOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
......@@ -157,18 +189,12 @@ BASIC_TESTS = [
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pbevan11/llama-3.1-8b-ocr-correction",
]
None,
),
),
],
......@@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase:
*,
testcase: Optional[TestCase],
model_path: str,
lora_paths: list[str],
lora_paths: List[Union[str, dict]],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
max_lora_rank: Optional[int],
......@@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase:
self.cuda_graph_max_bs = cuda_graph_max_bs
self.enable_lora = enable_lora
self.expected_adapters = set(lora_paths or [])
self.expected_adapters = set()
if self.lora_paths:
for adapter in self.lora_paths:
if isinstance(adapter, dict):
lora_name = adapter["lora_name"]
elif "=" in adapter:
lora_name = adapter.split("=")[0]
else:
lora_name = adapter
self.expected_adapters.add(lora_name)
self.handle = None # Will be set in __enter__
def __enter__(self):
......@@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
if self.enable_lora:
other_args.append("--enable-lora")
if self.lora_paths:
other_args.extend(["--lora-paths"] + self.lora_paths)
other_args.append("--lora-paths")
for lora_path in self.lora_paths:
if isinstance(lora_path, dict):
lora_path = json.dumps(lora_path)
other_args.append(lora_path)
if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None:
......@@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
self,
mode: LoRAUpdateTestSessionMode,
base: str,
initial_adapters: List[str],
initial_adapters: List[Union[str, dict]],
op_sequence: List[Operation],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment