Unverified Commit 1998ce40 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic...

Refactor LoRAManager and LoRAMemoryPool state management logic for dynamic LoRA loading support (#7412)
parent 72676cd6
......@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, List, Set, Tuple
from typing import Dict, Set, Tuple
import torch
......@@ -45,7 +45,6 @@ class LoRAManager:
def __init__(
self,
base_model: torch.nn.Module,
lora_paths: Dict[str, str],
base_hf_config: AutoConfig,
max_loras_per_batch: int,
load_config: LoadConfig,
......@@ -55,7 +54,6 @@ class LoRAManager:
tp_rank: int = 0,
):
self.base_model: torch.nn.Module = base_model
self.lora_paths: Dict[str, str] = lora_paths
self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch: int = max_loras_per_batch
self.load_config: LoadConfig = load_config
......@@ -69,8 +67,8 @@ class LoRAManager:
backend_type = get_backend_from_name(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
self.init_loras()
self.init_lora_memory_pool()
# Initialize mutable internal state of the LoRAManager.
self.init_state()
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
......@@ -100,72 +98,49 @@ class LoRAManager:
],
)
def init_loras(self):
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
def load_lora_adapters(self, lora_paths: Dict[str, str]):
"""
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras:
logger.warning(
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first."
)
continue
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.hf_target_names.update(self.configs[name].target_modules)
self.configs[lora_name] = LoRAConfig(lora_path)
# Target lora weight names for lora_a and lora_b modules respectively.
weights_A: List[str] = []
weights_B: List[str] = []
for module in self.hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
weights_A += lora_A
weights_B += lora_B
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
self.update_state_from_configs()
# load all weights to cpu
self.loras: Dict[str, LoRAAdapter] = {}
for name in self.lora_paths.keys():
lora_adapter = LoRAAdapter(
name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# misc lora configs
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
def unload_lora_adapters(self, lora_names: Set[str]):
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
if self.lora_backend == "flashinfer":
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
scaling = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
assert all(x.scaling == scaling for x in self.loras.values())
Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
"""
for lora_name in lora_names:
if lora_name in self.loras:
del self.configs[lora_name]
else:
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
# Convert original model layers to layers with LoRA
self.convert_to_lora_layers()
def init_lora_memory_pool(self):
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.max_lora_dim,
self.dtype,
self.tp_size,
self.tp_rank,
self.lora_modules,
)
# Initialize target lora modules in memory pool
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
self.update_state_from_configs()
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
......@@ -267,9 +242,16 @@ class LoRAManager:
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules
for layer_id, modules in self.lora_modules.items():
for module_name, module in modules:
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
self.update_lora_info()
def update_lora_info(self):
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in self.lora_modules.items():
for module_name, module in layer_modules.items():
if "qkv_proj" in module_name:
module.set_lora_info(
self.memory_pool.get_tensor(
......@@ -295,23 +277,139 @@ class LoRAManager:
),
)
def init_state(self):
"""
Initialize the internal (mutable) state of the LoRAManager.
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
"""
# Configs of all active LoRA adapters.
self.configs: Dict[str, LoRAConfig] = {}
# LoRA adapter weights cached in CPU memory.
self.loras: Dict[str, LoRAAdapter] = {}
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)
def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
This includes:
- Initializing LoRA adapters if they are not already loaded.
- Collect all LoRA weight names based on the current loaded adapters.
- Lazily monkey-patching the base model to use LoRA layers where applicable.
- Preparing the GPU buffer pool for active LoRA weights.
"""
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
hf_target_module_names: Set[str] = set()
for config in self.configs.values():
hf_target_module_names.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
#
# Please note that the following update operations are "monotonic" by design, meaning that we update
# multiple places to support the new weight names when the first adapter targeting such weight names
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
# list of LoRA weight names is expected to be extremely finite and stable.
self.update_lora_weight_names(hf_target_module_names)
self.update_lora_modules(hf_target_module_names)
self.update_memory_buffers(max_lora_dim)
def update_lora_weight_names(self, hf_target_names: Set[str]):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
def update_lora_adapters(self):
"""
Update the LoRA adapters in CPU memory based on the current `self.configs`.
It loads any new adapters that are not already loaded, and unloads any adapters
that are no longer in `self.configs` (e.g., unloaded).
"""
# Load new adapter weights to cpu
for name, config in self.configs.items():
if name not in self.loras:
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
lora_adapter = LoRAAdapter(
name,
config,
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters
for name in self.loras:
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self, max_lora_dim: int):
"""
Update the LoRA memory pool buffers based on the current LoRA configurations and update
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
are set or updated.
"""
self.memory_pool.init_buffers(
self.lora_weight_names, self.base_model, max_lora_dim
)
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(module, self.lora_backend)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def convert_to_lora_layers(self):
def update_lora_modules(self, hf_target_names: Set[str]):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
self.hf_target_names, self.base_model
hf_target_names, self.base_model
)
# Monkey patch to use the LoRA version layers
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
i: [] for i in range(self.base_hf_config.num_hidden_layers)
}
for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
......@@ -326,6 +424,7 @@ class LoRAManager:
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id].append(
(module_name, self.set_lora_module(module_name, module))
)
if module_name not in self.lora_modules[layer_id]:
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)
from typing import Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, List, Optional, Set, Tuple
import torch
......@@ -22,21 +22,16 @@ class LoRAMemoryPool:
self,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_lora_dim: int,
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]],
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
self.max_loras_per_batch: int = max_loras_per_batch
self.max_lora_dim: int = max_lora_dim
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = lora_modules
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
......@@ -55,79 +50,84 @@ class LoRAMemoryPool:
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1:
if module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
input_dim = divide(input_dim, self.tp_size)
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
input_dim = divide(input_dim, self.tp_size)
return (
self.max_loras_per_batch,
self.max_lora_dim * c,
max_lora_dim * c,
input_dim,
)
def get_lora_B_shape(
self, module_name: str, base_model: torch.nn.Module
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1:
if module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
return (
c,
self.max_loras_per_batch,
output_dim,
self.max_lora_dim,
max_lora_dim,
)
def init_buffers(
self,
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
max_lora_dim: int,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
device = next(base_model.parameters()).device
# Init A tensor, column_major=False
for module_A in lora_weight_names[0]:
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
self.A_buffer[module_A] = [
torch.empty(
lora_A_shape,
dtype=self.dtype,
device=device,
)
for _ in range(self.num_layer)
]
# Init B tensor, column_major=True
for module_B in lora_weight_names[1]:
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
self.B_buffer[module_B] = [
torch.empty(
lora_B_shape,
dtype=self.dtype,
device=device,
)
for _ in range(self.num_layer)
]
def update_buffer(
buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
):
new_weight_names = lora_weight_names - buffer.keys()
for module_name in new_weight_names:
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
buffer[module_name] = [
torch.empty(
lora_shape,
dtype=self.dtype,
device=device,
)
for _ in range(self.num_layer)
]
update_buffer(
self.A_buffer,
lora_weight_names[0],
self.get_lora_A_shape,
)
update_buffer(
self.B_buffer,
lora_weight_names[1],
self.get_lora_B_shape,
)
def prepare_lora_batch(
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
......@@ -147,14 +147,19 @@ class LoRAMemoryPool:
for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id = get_available_buffer_slot()
lora_adapter = lora_adapters.get(uid, None)
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapters.get(uid, None)
uid, buffer_id, lora_adapter, lora_modules
)
self.uid_to_buffer_id[uid] = buffer_id
self.buffer_id_to_uid[buffer_id] = uid
def load_lora_weight_to_buffer(
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
self,
uid: str,
buffer_id: int,
lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
):
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
assert (
......@@ -186,8 +191,8 @@ class LoRAMemoryPool:
temp_B_buffer[lora_weight_name] = weights
if self.tp_size > 1:
cur_layer_modules = self.lora_modules[layer_id]
for module_name, module in cur_layer_modules:
cur_layer_modules = lora_modules[layer_id]
for module_name, module in cur_layer_modules.items():
if "qkv_proj" in module_name:
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
temp_A_buffer["qkv_proj"], self.tp_rank
......@@ -236,7 +241,6 @@ class LoRAMemoryPool:
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType
) -> torch.Tensor:
if lora_type == LoRAType.LORA_A:
return self.A_buffer[weight_name][layer_id]
......
......@@ -108,7 +108,7 @@ def get_hidden_dim(
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"""
Mapping a target module name to names of the normized LoRA weights.
Mapping a target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
......
......@@ -278,6 +278,10 @@ class ModelRunner:
self.apply_torch_tp()
# Init lora
# TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
# a new server arg `enable_lora` to control whether to init LoRA manager to be more
# explicit, as it is perfectly valid to start a server with an empty lora_paths and
# load LoRA adapters dynamically later.
if server_args.lora_paths is not None:
self.init_lora_manager()
......@@ -796,7 +800,6 @@ class ModelRunner:
def init_lora_manager(self):
self.lora_manager = LoRAManager(
base_model=self.model,
lora_paths=self.server_args.lora_paths,
base_hf_config=self.model_config.hf_config,
max_loras_per_batch=self.server_args.max_loras_per_batch,
load_config=self.load_config,
......@@ -805,6 +808,7 @@ class ModelRunner:
tp_size=self.tp_size,
tp_rank=self.tp_rank,
)
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
logger.info("LoRA manager ready.")
def profile_max_num_token(self, total_gpu_memory: int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment