Unverified Commit f8a173bb authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)

parent 6b847a9a
......@@ -35,7 +35,7 @@
"\n",
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
"\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n",
......
......@@ -5,22 +5,6 @@ import torch
from sglang.srt.lora.utils import LoRABatchInfo
def get_fuse_output_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
class BaseLoRABackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
......@@ -28,15 +12,11 @@ class BaseLoRABackend:
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of adding will be fused into kernel
"""
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_add = get_fuse_output_add_from_name(name)
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
......@@ -126,8 +106,8 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
return TritonLoRABackend
elif name == "flashinfer":
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
return FlashInferLoRABackend
raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
)
else:
raise ValueError(f"Invalid backend: {name}")
from typing import Tuple
import torch
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
class FlashInferLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return (
self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
* self.batch_info.scalings[0]
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)
# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)
lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
return lora_output * self.batch_info.scalings[0]
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
lora_rank = gate_up_lora_b[0].shape[-1]
output_dim = gate_up_lora_b[0].shape[-2]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
lora_output = torch.empty(
(x.shape[0], 2 * output_dim),
device=x.device,
dtype=x.dtype,
)
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],
)
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank:].contiguous(),
weights=gate_up_lora_b[1],
)
return lora_output * self.batch_info.scalings[0]
from typing import List, Tuple
import torch
from torch import nn
......@@ -79,18 +77,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
x=lora_a_output,
weights=self.B_buffer,
base_output=base_output,
)
return lora_output
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
......@@ -135,37 +128,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
):
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
if getattr(self, "B_buffer_gate_up", None) is None:
self.B_buffer_gate_up = torch.empty(
(
B_buffer[0].shape[0],
2 * B_buffer[0].shape[1],
B_buffer[0].shape[2],
),
dtype=B_buffer[0].dtype,
device=B_buffer[0].device,
)
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
self.B_buffer_gate_up = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_output = self.lora_backend.run_gate_up_lora(
x,
self.A_buffer_gate_up,
self.B_buffer_gate_up,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
x=x,
gate_up_lora_a=self.A_buffer_gate_up,
gate_up_lora_b=self.B_buffer_gate_up,
base_output=base_output,
)
return lora_output
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
......@@ -173,9 +145,16 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
# Since the outputs for both gate and up are identical, we use a random one.
shard_size = self.base_layer.output_partition_sizes[0]
gate_size = self.base_layer.output_sizes[0]
start_idx = tp_rank * shard_size
end_idx = (tp_rank + 1) * shard_size
return B[:, start_idx:end_idx, :]
return torch.concat(
(
B[start_idx:end_idx, :],
B[gate_size + start_idx : gate_size + end_idx],
),
dim=0,
)
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
......@@ -185,86 +164,46 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
q_proj_shard_size = self.base_layer.q_proj_shard_size
kv_proj_shard_size = self.base_layer.kv_proj_shard_size
self.output_offset = torch.tensor(
[
0,
q_proj_shard_size,
q_proj_shard_size + kv_proj_shard_size,
q_proj_shard_size + 2 * kv_proj_shard_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(q_proj_shard_size, kv_proj_shard_size)
def set_lora_info(
self,
A_buffer_qkv: torch.Tensor,
B_buffer_q: torch.Tensor,
B_buffer_kv: torch.Tensor,
B_buffer_qkv: torch.Tensor,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
if getattr(self, "B_buffer_qkv", None) is None:
self.B_buffer_qkv = torch.empty(
(
B_buffer_q[0].shape[0],
output_dim_q + 2 * output_dim_kv,
B_buffer_q[0].shape[2],
),
dtype=B_buffer_q[0].dtype,
device=B_buffer_q[0].device,
)
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
B_buffer_kv[0]
)
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
B_buffer_kv[1]
)
# Offsets of q/k/v in output dimension
if getattr(self, "output_offset", None) is None:
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.B_buffer_qkv = B_buffer_qkv
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
if self.lora_backend.fuse_stacked_lora_b:
backend_kwargs["output_offset"] = self.output_offset
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
x=x,
qkv_lora_a=self.A_buffer_qkv,
qkv_lora_b=self.B_buffer_qkv,
base_output=base_output,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
)
return lora_output
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
return A
def slice_lora_b_weights(
self, B: List[torch.Tensor], tp_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
B_q, B_kv = B
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int) -> torch.Tensor:
base_layer = self.base_layer
q_proj_shard_size = base_layer.q_proj_shard_size
kv_proj_shard_size = base_layer.kv_proj_shard_size
......@@ -277,7 +216,19 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
kv_start_idx = kv_proj_shard_size * kv_shard_id
kv_end_idx = kv_start_idx + kv_proj_shard_size
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
q_size, k_size, _ = base_layer.output_sizes
B_q_shard = B[q_start_idx:q_end_idx, :]
B_k_shard = B[q_size + kv_start_idx : q_size + kv_end_idx, :]
B_v_shard = B[q_size + k_size + kv_start_idx : q_size + k_size + kv_end_idx, :]
return torch.concat(
(
B_q_shard,
B_k_shard,
B_v_shard,
),
dim=0,
)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
......@@ -294,18 +245,13 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_add
else base_output + lora_output
x=lora_a_output,
weights=self.B_buffer,
base_output=base_output,
)
return lora_output
def forward(self, input_: torch.Tensor):
# duplicate the logic in RowParallelLinear
......
......@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj")
v_name = weight_name.replace("q_proj", "v_proj")
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have lora, initialize it to zero
......@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module
else torch.zeros_like(weights[v_name])
)
if "lora_A" in weight_name:
weights[qkv_name] = torch.cat(
(
weights[q_name],
k_proj_weight,
weights[v_name],
),
0,
)
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
else:
weights[kv_name] = torch.stack(
[
k_proj_weight,
weights[v_name],
],
dim=0,
)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
weights[qkv_name] = torch.cat(
(
weights[q_name],
k_proj_weight,
weights[v_name],
),
0,
)
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else:
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads,
head_size * self.base_hf_config.num_key_value_heads,
],
dim=0,
)
weights[kv_name] = torch.stack(
[k_proj_weight, v_proj_weight],
dim=0,
)
# else: no-op as LoRA B weight is already stacked.
def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
......@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name])
# FIXME: Add gate-only support for flashinfer in future implementations
assert self.lora_backend.name == "triton", (
f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
)
if "lora_A" in weight_name:
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
else:
weights[gate_up_name] = torch.stack(
[weights[weight_name], weights[up_name]], dim=0
)
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
weights.pop(weight_name)
if up_name in weights:
weights.pop(up_name)
......@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
else:
output_dim = weights[gate_up_name].shape[0] // 2
weights[gate_up_name] = torch.stack(
[
weights[gate_up_name][:output_dim, :],
weights[gate_up_name][output_dim:, :],
],
dim=0,
)
# else: no-op as LoRA B weight is already stacked.
......@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
LoRAType,
get_customized_names_from_hf_names,
get_layer_id,
get_normalized_lora_weight_names,
get_weight_name,
......@@ -345,40 +344,19 @@ class LoRAManager:
)
self.lora_backend.set_batch_info(batch_info)
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
self.update_lora_info()
def update_lora_info(self):
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items():
if "qkv_proj" in module_name:
module.set_lora_info(
self.memory_pool.get_tensor(
"qkv_proj", layer_id, LoRAType.LORA_A
),
self.memory_pool.get_tensor(
"q_proj", layer_id, LoRAType.LORA_B
),
self.memory_pool.get_tensor(
"kv_proj", layer_id, LoRAType.LORA_B
),
)
else:
weight_name = get_weight_name(
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(
weight_name, layer_id, LoRAType.LORA_A
),
self.memory_pool.get_tensor(
weight_name, layer_id, LoRAType.LORA_B
),
)
weight_name = get_weight_name(
module_name, self.memory_pool.lora_weight_names
)
module.set_lora_info(
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
)
def init_state(
self,
......@@ -405,6 +383,7 @@ class LoRAManager:
self.init_lora_weight_names()
self.init_lora_modules()
self.init_memory_pool()
self.update_lora_info()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
......@@ -461,9 +440,9 @@ class LoRAManager:
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
self.target_modules
)
def load_lora_weights(self, lora_ref: LoRARef):
"""
......@@ -479,15 +458,6 @@ class LoRAManager:
lora_adapter.initialize_weights()
self.loras[lora_ref.lora_id] = lora_adapter
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
......@@ -512,12 +482,6 @@ class LoRAManager:
{} for _ in range(self.base_hf_config.num_hidden_layers)
]
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
self.target_modules, self.base_model
)
for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
......@@ -530,7 +494,7 @@ class LoRAManager:
continue
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
if module_name.split(".")[-1] in self.lora_weight_names:
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
......
......@@ -52,7 +52,7 @@ class LoRAMemoryPool:
tp_size: int,
tp_rank: int,
max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]],
lora_weight_names: Set[str],
base_model: torch.nn.Module,
):
self.base_hf_config: AutoConfig = base_hf_config
......@@ -62,9 +62,7 @@ class LoRAMemoryPool:
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank
# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
self.lora_weight_names: Set[str] = lora_weight_names
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
......@@ -97,12 +95,8 @@ class LoRAMemoryPool:
"""
if config.r > self.max_lora_rank:
return False
weights_a, weights_b = get_normalized_lora_weight_names(
config.target_modules
)
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
self.lora_weight_names[1]
)
weights = get_normalized_lora_weight_names(config.target_modules)
return weights.issubset(self.lora_weight_names)
if isinstance(config, LoRAConfig):
return _can_support(config)
......@@ -132,11 +126,9 @@ class LoRAMemoryPool:
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
c = get_stacked_multiply(module_name)
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
return (
c,
self.max_loras_per_batch,
output_dim,
max_lora_dim,
......@@ -165,13 +157,13 @@ class LoRAMemoryPool:
init_buffer(
self.A_buffer,
self.lora_weight_names[0],
self.lora_weight_names,
self.get_lora_A_shape,
)
init_buffer(
self.B_buffer,
self.lora_weight_names[1],
self.lora_weight_names,
self.get_lora_B_shape,
)
......@@ -246,7 +238,7 @@ class LoRAMemoryPool:
return
assert lora_adapter is not None
lora_rank = lora_adapter.config.hf_config["r"]
lora_rank = lora_adapter.config.r
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
......@@ -256,73 +248,38 @@ class LoRAMemoryPool:
weight_name: None for weight_name in self.B_buffer
}
for name, weights in layer_weights.items():
lora_weight_name = get_weight_name(name, self.lora_weight_names)
if "lora_A" in name:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_A
)
temp_A_buffer[lora_weight_name] = weights
else:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_B
)
temp_B_buffer[lora_weight_name] = weights
if self.tp_size > 1:
cur_layer_modules = lora_modules[layer_id]
for module_name, module in cur_layer_modules.items():
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
weight_name = get_weight_name(module_name, self.lora_weight_names)
if temp_A_buffer[weight_name] is None:
# Skip weight slicing if the weight is not present in the adapter
continue
if "qkv_proj" in module_name:
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
temp_A_buffer["qkv_proj"], self.tp_rank
)
temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"] = (
module.slice_lora_b_weights(
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
self.tp_rank,
)
)
else:
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
# FlashInfer LoRA backend.
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
)
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank
)
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
)
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
temp_B_buffer[weight_name], self.tp_rank
)
for name, weights in temp_A_buffer.items():
c = get_stacked_multiply(name)
buffer_view = self.A_buffer[name][layer_id][buffer_id][
: lora_rank * c, :
]
target_buffer = self.A_buffer[name][layer_id]
buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
load_lora_weight_tensor(buffer_view, weights)
for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name)
if c > 1:
for stacked_id in range(c):
buffer_view = self.B_buffer[name][layer_id][stacked_id][
buffer_id
][:, :lora_rank]
weight_slice = (
weights[stacked_id] if weights is not None else None
)
load_lora_weight_tensor(buffer_view, weight_slice)
else:
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
:, :lora_rank
]
load_lora_weight_tensor(buffer_view, weights)
target_buffer = self.B_buffer[name][layer_id]
buffer_view = target_buffer[buffer_id, :, :lora_rank]
load_lora_weight_tensor(buffer_view, weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType
......
......@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
......
......@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
return int(match.group(1))
def get_customized_names_from_hf_names(
hf_module_names: Set[str], base_model: torch.nn.Module
) -> Set[str]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if hasattr(base_model, "get_module_name"):
return {base_model.get_module_name(name) for name in hf_module_names}
else:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return {params_mapping.get(name, name) for name in hf_module_names}
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
) -> Tuple[int]:
......@@ -95,22 +67,9 @@ def get_hidden_dim(
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
# TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return (
config.hidden_size,
None, # qkv_proj is only used in LoRA A
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
head_dim * config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
head_dim * config.num_attention_heads,
return config.hidden_size, head_dim * (
config.num_attention_heads + config.num_key_value_heads * 2
)
elif module_name == "o_proj":
return (
......@@ -118,7 +77,7 @@ def get_hidden_dim(
config.hidden_size,
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
return config.hidden_size, config.intermediate_size * 2
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
......@@ -127,26 +86,22 @@ def get_hidden_dim(
def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
) -> set[str]:
"""
Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
"q_proj": (["qkv_proj"], ["q_proj"]),
"k_proj": (["qkv_proj"], ["kv_proj"]),
"v_proj": (["qkv_proj"], ["kv_proj"]),
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
result = (set(), set())
result = set()
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
weight_name = params_mapping.get(name, name)
result.add(weight_name)
return result
......@@ -156,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
"""
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name(
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
target_name: str, lora_weight_names: Tuple[Set[str]]
) -> Optional[str]:
"""
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
Get the weight name in lora_weight_names that can match target_name.
If there is a weight name in lora_weight_names that can match target_name, return this name
Else raise ValueError.
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name in lora_weight_names[idx]:
for weight_name in lora_weight_names:
if weight_name in target_name:
return weight_name
raise ValueError(
......@@ -180,9 +133,4 @@ def get_weight_name(
)
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
......@@ -501,23 +501,16 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def get_hidden_dim(self, module_name):
# return input_dim, output_dim
# TODO: the special handling of qkv will be addressed in #8940.
if module_name == "qkv_proj":
return (
self.config.hidden_size,
None, # qkv_proj is only used in LoRA A
self.config.head_dim
* (
self.config.num_attention_heads
+ self.config.num_key_value_heads * 2
),
)
elif module_name == "kv_proj":
return (
None, # kv_proj is only used in LoRA B
self.config.head_dim * self.config.num_key_value_heads,
)
elif module_name == "q_proj":
return (
None, # q_proj is only used in LoRA B
self.config.head_dim * self.config.num_attention_heads,
)
elif module_name in ["o_proj"]:
elif module_name == "o_proj":
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
......@@ -527,7 +520,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
"Currently SGLang requires uniform intermediate size for all layers. "
"Please file an issue if you need support for non-uniform intermediate sizes."
)
return self.config.hidden_size, self.config.intermediate_size[0]
return self.config.hidden_size, self.config.intermediate_size[0] * 2
elif module_name == "down_proj":
assert len(set(self.config.intermediate_size)) == 1, (
"Currently SGLang requires uniform intermediate size for all layers. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment