Unverified Commit e2ed9d04 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)

parent b5dd5e87
...@@ -33,6 +33,10 @@ ...@@ -33,6 +33,10 @@
"\n", "\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n", "\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n", "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n", "\n",
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to." "From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
...@@ -176,6 +180,241 @@ ...@@ -176,6 +180,241 @@
"terminate_process(server_process)" "terminate_process(server_process)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dynamic LoRA loading"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Basic Usage\n",
"\n",
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
"\n",
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
" },\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/unload_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora0\",\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora2\",\n",
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/generate\",\n",
" json={\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"List 3 countries and their capitals.\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
" },\n",
")\n",
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Advanced: hosting adapters of different shapes\n",
"\n",
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
"\n",
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
"\n",
"\n",
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
"# We are adding it here just to demonstrate usage.\n",
"server_process, port = launch_server_cmd(\n",
" f\"\"\"\n",
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --lora-paths lora0={lora0} \\\n",
" --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n",
" --disable-radix-cache\n",
" --max-lora-rank 64\n",
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
" \"\"\"\n",
")\n",
"\n",
"url = f\"http://127.0.0.1:{port}\"\n",
"wait_for_server(url)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"response = requests.post(\n",
" url + \"/load_lora_adapter\",\n",
" json={\n",
" \"lora_name\": \"lora1\",\n",
" \"lora_path\": lora1,\n",
" },\n",
")\n",
"\n",
"if response.status_code == 200:\n",
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
"else:\n",
" print(\"Failed to load LoRA adapter.\", response.json())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"url = f\"http://127.0.0.1:{port}\"\n",
"json_data = {\n",
" \"text\": [\n",
" \"List 3 countries and their capitals.\",\n",
" \"AI is a field of computer science focused on\",\n",
" ],\n",
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
" # The first input uses lora0, and the second input uses lora1\n",
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
"}\n",
"response = requests.post(\n",
" url + \"/generate\",\n",
" json=json_data,\n",
")\n",
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
......
...@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s ...@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None | | `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 | | `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton | | `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
## Kernel backend ## Kernel backend
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving" # and "Punica: Multi-Tenant LoRA Serving"
import logging import logging
from typing import Dict, Set, Tuple from typing import Dict, Iterable, Optional, Set, Tuple
import torch import torch
...@@ -53,6 +53,8 @@ class LoRAManager: ...@@ -53,6 +53,8 @@ class LoRAManager:
lora_backend: str = "triton", lora_backend: str = "triton",
tp_size: int = 1, tp_size: int = 1,
tp_rank: int = 0, tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
): ):
self.base_model: torch.nn.Module = base_model self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
...@@ -62,6 +64,10 @@ class LoRAManager: ...@@ -62,6 +64,10 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels # LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.") logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
...@@ -153,7 +159,9 @@ class LoRAManager: ...@@ -153,7 +159,9 @@ class LoRAManager:
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first." error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try: try:
self.configs[lora_name] = LoRAConfig(lora_path) new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
except Exception as e: except Exception as e:
success = False success = False
error_message = ( error_message = (
...@@ -168,6 +176,21 @@ class LoRAManager: ...@@ -168,6 +176,21 @@ class LoRAManager:
error_message=error_message, error_message=error_message,
) )
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult: def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
""" """
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
...@@ -214,7 +237,7 @@ class LoRAManager: ...@@ -214,7 +237,7 @@ class LoRAManager:
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path) weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None: if lora_path is not None:
lora = self.loras[lora_path] lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"] lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer # Use pinned memory to avoid synchronizations during host-to-device transfer
...@@ -319,7 +342,7 @@ class LoRAManager: ...@@ -319,7 +342,7 @@ class LoRAManager:
) )
else: else:
weight_name = get_weight_name( weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
) )
module.set_lora_info( module.set_lora_info(
self.memory_pool.get_tensor( self.memory_pool.get_tensor(
...@@ -351,56 +374,65 @@ class LoRAManager: ...@@ -351,56 +374,65 @@ class LoRAManager:
i: {} for i in range(self.base_hf_config.num_hidden_layers) i: {} for i in range(self.base_hf_config.num_hidden_layers)
} }
# Initialize memory pool # The LoRA memory pool that manages the GPU buffers for active LoRA weights.
self.memory_pool = LoRAMemoryPool( # It is initialized lazily when the first LoRA adapter is loaded.
self.base_hf_config, self.memory_pool: Optional[LoRAMemoryPool] = None
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)
def update_state_from_configs(self): def update_state_from_configs(self):
""" """
Update the internal state of the LoRAManager based on the current `self.configs`. This method Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded). should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
"""
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
This includes: Notes:
- Initializing LoRA adapters if they are not already loaded. - Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
- Collect all LoRA weight names based on the current loaded adapters. we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
- Lazily monkey-patching the base model to use LoRA layers where applicable. LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
- Preparing the GPU buffer pool for active LoRA weights. early CY25H2.
""" """
# Target module names in huggingface lora configs. if self.memory_pool is None:
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"} # Infer max_lora_rank and target_modules if not explicitly specified in server args.
hf_target_module_names: Set[str] = set() if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values(): for config in self.configs.values():
hf_target_module_names.update(config.target_modules) self.target_modules.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
# Loads / unloads LoRA adapters based on the latest configs. if self.max_lora_rank is None:
self.update_lora_adapters() self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed. self.update_lora_weight_names()
# self.update_lora_modules()
# Please note that the following update operations are "monotonic" by design, meaning that we update self.update_memory_buffers()
# multiple places to support the new weight names when the first adapter targeting such weight names else:
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer) # No-op if the memory pool can support the current LoRA configurations.
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the # TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# list of LoRA weight names is expected to be extremely finite and stable. # module is changed once FlashInfer backend is deprecated.
self.update_lora_weight_names(hf_target_module_names) assert self.memory_pool.can_support(self.configs.values()), (
self.update_lora_modules(hf_target_module_names) "LoRA memory pool cannot support the current LoRA configuration. "
self.update_memory_buffers(max_lora_dim) "This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
def update_lora_weight_names(self, hf_target_names: Set[str]): )
def update_lora_weight_names(self):
""" """
Add new LoRA weight names if needed based on the current `self.configs`. Add new LoRA weight names if needed based on the current `self.configs`.
""" """
# Target lora weight names for lora_a and lora_b modules respectively. # Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names: lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A) self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B) self.lora_weight_names[1].update(lora_B)
...@@ -434,21 +466,23 @@ class LoRAManager: ...@@ -434,21 +466,23 @@ class LoRAManager:
# Additional checks for flashinfer backend # Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer": if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values()) lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values()) scalings = set(x.scaling for x in self.loras.values())
assert ( assert (
len(lora_dims) == 1 and len(scalings) == 1 len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. " ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self, max_lora_dim: int): def update_memory_buffers(self):
""" """(Re)initialize the LoRA memory pool based on the current configurations."""
Update the LoRA memory pool buffers based on the current LoRA configurations and update self.memory_pool = LoRAMemoryPool(
LoRA modules to use the new buffers. This method should be called after the LoRA configurations base_hf_config=self.base_hf_config,
are set or updated. max_loras_per_batch=self.max_loras_per_batch,
""" dtype=self.dtype,
tp_size=self.tp_size,
self.memory_pool.init_buffers( tp_rank=self.tp_rank,
self.lora_weight_names, self.base_model, max_lora_dim max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
base_model=self.base_model,
) )
def set_lora_module(self, module_name, module): def set_lora_module(self, module_name, module):
...@@ -456,11 +490,11 @@ class LoRAManager: ...@@ -456,11 +490,11 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module) replace_submodule(self.base_model, module_name, lora_module)
return lora_module return lora_module
def update_lora_modules(self, hf_target_names: Set[str]): def update_lora_modules(self):
# Target module names of customized layers defined in python/sglang/srt/layers # Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"} # e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names( customized_target_names = get_customized_names_from_hf_names(
hf_target_names, self.base_model self.target_modules, self.base_model
) )
for module_name, module in self.base_model.named_modules(): for module_name, module in self.base_model.named_modules():
......
from typing import Callable, Dict, List, Optional, Set, Tuple from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide ...@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.utils import ( from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES, ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType, LoRAType,
get_hidden_dim, get_hidden_dim,
get_normalized_lora_weight_names,
get_stacked_multiply, get_stacked_multiply,
get_weight_name, get_weight_name,
) )
...@@ -25,6 +27,9 @@ class LoRAMemoryPool: ...@@ -25,6 +27,9 @@ class LoRAMemoryPool:
dtype: torch.dtype, dtype: torch.dtype,
tp_size: int, tp_size: int,
tp_rank: int, tp_rank: int,
max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]],
base_model: torch.nn.Module,
): ):
self.base_hf_config: AutoConfig = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers self.num_layer: int = base_hf_config.num_hidden_layers
...@@ -32,6 +37,10 @@ class LoRAMemoryPool: ...@@ -32,6 +37,10 @@ class LoRAMemoryPool:
self.dtype: torch.dtype = dtype self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size self.tp_size: int = tp_size
self.tp_rank: int = tp_rank self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank
# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
# Both A_buffer and B_buffer maps lora weight names to its buffer space. # Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape # A_buffer contains num_layer number of row-major tensors with shape
...@@ -49,6 +58,31 @@ class LoRAMemoryPool: ...@@ -49,6 +58,31 @@ class LoRAMemoryPool:
# Here we don't initialize to None since None is a valid uid # Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.init_buffers(base_model)
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
"""
Check if the memory pool can support the given LoRA adapters.
"""
def _can_support(config: LoRAConfig) -> bool:
"""
Check if the memory pool can support a single LoRA adapter.
"""
if config.r > self.max_lora_rank:
return False
weights_a, weights_b = get_normalized_lora_weight_names(
config.target_modules
)
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
self.lora_weight_names[1]
)
if isinstance(config, LoRAConfig):
return _can_support(config)
else:
return all(_can_support(x) for x in config)
def get_lora_A_shape( def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]: ) -> Tuple[int]:
...@@ -82,25 +116,18 @@ class LoRAMemoryPool: ...@@ -82,25 +116,18 @@ class LoRAMemoryPool:
max_lora_dim, max_lora_dim,
) )
def init_buffers( def init_buffers(self, base_model: torch.nn.Module):
self,
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
max_lora_dim: int,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
device = next(base_model.parameters()).device device = next(base_model.parameters()).device
def update_buffer( def init_buffer(
buffer: Dict[str, List[torch.Tensor]], buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str], lora_weight_names: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]], get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
): ):
new_weight_names = lora_weight_names - buffer.keys() for module_name in lora_weight_names:
for module_name in new_weight_names: lora_shape = get_lora_shape_fn(
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim) module_name, base_model, self.max_lora_rank
)
buffer[module_name] = [ buffer[module_name] = [
torch.empty( torch.empty(
lora_shape, lora_shape,
...@@ -110,15 +137,15 @@ class LoRAMemoryPool: ...@@ -110,15 +137,15 @@ class LoRAMemoryPool:
for _ in range(self.num_layer) for _ in range(self.num_layer)
] ]
update_buffer( init_buffer(
self.A_buffer, self.A_buffer,
lora_weight_names[0], self.lora_weight_names[0],
self.get_lora_A_shape, self.get_lora_A_shape,
) )
update_buffer( init_buffer(
self.B_buffer, self.B_buffer,
lora_weight_names[1], self.lora_weight_names[1],
self.get_lora_B_shape, self.get_lora_B_shape,
) )
......
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Optional, Set, Tuple from typing import Iterable, Optional, Set, Tuple
import torch import torch
...@@ -106,9 +106,11 @@ def get_hidden_dim( ...@@ -106,9 +106,11 @@ def get_hidden_dim(
raise NotImplementedError() raise NotImplementedError()
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
""" """
Mapping a target module name to names of the normalized LoRA weights. Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B) Returned tuple contains (name for Lora A, name for Lora B)
""" """
params_mapping = { params_mapping = {
...@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: ...@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]), "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]), "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
} }
stacked = params_mapping.get(name, ([name], [name]))
return stacked result = (set(), set())
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
return result
def get_stacked_multiply(module_name: str) -> int: def get_stacked_multiply(module_name: str) -> int:
......
...@@ -891,6 +891,8 @@ class ModelRunner: ...@@ -891,6 +891,8 @@ class ModelRunner:
lora_backend=self.server_args.lora_backend, lora_backend=self.server_args.lora_backend,
tp_size=self.tp_size, tp_size=self.tp_size,
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
) )
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths) result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
if result.success: if result.success:
......
...@@ -134,6 +134,8 @@ class ServerArgs: ...@@ -134,6 +134,8 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None preferred_sampling_params: Optional[str] = None
# LoRA # LoRA
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List[str]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton" lora_backend: str = "triton"
...@@ -1129,6 +1131,28 @@ class ServerArgs: ...@@ -1129,6 +1131,28 @@ class ServerArgs:
) )
# LoRA # LoRA
parser.add_argument(
"--max-lora-rank",
default=ServerArgs.max_lora_rank,
type=int,
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument(
"--lora-target-modules",
type=str,
choices=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
nargs="*",
default=None,
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument( parser.add_argument(
"--lora-paths", "--lora-paths",
type=str, type=str,
......
...@@ -505,6 +505,8 @@ class SRTRunner: ...@@ -505,6 +505,8 @@ class SRTRunner:
torchao_config: Optional[str] = None, torchao_config: Optional[str] = None,
cuda_graph_max_bs: int = 4, cuda_graph_max_bs: int = 4,
sleep_on_idle=False, sleep_on_idle=False,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
): ):
self.model_type = model_type self.model_type = model_type
self.is_generation = model_type == "generation" self.is_generation = model_type == "generation"
...@@ -543,6 +545,8 @@ class SRTRunner: ...@@ -543,6 +545,8 @@ class SRTRunner:
cuda_graph_max_bs=cuda_graph_max_bs, cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce, disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle, sleep_on_idle=sleep_on_idle,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
**spec_kwargs, **spec_kwargs,
) )
......
This diff is collapsed.
...@@ -17,7 +17,7 @@ suites = { ...@@ -17,7 +17,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250), TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 400), TestFile("models/lora/test_lora_update.py", 700),
TestFile("models/test_embedding_models.py", 73), TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52), # TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100), TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment