"torchvision/transforms/_functional_tensor.py" did not exist on "00c119c853a74848655799c9b185cedf7a01f891"
Unverified Commit c45cab1c authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix] Fix accuracy bug and refactor codes for lora (#3413)

parent 27c4c9cf
from .base_backend import BaseLoraBackend
from .flashinfer_backend import FlashInferLoraBackend
from .triton_backend import TritonLoraBackend
from .base_backend import BaseLoRABackend
from .flashinfer_backend import FlashInferLoRABackend
from .triton_backend import TritonLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
backend_mapping = {
"triton": TritonLoRABackend,
"flashinfer": FlashInferLoRABackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
__all__ = [
"FlashInferLoraBackend",
"TritonLoraBackend",
"BaseLoRABackend",
"FlashInferLoRABackend",
"TritonLoRABackend",
"get_backend_from_name",
]
......@@ -2,7 +2,7 @@ from typing import Tuple, Union
import torch
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.utils import LoRABatchInfo
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
......@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
return mapping.get(name, False)
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
......@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
return mapping.get(name, False)
class BaseLoraBackend:
class BaseLoRABackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
......@@ -32,11 +32,11 @@ class BaseLoraBackend:
and the operation of scaling and adding will be fused into kernel
"""
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
......@@ -46,10 +46,11 @@ class BaseLoraBackend:
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r
Returns:
result with shape (s, r)
result with shape (s, c * r)
"""
pass
......@@ -83,7 +84,7 @@ class BaseLoraBackend:
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing:
If passed in as a tuple of two tensors, it should contain:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
......@@ -91,5 +92,26 @@ class BaseLoraBackend:
"""
pass
def set_batch_info(self, batch_info: LoraBatchInfo):
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
gate_up_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
Returns:
result with shape (s, 2 * output_dim)
"""
pass
def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info
......@@ -2,17 +2,17 @@ from typing import Tuple
import torch
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
class FlashInferLoraBackend(BaseLoraBackend):
class FlashInferLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
......@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
**kwargs,
) -> torch.Tensor:
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
......@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
)
return lora_output
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
lora_rank = gate_up_lora_b[0].shape[-1]
output_dim = gate_up_lora_b[0].shape[-2]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
lora_output = torch.empty(
(x.shape[0], 2 * output_dim),
device=x.device,
dtype=x.dtype,
)
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],
)
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank:].contiguous(),
weights=gate_up_lora_b[1],
)
return lora_output
import torch
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
gate_up_lora_b_fwd,
qkv_lora_b_fwd,
sgemm_lora_a_fwd,
sgemm_lora_b_fwd,
)
from sglang.srt.lora.utils import LoRABatchInfo
class TritonLoraBackend(BaseLoraBackend):
class TritonLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info)
def run_lora_a_sgemm(
......@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
scaling,
)
return lora_output
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
assert isinstance(gate_up_lora_b, torch.Tensor)
output_dim = gate_up_lora_b.shape[-2] // 2
# lora_a_output: (s, 2 * r)
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
lora_output = gate_up_lora_b_fwd(
lora_a_output,
gate_up_lora_b,
self.batch_info,
output_dim,
base_output,
scaling,
)
return lora_output
import torch
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module):
def __init__(
self,
base_layer: nn.Module,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
):
super().__init__()
self.base_layer: nn.Module = base_layer
self.lora_rank: int = lora_rank
self.scaling: float = scaling
self.set_lora: bool = False
self.lora_backend: BaseLoRABackend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: VocabParallelEmbedding,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: ColumnParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self,
base_layer: MergedColumnParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat(
(B_buffer[0], B_buffer[1]), dim=-2
).contiguous()
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
lora_output = self.lora_backend.run_gate_up_lora(
x,
self.A_buffer_gate_up,
self.B_buffer_gate_up,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
self,
base_layer: QKVParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer_qkv: torch.Tensor,
B_buffer_q: torch.Tensor,
B_buffer_kv: torch.Tensor,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
if self.lora_backend.fuse_stacked_lora_b:
backend_kwargs["output_offset"] = self.output_offset
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: RowParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
def forward(self, input_: torch.Tensor):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def get_lora_layer(
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
......@@ -19,282 +19,25 @@
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import re
from dataclasses import dataclass
from typing import Dict, List
import torch
from torch import nn
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader
@dataclass
class LoraBatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
super().__init__()
self.base_layer = base_layer
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
self.lora_backend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
return output
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer,
B_buffer,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
output_dim = base_output.shape[-1]
lora_output = torch.empty_like(base_output)
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
weights=self.B_buffer[0],
)
lora_output[:, output_dim : 2 * output_dim] = (
self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
weights=self.B_buffer[1],
)
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
if self.lora_backend.fuse_qkv_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
base_output=base_output,
scaling=self.scaling,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
def forward(self, input_):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def get_lora_layer(
layer: nn.Module, lora_rank, scaling, lora_backend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
def get_mapped_params(module_names):
ret = set()
for module_name in module_names:
ret.add(params_mapping(module_name))
return list(ret)
class LoRALayer(nn.Module):
def __init__(self, config, base_hf_config):
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
super().__init__()
self.config = config
self.base_hf_config = base_hf_config
self.weights = {}
self.weight_gpu = {}
self.config: LoRAConfig = config
self.base_hf_config: AutoConfig = base_hf_config
self.weights: Dict[str, torch.Tensor] = {}
self.weight_gpu: Dict[str, torch.Tensor] = {}
def load_to_gpu(self):
for name, weight in self.weights.items():
......@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
def __init__(
self,
uid: str,
config: LoRAConfig,
base_hf_config: AutoConfig,
load_config: LoadConfig,
lora_backend: BaseLoRABackend,
):
super().__init__()
self.uid = uid
self.config = config
self.uid: str = uid
self.config: LoRAConfig = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.lora_backend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r
self.base_hf_config: AutoConfig = base_hf_config
self.load_config: LoadConfig = load_config
self.lora_backend: BaseLoRABackend = lora_backend
self.scaling: float = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList(
self.layers: List[LoRALayer] = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
]
)
self.weights = {}
self.weights_gpu = {}
def get_stacked_multiply(self, module_name):
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
self.weights: Dict[str, torch.Tensor] = {}
self.weights_gpu: Dict[str, torch.Tensor] = {}
def load_to_gpu(self):
for name, weight in self.weights.items():
......@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
for weight_name in weight_names:
if "k_proj" in weight_name:
q_name = weight_name.replace("k_proj", "q_proj")
v_name = weight_name.replace("k_proj", "v_proj")
kv_name = weight_name.replace("k_proj", "kv_proj")
qkv_name = weight_name.replace("k_proj", "qkv_proj")
if "lora_A" in weight_name:
layer.weights[qkv_name] = torch.cat(
(
layer.weights[q_name],
layer.weights[weight_name],
layer.weights[v_name],
),
0,
)
layer.weights.pop(q_name)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.stack(
[
layer.weights[weight_name],
layer.weights[v_name],
],
dim=0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if "lora_A" in weight_name:
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
else:
layer.weights[gate_up_name] = torch.stack(
[layer.weights[weight_name], layer.weights[up_name]], dim=0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)
self.stack_qkv_proj(weight_names, layer.weights)
self.stack_gate_up_proj(weight_names, layer.weights)
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set()
for weight_name in weight_names:
if "k_proj" in weight_name:
target_module.add("k_proj")
if "q_proj" in weight_name:
target_module.add("q_proj")
if "v_proj" in weight_name:
target_module.add("v_proj")
if len(target_module) == 0:
return
for weight_name in weight_names:
# We assume every lora adaptor should contain lora modules for q_proj
if "q_proj" in weight_name:
q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj")
v_name = weight_name.replace("q_proj", "v_proj")
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have lora, initialize it to zero
k_proj_weight = (
weights[k_name]
if "k_proj" in target_module
else torch.zeros_like(weights[v_name])
)
if "lora_A" in weight_name:
weights[qkv_name] = torch.cat(
(
weights[q_name],
k_proj_weight,
weights[v_name],
),
0,
)
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
else:
weights[kv_name] = torch.stack(
[
k_proj_weight,
weights[v_name],
],
dim=0,
)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
def stack_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
for weight_name in weight_names:
if "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if "lora_A" in weight_name:
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
else:
weights[gate_up_name] = torch.stack(
[weights[weight_name], weights[up_name]], dim=0
)
weights.pop(weight_name)
weights.pop(up_name)
......@@ -16,307 +16,115 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
import re
from typing import Dict, List, Set, Tuple
import torch
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
LoRAType,
get_customized_names_from_hf_names,
get_layer_id,
get_stacked_name,
get_weight_name,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available, replace_submodule
from sglang.srt.utils import replace_submodule
logger = logging.getLogger(__name__)
def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class.
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_hidden_dim(module_name, config):
# Fallback solution of get_hidden_dim for different modules
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_stacked_name(name):
# origin name -> (name for A, name for B)
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_backend_from_name(name):
backend_mapping = {
"triton": TritonLoraBackend,
"flashinfer": FlashInferLoraBackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
class LoRAManager:
def __init__(
self,
base_model,
lora_paths,
base_hf_config,
max_loras_per_batch,
load_config,
dtype,
lora_backend,
base_model: torch.nn.Module,
lora_paths: Dict[str, str],
base_hf_config: AutoConfig,
max_loras_per_batch: int,
load_config: LoadConfig,
dtype: torch.dtype,
lora_backend: str = "triton",
):
self.base_model = base_model
self.lora_paths = lora_paths
self.base_hf_config = base_hf_config
self.max_loras_per_batch = max_loras_per_batch
self.load_config = load_config
self.dtype = dtype
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
self.base_model: torch.nn.Module = base_model
self.lora_paths: Dict[str, str] = lora_paths
self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch: int = max_loras_per_batch
self.load_config: LoadConfig = load_config
self.dtype: torch.dtype = dtype
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend)
self.lora_backend = backend_type(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
self.init_loras()
self.init_lora_memory_pool()
self.init_lora_batch()
def match_target_modules(self, module_name):
for target_module in self.target_modules:
if module_name.split(".")[-1] == target_module:
return True
return False
def get_target_modules(self):
modules = []
for module_name, module in self.base_model.named_modules():
if self.match_target_modules(module_name):
modules.append((module_name, module))
return modules
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.max_lora_dim, self.scaling, self.lora_backend
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def init_loras(self):
# get configs and target modules
self.configs = {}
self.origin_target_modules = set()
# Config of each LoRA adapter
self.configs: Dict[str, LoRAConfig] = {}
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path)
self.origin_target_modules = set(self.origin_target_modules) | set(
self.hf_target_names = set(self.hf_target_names) | set(
self.configs[name].target_modules
)
if hasattr(self.base_model, "get_module_name"):
self.target_modules = {
self.base_model.get_module_name(module)
for module in self.origin_target_modules
}
else:
logger.warning(
"WARNING: get_module_name() is not defined, "
"which is used to map config module name to model implementation module name."
"Use the default one, but please check if it is correct for your model."
)
self.target_modules = {
get_module_name(module) for module in self.origin_target_modules
}
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
# Target lora weight names for lora_a and lora_b modules repectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
self.lora_weight_names: Set[Tuple[str]] = set(
[get_stacked_name(module) for module in self.hf_target_names]
)
# load all weights to cpu
self.loras = []
self.lora_id = {}
self.loras: Dict[str, LoRAAdapter] = {}
for name in self.lora_paths.keys():
self.lora_id[name] = len(self.loras)
self.loras.append(
LoRAAdapter(
name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter = LoRAAdapter(
name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
self.loras[-1].initialize_weights()
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# misc lora configs
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
self.scaling = self.loras[0].scaling
# FIXME remove the restrictions
# FIXME remove the restrictions after implementing unified paging
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
self.scaling: float = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
assert all(x.scaling == self.scaling for x in self.loras)
assert all(x.scaling == self.scaling for x in self.loras.values())
# monkey patch to use the LoRA version
self.lora_modules = []
for module_name, module in self.get_target_modules():
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)
# Convert original model layers to layers with LoRA
self.convert_to_lora_layers()
def init_lora_memory_pool(self):
# preallocate lora memory pool
self.A_buffer = {}
self.B_buffer = {}
num_layer = self.base_hf_config.num_hidden_layers
for module_A, module_B in self.target_weights:
# init A tensor, column_major=True
if hasattr(self.base_model, "get_hidden_dim"):
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
else:
logger.warning(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
hidden_dim_A,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
# init B tensor, column_major=True
if hasattr(self.base_model, "get_hidden_dim"):
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
else:
logger.warning(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
c,
self.max_loras_per_batch,
hidden_dim_B,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
def init_lora_batch(self):
self.active_uids = set() # set of active loras
self.buffer_id = {} # lora uid -> idx in memory pool
def get_weight_name(self, name, idx):
for target_weight_name in self.target_weights:
if target_weight_name[idx] in name:
return target_weight_name[idx]
def load_lora(self, uid, buffer_id):
num_layer = self.base_hf_config.num_hidden_layers
if uid is None:
for i in range(num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
)
for i in range(num_layer):
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = self.get_weight_name(name, 0)
if lora_weight_name:
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
if c > 1:
for j in range(c):
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
weights[j]
)
else:
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
weights
)
# Initialize target lora modules in memory pool
self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch
i = 0
j = len(self.active_uids)
evictable_uids = list(self.active_uids)
for uid in cur_uids:
if uid not in self.active_uids:
if j < self.max_loras_per_batch:
index = j
j += 1
else:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
assert i < len(evictable_uids)
self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i])
index = i
i += 1
self.load_lora(uid, index)
self.active_uids.add(uid)
self.buffer_id[uid] = index
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
# FIXME: Handle lora uid with None more safely
if cur_uids == set([None]):
return
......@@ -332,9 +140,9 @@ class LoRAManager:
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
batch_info = LoraBatchInfo(
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
......@@ -346,16 +154,40 @@ class LoRAManager:
# call set_lora_info for each lora modules
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
if "qkv_proj" not in module_name:
weight_name = self.get_weight_name(module_name, 0)
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
)
else:
module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
)
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.max_lora_dim, self.scaling, self.lora_backend
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def convert_to_lora_layers(self):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
self.hf_target_names, self.base_model
)
# Monkey patch to use the LoRA version layers
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
for module_name, module in self.base_model.named_modules():
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)
from typing import Dict, List, Optional, Set, Tuple
import torch
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.utils import (
LoRAType,
get_hidden_dim,
get_stacked_multiply,
get_weight_name,
)
class LoRAMemoryPool:
"""Class for memory pool management of lora modules"""
def __init__(
self,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_lora_dim: int,
dtype: torch.dtype,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
self.max_loras_per_batch: int = max_loras_per_batch
self.max_lora_dim: int = max_lora_dim
self.dtype: torch.dtype = dtype
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
# Lora uid -> buffer idx in memory pool
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool
# All uids are initalized as empty strings for empty buffer slots
# Here we don't initalize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
def init_buffers(
self,
lora_weight_names: Set[Tuple[str]],
base_model: torch.nn.Module,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
for module_A, module_B in lora_weight_names:
# Init A tensor, column_major=False
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
c = get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
input_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(self.num_layer)
]
# Init B tensor, column_major=True
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
c = get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
c, # stacked lora_b modules might need separation
self.max_loras_per_batch,
output_dim,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(self.num_layer)
]
def prepare_lora_batch(
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
return buffer_id, ""
for buffer_id in range(self.max_loras_per_batch):
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
return buffer_id, self.buffer_id_to_uid[buffer_id]
raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)
for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id, evicted_lora_uid = get_available_buffer_slot()
if evicted_lora_uid != "":
self.uid_to_buffer_id.pop(evicted_lora_uid)
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapters.get(uid, None)
)
self.uid_to_buffer_id[uid] = buffer_id
self.buffer_id_to_uid[buffer_id] = uid
def load_lora_weight_to_buffer(
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
):
if uid is None:
for i in range(self.num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
assert lora_adapter is not None
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_A
)
if lora_weight_name:
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
weights
)
else:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_B
)
if lora_weight_name:
c = get_stacked_multiply(lora_weight_name)
if c > 1:
for stacked_id in range(c):
self.B_buffer[lora_weight_name][layer_id][stacked_id][
buffer_id
].copy_(weights[stacked_id])
else:
self.B_buffer[lora_weight_name][layer_id][0][
buffer_id
].copy_(weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType
) -> torch.Tensor:
if lora_type == LoRAType.LORA_A:
return self.A_buffer[weight_name][layer_id]
return self.B_buffer[weight_name][layer_id]
def get_buffer_id(self, lora_uid: str):
return self.uid_to_buffer_id[lora_uid]
from .gate_up_lora_b import gate_up_lora_b_fwd
from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd
from .sgemm_lora_b import sgemm_lora_b_fwd
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
__all__ = [
"gate_up_lora_b_fwd",
"qkv_lora_b_fwd",
"sgemm_lora_a_fwd",
"sgemm_lora_b_fwd",
]
import torch
import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
def _gate_up_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
K, # K = R
output_dim,
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
# output_dim >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# gate_up_id decides which of gate or up (0: gate, 1: up)
batch_id = tl.program_id(axis=2)
gate_up_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = gate_up_id * output_dim # offset on output dim
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K)
and (n_offset[None, :] < output_dim),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def gate_up_lora_b_fwd(
x: torch.Tensor,
gate_up_lora_b: torch.Tensor,
batch_info: LoRABatchInfo,
output_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, 2 * r)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
# output: (s, 2 * output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
# lora_output[:, output_dim:]
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
r = gate_up_lora_b.shape[-1]
assert input_dim == 2 * r
BLOCK_S = 16
BLOCK_R = 16
BLOCK_OUT = 64
grid_b = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
2, # this dimension decides current block computes on gate or up proj
batch_info.bs,
)
if base_output is None:
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_gate_up_lora_b_kernel[grid_b](
x,
gate_up_lora_b,
output,
r,
output_dim,
x.stride(0),
x.stride(1),
gate_up_lora_b.stride(0),
gate_up_lora_b.stride(1),
gate_up_lora_b.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output
......@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
......@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
def qkv_lora_b_fwd(
x: torch.Tensor,
qkv_lora_b: torch.Tensor,
batch_info: LoraBatchInfo,
batch_info: LoRABatchInfo,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
......@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
# output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
# lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
# = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
# lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
# = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
# Get dims
s = x.shape[0]
......
......@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
......@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
def sgemm_lora_a_fwd(
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
......
......@@ -2,7 +2,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
......@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
def sgemm_lora_b_fwd(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoraBatchInfo,
batch_info: LoRABatchInfo,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
......
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Set, Tuple
import torch
from sglang.srt.hf_transformers_utils import AutoConfig
@dataclass
class LoRABatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class LoRAType(Enum):
LORA_A = 0
LORA_B = 1
def get_layer_id(name: str) -> int:
"""
Extract integer id of layer from its name in string.
"""
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
def get_customized_names_from_hf_names(
hf_module_names: Set[str], base_model: torch.nn.Module
) -> Set[str]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if hasattr(base_model, "get_module_name"):
return {base_model.get_module_name(name) for name in hf_module_names}
else:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return {params_mapping.get(name, name) for name in hf_module_names}
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
if hasattr(base_model, "get_hidden_dim"):
return base_model.get_hidden_dim(module_name)
else:
"""
WARNING: get_hidden_dim() is not defined,
which is used to get the hidden dim for different lora modules
Use the default one, but please check if it is correct for your model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_stacked_name(name: str) -> Tuple[str]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
"""
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_stacked_multiply(module_name: str) -> int:
"""
Mapping a lora module name to its magnification at output dimension
"""
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name(
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
) -> Optional[str]:
"""
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else return None
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names:
if weight_name_pair[idx] in target_name:
return weight_name_pair[idx]
......@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
LORA_SETS = [
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
{
"base": "meta-llama/Llama-3.1-8B-Instruct",
"loras": ["reissbaker/llama-3.1-8b-abliterated-lora"],
"decode_tolerance": 8e-2,
},
]
TORCH_DTYPES = [torch.float16]
......@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
tol = lora_set.get("prefill_tolerance", prefill_tolerance)
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
f"prefill logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"prefill_tolerance={prefill_tolerance}."
......@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
"\n",
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
tol = lora_set.get("decode_tolerance", decode_tolerance)
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
f"decode logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"decode_tolerance={decode_tolerance}."
......@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
# compare output strings
srt_output_str = srt_outputs.output_strs[i].strip(" ")
hf_output_str = hf_outputs.output_strs[i]
hf_output_str = hf_outputs.output_strs[i].strip(" ")
print(f"srt_output_str={srt_output_str}")
print(f"hf_output_str={hf_output_str}")
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment