Unverified Commit c45cab1c authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Fix] Fix accuracy bug and refactor codes for lora (#3413)

parent 27c4c9cf
from .base_backend import BaseLoraBackend from .base_backend import BaseLoRABackend
from .flashinfer_backend import FlashInferLoraBackend from .flashinfer_backend import FlashInferLoRABackend
from .triton_backend import TritonLoraBackend from .triton_backend import TritonLoRABackend
def get_backend_from_name(name: str) -> BaseLoRABackend:
"""
Get corresponding backend class from backend's name
"""
backend_mapping = {
"triton": TritonLoRABackend,
"flashinfer": FlashInferLoRABackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
__all__ = [ __all__ = [
"FlashInferLoraBackend", "BaseLoRABackend",
"TritonLoraBackend", "FlashInferLoRABackend",
"TritonLoRABackend",
"get_backend_from_name",
] ]
...@@ -2,7 +2,7 @@ from typing import Tuple, Union ...@@ -2,7 +2,7 @@ from typing import Tuple, Union
import torch import torch
from sglang.srt.lora.lora import LoraBatchInfo from sglang.srt.lora.utils import LoRABatchInfo
def get_fuse_output_scaling_add_from_name(name: str) -> bool: def get_fuse_output_scaling_add_from_name(name: str) -> bool:
...@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool: ...@@ -13,7 +13,7 @@ def get_fuse_output_scaling_add_from_name(name: str) -> bool:
return mapping.get(name, False) return mapping.get(name, False)
def get_fuse_qkv_lora_b_from_name(name: str) -> bool: def get_fuse_stacked_lora_b_from_name(name: str) -> bool:
mapping = { mapping = {
"triton": True, "triton": True,
"flashinfer": False, "flashinfer": False,
...@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool: ...@@ -21,7 +21,7 @@ def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
return mapping.get(name, False) return mapping.get(name, False)
class BaseLoraBackend: class BaseLoRABackend:
"""Base class for different Lora backends. """Base class for different Lora backends.
Each backend has its own implementation of Lora kernels. Each backend has its own implementation of Lora kernels.
...@@ -32,11 +32,11 @@ class BaseLoraBackend: ...@@ -32,11 +32,11 @@ class BaseLoraBackend:
and the operation of scaling and adding will be fused into kernel and the operation of scaling and adding will be fused into kernel
""" """
def __init__(self, name: str, batch_info: LoraBatchInfo = None): def __init__(self, name: str, batch_info: LoRABatchInfo = None):
self.name = name self.name = name
self.batch_info = batch_info self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name) self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name) self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
def run_lora_a_sgemm( def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
...@@ -46,10 +46,11 @@ class BaseLoraBackend: ...@@ -46,10 +46,11 @@ class BaseLoraBackend:
Args: Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank weights: a set of lora weights with shape (num_lora, c * r, input_dim),
here r is lora rank, c is a multiplier for stacked modules (e.g., c=3 for qkv_proj, c=2 for gate_up_proj)
usually input_dim is much larger than r usually input_dim is much larger than r
Returns: Returns:
result with shape (s, r) result with shape (s, c * r)
""" """
pass pass
...@@ -83,7 +84,7 @@ class BaseLoraBackend: ...@@ -83,7 +84,7 @@ class BaseLoraBackend:
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim) qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv. qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r) If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing: If passed in as a tuple of two tensors, it should contain:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r) a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r) and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns: Returns:
...@@ -91,5 +92,26 @@ class BaseLoraBackend: ...@@ -91,5 +92,26 @@ class BaseLoraBackend:
""" """
pass pass
def set_batch_info(self, batch_info: LoraBatchInfo): def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
gate_up_lora_a: lora_a module for gate_up_proj, with shape (num_lora, 2 * r, input_dim)
gate_up_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora, 2 * output_dim, r)
If passed in as a tuple, it should contain two tensors with shape (num_lora, output_dim, r)
Returns:
result with shape (s, 2 * output_dim)
"""
pass
def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info self.batch_info = batch_info
...@@ -2,17 +2,17 @@ from typing import Tuple ...@@ -2,17 +2,17 @@ from typing import Tuple
import torch import torch
from sglang.srt.lora.backend import BaseLoraBackend from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.lora import LoraBatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper from flashinfer import SegmentGEMMWrapper
class FlashInferLoraBackend(BaseLoraBackend): class FlashInferLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None): def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info) super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer # Set up SGemm Wrapper from flashinfer
...@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend): ...@@ -55,6 +55,8 @@ class FlashInferLoraBackend(BaseLoraBackend):
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert isinstance(qkv_lora_b, tuple) and len(qkv_lora_b) == 2
# Shape of lora_a_output: (s, 3 * r) # Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a) lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
...@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend): ...@@ -89,3 +91,38 @@ class FlashInferLoraBackend(BaseLoraBackend):
) )
return lora_output return lora_output
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
assert isinstance(gate_up_lora_b, tuple) and len(gate_up_lora_b) == 2
lora_rank = gate_up_lora_b[0].shape[-1]
output_dim = gate_up_lora_b[0].shape[-2]
# Shape of lora_a_output: (s, 2 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=gate_up_lora_a)
lora_output = torch.empty(
(x.shape[0], 2 * output_dim),
device=x.device,
dtype=x.dtype,
)
# Compute lora for gate and up proj respectively
lora_output[:, :output_dim] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(),
weights=gate_up_lora_b[0],
)
lora_output[:, output_dim:] = self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank:].contiguous(),
weights=gate_up_lora_b[1],
)
return lora_output
import torch import torch
from sglang.srt.lora.backend import BaseLoraBackend from sglang.srt.lora.backend import BaseLoRABackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.triton_ops import ( from sglang.srt.lora.triton_ops import (
gate_up_lora_b_fwd,
qkv_lora_b_fwd, qkv_lora_b_fwd,
sgemm_lora_a_fwd, sgemm_lora_a_fwd,
sgemm_lora_b_fwd, sgemm_lora_b_fwd,
) )
from sglang.srt.lora.utils import LoRABatchInfo
class TritonLoraBackend(BaseLoraBackend): class TritonLoRABackend(BaseLoRABackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None): def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info) super().__init__(name, batch_info)
def run_lora_a_sgemm( def run_lora_a_sgemm(
...@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend): ...@@ -59,3 +60,32 @@ class TritonLoraBackend(BaseLoraBackend):
scaling, scaling,
) )
return lora_output return lora_output
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
assert isinstance(gate_up_lora_b, torch.Tensor)
output_dim = gate_up_lora_b.shape[-2] // 2
# lora_a_output: (s, 2 * r)
lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
lora_output = gate_up_lora_b_fwd(
lora_a_output,
gate_up_lora_b,
self.batch_info,
output_dim,
base_output,
scaling,
)
return lora_output
import torch
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.lora.backend import BaseLoRABackend
class BaseLayerWithLoRA(nn.Module):
def __init__(
self,
base_layer: nn.Module,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
):
super().__init__()
self.base_layer: nn.Module = base_layer
self.lora_rank: int = lora_rank
self.scaling: float = scaling
self.set_lora: bool = False
self.lora_backend: BaseLoRABackend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: VocabParallelEmbedding,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: ColumnParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self,
base_layer: MergedColumnParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b:
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
self.B_buffer_gate_up = torch.cat(
(B_buffer[0], B_buffer[1]), dim=-2
).contiguous()
else:
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
lora_output = self.lora_backend.run_gate_up_lora(
x,
self.A_buffer_gate_up,
self.B_buffer_gate_up,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
self,
base_layer: QKVParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer_qkv: torch.Tensor,
B_buffer_q: torch.Tensor,
B_buffer_kv: torch.Tensor,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
if self.lora_backend.fuse_stacked_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
if self.lora_backend.fuse_stacked_lora_b:
backend_kwargs["output_offset"] = self.output_offset
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self,
base_layer: RowParallelLinear,
lora_rank: int,
scaling: float,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
backend_kwargs = {"base_output": base_output, "scaling": self.scaling}
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
**backend_kwargs,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
def forward(self, input_: torch.Tensor):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def get_lora_layer(
layer: nn.Module, lora_rank: int, scaling: int, lora_backend: BaseLoRABackend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
...@@ -19,282 +19,25 @@ ...@@ -19,282 +19,25 @@
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import re import re
from dataclasses import dataclass from typing import Dict, List
import torch import torch
from torch import nn from torch import nn
from sglang.srt.layers.linear import ( from sglang.srt.configs.load_config import LoadConfig
ColumnParallelLinear, from sglang.srt.hf_transformers_utils import AutoConfig
MergedColumnParallelLinear, from sglang.srt.lora.backend import BaseLoRABackend
QKVParallelLinear, from sglang.srt.lora.lora_config import LoRAConfig
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
@dataclass
class LoraBatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
super().__init__()
self.base_layer = base_layer
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
self.lora_backend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
def set_lora_info(self, *args):
pass
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
return output
def forward(self, input_: torch.Tensor):
# duplicate the logic in ColumnParallelLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_)
if self.base_layer.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer,
B_buffer,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
output_dim = base_output.shape[-1]
lora_output = torch.empty_like(base_output)
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
weights=self.B_buffer[0],
)
lora_output[:, output_dim : 2 * output_dim] = (
self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
weights=self.B_buffer[1],
)
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def init__(
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
if self.lora_backend.fuse_qkv_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
base_output=base_output,
scaling=self.scaling,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
)
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
def forward(self, input_):
# duplicate the logic in RowParallelLinear
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.base_layer.tp_size
)
input_parallel = splitted_input[tp_rank].contiguous()
output_parallel = self.base_layer.quant_method.apply(
self.base_layer, input_parallel
)
if self.set_lora:
output_parallel = self.apply_lora(output_parallel, input_parallel)
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
return output, output_bias
def get_lora_layer(
layer: nn.Module, lora_rank, scaling, lora_backend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
def get_mapped_params(module_names):
ret = set()
for module_name in module_names:
ret.add(params_mapping(module_name))
return list(ret)
class LoRALayer(nn.Module): class LoRALayer(nn.Module):
def __init__(self, config, base_hf_config): def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
super().__init__() super().__init__()
self.config = config self.config: LoRAConfig = config
self.base_hf_config = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
self.weights = {} self.weights: Dict[str, torch.Tensor] = {}
self.weight_gpu = {} self.weight_gpu: Dict[str, torch.Tensor] = {}
def load_to_gpu(self): def load_to_gpu(self):
for name, weight in self.weights.items(): for name, weight in self.weights.items():
...@@ -306,33 +49,32 @@ class LoRALayer(nn.Module): ...@@ -306,33 +49,32 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module): class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config, lora_backend): def __init__(
self,
uid: str,
config: LoRAConfig,
base_hf_config: AutoConfig,
load_config: LoadConfig,
lora_backend: BaseLoRABackend,
):
super().__init__() super().__init__()
self.uid = uid self.uid: str = uid
self.config = config self.config: LoRAConfig = config
assert self.config.hf_config["peft_type"].lower() == "lora" assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
self.load_config = load_config self.load_config: LoadConfig = load_config
self.lora_backend = lora_backend self.lora_backend: BaseLoRABackend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r self.scaling: float = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList( self.layers: List[LoRALayer] = nn.ModuleList(
[ [
LoRALayer(config, base_hf_config) LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers) for i in range(base_hf_config.num_hidden_layers)
] ]
) )
self.weights = {} self.weights: Dict[str, torch.Tensor] = {}
self.weights_gpu = {} self.weights_gpu: Dict[str, torch.Tensor] = {}
def get_stacked_multiply(self, module_name):
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def load_to_gpu(self): def load_to_gpu(self):
for name, weight in self.weights.items(): for name, weight in self.weights.items():
...@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module): ...@@ -367,44 +109,77 @@ class LoRAAdapter(nn.Module):
for i in range(self.base_hf_config.num_hidden_layers): for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i] layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()] weight_names = [name for name, _ in layer.weights.items()]
for weight_name in weight_names: self.stack_qkv_proj(weight_names, layer.weights)
if "k_proj" in weight_name: self.stack_gate_up_proj(weight_names, layer.weights)
q_name = weight_name.replace("k_proj", "q_proj")
v_name = weight_name.replace("k_proj", "v_proj") def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
kv_name = weight_name.replace("k_proj", "kv_proj")
qkv_name = weight_name.replace("k_proj", "qkv_proj") # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
if "lora_A" in weight_name: target_module = set()
layer.weights[qkv_name] = torch.cat( for weight_name in weight_names:
( if "k_proj" in weight_name:
layer.weights[q_name], target_module.add("k_proj")
layer.weights[weight_name], if "q_proj" in weight_name:
layer.weights[v_name], target_module.add("q_proj")
), if "v_proj" in weight_name:
0, target_module.add("v_proj")
) if len(target_module) == 0:
layer.weights.pop(q_name) return
layer.weights.pop(weight_name)
layer.weights.pop(v_name) for weight_name in weight_names:
else: # We assume every lora adaptor should contain lora modules for q_proj
layer.weights[kv_name] = torch.stack( if "q_proj" in weight_name:
[ q_name = weight_name
layer.weights[weight_name], k_name = weight_name.replace("q_proj", "k_proj")
layer.weights[v_name], v_name = weight_name.replace("q_proj", "v_proj")
], kv_name = weight_name.replace("q_proj", "kv_proj")
dim=0, qkv_name = weight_name.replace("q_proj", "qkv_proj")
)
layer.weights.pop(weight_name) # If k_proj doesn't have lora, initialize it to zero
layer.weights.pop(v_name) k_proj_weight = (
elif "gate_proj" in weight_name: weights[k_name]
up_name = weight_name.replace("gate_proj", "up_proj") if "k_proj" in target_module
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") else torch.zeros_like(weights[v_name])
if "lora_A" in weight_name: )
layer.weights[gate_up_name] = torch.cat( if "lora_A" in weight_name:
(layer.weights[weight_name], layer.weights[up_name]), 0 weights[qkv_name] = torch.cat(
) (
else: weights[q_name],
layer.weights[gate_up_name] = torch.stack( k_proj_weight,
[layer.weights[weight_name], layer.weights[up_name]], dim=0 weights[v_name],
) ),
layer.weights.pop(weight_name) 0,
layer.weights.pop(up_name) )
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
else:
weights[kv_name] = torch.stack(
[
k_proj_weight,
weights[v_name],
],
dim=0,
)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
def stack_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
for weight_name in weight_names:
if "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if "lora_A" in weight_name:
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
else:
weights[gate_up_name] = torch.stack(
[weights[weight_name], weights[up_name]], dim=0
)
weights.pop(weight_name)
weights.pop(up_name)
...@@ -16,307 +16,115 @@ ...@@ -16,307 +16,115 @@
# and "Punica: Multi-Tenant LoRA Serving" # and "Punica: Multi-Tenant LoRA Serving"
import logging import logging
import re from typing import Dict, List, Set, Tuple
import torch import torch
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
LoRAType,
get_customized_names_from_hf_names,
get_layer_id,
get_stacked_name,
get_weight_name,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available, replace_submodule from sglang.srt.utils import replace_submodule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class.
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return params_mapping.get(name, name)
def get_hidden_dim(module_name, config):
# Fallback solution of get_hidden_dim for different modules
# Please check if it aligns with your base model.
# Please implement the function in the model class if it is not.
# You can reference this function in llama.py.
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_stacked_name(name):
# origin name -> (name for A, name for B)
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_backend_from_name(name):
backend_mapping = {
"triton": TritonLoraBackend,
"flashinfer": FlashInferLoraBackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
class LoRAManager: class LoRAManager:
def __init__( def __init__(
self, self,
base_model, base_model: torch.nn.Module,
lora_paths, lora_paths: Dict[str, str],
base_hf_config, base_hf_config: AutoConfig,
max_loras_per_batch, max_loras_per_batch: int,
load_config, load_config: LoadConfig,
dtype, dtype: torch.dtype,
lora_backend, lora_backend: str = "triton",
): ):
self.base_model = base_model self.base_model: torch.nn.Module = base_model
self.lora_paths = lora_paths self.lora_paths: Dict[str, str] = lora_paths
self.base_hf_config = base_hf_config self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch = max_loras_per_batch self.max_loras_per_batch: int = max_loras_per_batch
self.load_config = load_config self.load_config: LoadConfig = load_config
self.dtype = dtype self.dtype: torch.dtype = dtype
logger.info(f"Using {lora_backend} as backend of Lora kernels.") # LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend) backend_type = get_backend_from_name(lora_backend)
self.lora_backend = backend_type(lora_backend) self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
self.init_loras() self.init_loras()
self.init_lora_memory_pool() self.init_lora_memory_pool()
self.init_lora_batch()
def match_target_modules(self, module_name):
for target_module in self.target_modules:
if module_name.split(".")[-1] == target_module:
return True
return False
def get_target_modules(self):
modules = []
for module_name, module in self.base_model.named_modules():
if self.match_target_modules(module_name):
modules.append((module_name, module))
return modules
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.max_lora_dim, self.scaling, self.lora_backend
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def init_loras(self): def init_loras(self):
# get configs and target modules # Config of each LoRA adapter
self.configs = {} self.configs: Dict[str, LoRAConfig] = {}
self.origin_target_modules = set()
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
self.hf_target_names: Set[str] = set()
for name, path in self.lora_paths.items(): for name, path in self.lora_paths.items():
self.configs[name] = LoRAConfig(path) self.configs[name] = LoRAConfig(path)
self.origin_target_modules = set(self.origin_target_modules) | set( self.hf_target_names = set(self.hf_target_names) | set(
self.configs[name].target_modules self.configs[name].target_modules
) )
if hasattr(self.base_model, "get_module_name"):
self.target_modules = { # Target lora weight names for lora_a and lora_b modules repectively.
self.base_model.get_module_name(module) # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
for module in self.origin_target_modules self.lora_weight_names: Set[Tuple[str]] = set(
} [get_stacked_name(module) for module in self.hf_target_names]
else:
logger.warning(
"WARNING: get_module_name() is not defined, "
"which is used to map config module name to model implementation module name."
"Use the default one, but please check if it is correct for your model."
)
self.target_modules = {
get_module_name(module) for module in self.origin_target_modules
}
self.target_weights = set(
[get_stacked_name(module) for module in self.origin_target_modules]
) )
# load all weights to cpu # load all weights to cpu
self.loras = [] self.loras: Dict[str, LoRAAdapter] = {}
self.lora_id = {}
for name in self.lora_paths.keys(): for name in self.lora_paths.keys():
self.lora_id[name] = len(self.loras) lora_adapter = LoRAAdapter(
self.loras.append( name,
LoRAAdapter( self.configs[name],
name, self.base_hf_config,
self.configs[name], self.load_config,
self.base_hf_config, self.lora_backend,
self.load_config,
self.lora_backend,
)
) )
self.loras[-1].initialize_weights() lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# misc lora configs # misc lora configs
self.max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()]) # FIXME remove the restrictions after implementing unified paging
self.scaling = self.loras[0].scaling self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
# FIXME remove the restrictions self.scaling: float = list(self.loras.values())[0].scaling
assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values()) assert all(x.hf_config["r"] == self.max_lora_dim for x in self.configs.values())
assert all(x.scaling == self.scaling for x in self.loras) assert all(x.scaling == self.scaling for x in self.loras.values())
# monkey patch to use the LoRA version # Convert original model layers to layers with LoRA
self.lora_modules = [] self.convert_to_lora_layers()
for module_name, module in self.get_target_modules():
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
)
def init_lora_memory_pool(self): def init_lora_memory_pool(self):
# preallocate lora memory pool # Initialize memory pool
self.A_buffer = {} self.memory_pool = LoRAMemoryPool(
self.B_buffer = {} self.base_hf_config, self.max_loras_per_batch, self.max_lora_dim, self.dtype
num_layer = self.base_hf_config.num_hidden_layers )
for module_A, module_B in self.target_weights:
# init A tensor, column_major=True
if hasattr(self.base_model, "get_hidden_dim"):
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
else:
logger.warning(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
hidden_dim_A,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
# init B tensor, column_major=True
if hasattr(self.base_model, "get_hidden_dim"):
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
else:
logger.warning(
"WARNING: get_hidden_dim() is not defined, "
"which is used to get the hidden dim for different lora modules"
"Use the default one, but please check if it is correct for your model."
)
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
c = self.loras[-1].get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
c,
self.max_loras_per_batch,
hidden_dim_B,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(num_layer)
]
def init_lora_batch(self):
self.active_uids = set() # set of active loras
self.buffer_id = {} # lora uid -> idx in memory pool
def get_weight_name(self, name, idx):
for target_weight_name in self.target_weights:
if target_weight_name[idx] in name:
return target_weight_name[idx]
def load_lora(self, uid, buffer_id):
num_layer = self.base_hf_config.num_hidden_layers
if uid is None:
for i in range(num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
for i in range(num_layer): # Initialize target lora modules in memory pool
layer_weights = self.loras[self.lora_id[uid]].layers[i].weights self.memory_pool.init_buffers(self.lora_weight_names, self.base_model)
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = self.get_weight_name(name, 0)
if lora_weight_name:
self.A_buffer[lora_weight_name][i][buffer_id].copy_(weights)
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
if c > 1:
for j in range(c):
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
weights[j]
)
else:
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
weights
)
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool # load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths) cur_uids = set(forward_batch.lora_paths)
assert len(cur_uids) <= self.max_loras_per_batch assert len(cur_uids) <= self.max_loras_per_batch
i = 0 self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
j = len(self.active_uids)
evictable_uids = list(self.active_uids)
for uid in cur_uids:
if uid not in self.active_uids:
if j < self.max_loras_per_batch:
index = j
j += 1
else:
while i < len(evictable_uids) and evictable_uids[i] in cur_uids:
i += 1
assert i < len(evictable_uids)
self.active_uids.remove(evictable_uids[i])
self.buffer_id.pop(evictable_uids[i])
index = i
i += 1
self.load_lora(uid, index)
self.active_uids.add(uid)
self.buffer_id[uid] = index
# FIXME: Handle lora uid with None more safely
if cur_uids == set([None]): if cur_uids == set([None]):
return return
...@@ -332,9 +140,9 @@ class LoRAManager: ...@@ -332,9 +140,9 @@ class LoRAManager:
max_len = int(torch.max(seg_lens)) max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths): for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path] weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
batch_info = LoraBatchInfo( batch_info = LoRABatchInfo(
bs=bs, bs=bs,
seg_lens=seg_lens, seg_lens=seg_lens,
seg_indptr=seg_indptr, seg_indptr=seg_indptr,
...@@ -346,16 +154,40 @@ class LoRAManager: ...@@ -346,16 +154,40 @@ class LoRAManager:
# call set_lora_info for each lora modules # call set_lora_info for each lora modules
for module_name, module in self.lora_modules: for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
if "qkv_proj" not in module_name: if "qkv_proj" not in module_name:
weight_name = self.get_weight_name(module_name, 0) weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info( module.set_lora_info(
self.A_buffer[weight_name][layer_id], self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.B_buffer[weight_name][layer_id], self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
) )
else: else:
module.set_lora_info( module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id], self.memory_pool.get_tensor("qkv_proj", layer_id, LoRAType.LORA_A),
self.B_buffer["q_proj"][layer_id], self.memory_pool.get_tensor("q_proj", layer_id, LoRAType.LORA_B),
self.B_buffer["kv_proj"][layer_id], self.memory_pool.get_tensor("kv_proj", layer_id, LoRAType.LORA_B),
)
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.max_lora_dim, self.scaling, self.lora_backend
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def convert_to_lora_layers(self):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
self.hf_target_names, self.base_model
)
# Monkey patch to use the LoRA version layers
self.lora_modules: List[Tuple[str, torch.nn.Module]] = []
for module_name, module in self.base_model.named_modules():
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names:
self.lora_modules.append(
(module_name, self.set_lora_module(module_name, module))
) )
from typing import Dict, List, Optional, Set, Tuple
import torch
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.utils import (
LoRAType,
get_hidden_dim,
get_stacked_multiply,
get_weight_name,
)
class LoRAMemoryPool:
"""Class for memory pool management of lora modules"""
def __init__(
self,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
max_lora_dim: int,
dtype: torch.dtype,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
self.max_loras_per_batch: int = max_loras_per_batch
self.max_lora_dim: int = max_lora_dim
self.dtype: torch.dtype = dtype
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
# B_buffer contains num_layer number of column-major tensors with shape
# (stacked_num, max_loras_per_batch, output_dim, max_lora_dim)
self.A_buffer: Dict[str, List[torch.Tensor]] = {}
self.B_buffer: Dict[str, List[torch.Tensor]] = {}
# Lora uid -> buffer idx in memory pool
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool
# All uids are initalized as empty strings for empty buffer slots
# Here we don't initalize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
def init_buffers(
self,
lora_weight_names: Set[Tuple[str]],
base_model: torch.nn.Module,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
for module_A, module_B in lora_weight_names:
# Init A tensor, column_major=False
input_dim, _ = get_hidden_dim(module_A, self.base_hf_config, base_model)
c = get_stacked_multiply(module_A)
if module_A not in self.A_buffer:
self.A_buffer[module_A] = [
torch.empty(
(
self.max_loras_per_batch,
self.max_lora_dim * c,
input_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(self.num_layer)
]
# Init B tensor, column_major=True
_, output_dim = get_hidden_dim(module_B, self.base_hf_config, base_model)
c = get_stacked_multiply(module_B)
if module_B not in self.B_buffer:
self.B_buffer[module_B] = [
torch.empty(
(
c, # stacked lora_b modules might need separation
self.max_loras_per_batch,
output_dim,
self.max_lora_dim,
),
dtype=self.dtype,
device="cuda",
)
for i in range(self.num_layer)
]
def prepare_lora_batch(
self,
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
return buffer_id, ""
for buffer_id in range(self.max_loras_per_batch):
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
return buffer_id, self.buffer_id_to_uid[buffer_id]
raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
)
for uid in cur_uids:
if uid not in self.uid_to_buffer_id:
buffer_id, evicted_lora_uid = get_available_buffer_slot()
if evicted_lora_uid != "":
self.uid_to_buffer_id.pop(evicted_lora_uid)
self.load_lora_weight_to_buffer(
uid, buffer_id, lora_adapters.get(uid, None)
)
self.uid_to_buffer_id[uid] = buffer_id
self.buffer_id_to_uid[buffer_id] = uid
def load_lora_weight_to_buffer(
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
):
if uid is None:
for i in range(self.num_layer):
for k in self.A_buffer.keys():
self.A_buffer[k][i][buffer_id] *= 0
return
assert lora_adapter is not None
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_A
)
if lora_weight_name:
self.A_buffer[lora_weight_name][layer_id][buffer_id].copy_(
weights
)
else:
lora_weight_name = get_weight_name(
name, self.lora_weight_names, LoRAType.LORA_B
)
if lora_weight_name:
c = get_stacked_multiply(lora_weight_name)
if c > 1:
for stacked_id in range(c):
self.B_buffer[lora_weight_name][layer_id][stacked_id][
buffer_id
].copy_(weights[stacked_id])
else:
self.B_buffer[lora_weight_name][layer_id][0][
buffer_id
].copy_(weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType
) -> torch.Tensor:
if lora_type == LoRAType.LORA_A:
return self.A_buffer[weight_name][layer_id]
return self.B_buffer[weight_name][layer_id]
def get_buffer_id(self, lora_uid: str):
return self.uid_to_buffer_id[lora_uid]
from .gate_up_lora_b import gate_up_lora_b_fwd
from .qkv_lora_b import qkv_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_a import sgemm_lora_a_fwd
from .sgemm_lora_b import sgemm_lora_b_fwd from .sgemm_lora_b import sgemm_lora_b_fwd
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"] __all__ = [
"gate_up_lora_b_fwd",
"qkv_lora_b_fwd",
"sgemm_lora_a_fwd",
"sgemm_lora_b_fwd",
]
import torch
import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
def _gate_up_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
K, # K = R
output_dim,
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# This kernel packs 2 sgemms (gate/up) into a single kernel.
# x: (s, 2 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, 2 * output_dim, K)
# output: (s, 2 * output_dim)
# output_dim >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# gate_up_id decides which of gate or up (0: gate, 1: up)
batch_id = tl.program_id(axis=2)
gate_up_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = gate_up_id * output_dim # offset on output dim
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(output_dim, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0 + (gate_up_id * K) * x_stride_1) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K)
and (n_offset[None, :] < output_dim),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < output_dim)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def gate_up_lora_b_fwd(
x: torch.Tensor,
gate_up_lora_b: torch.Tensor,
batch_info: LoRABatchInfo,
output_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, 2 * r)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
# output: (s, 2 * output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim] = sgemm(x[:, :r], gate_up_lora_b[:, :output_dim, :])
# lora_output[:, output_dim:]
# = sgemm(x[:, r:], gate_up_lora_b[:, output_dim:, :])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
r = gate_up_lora_b.shape[-1]
assert input_dim == 2 * r
BLOCK_S = 16
BLOCK_R = 16
BLOCK_OUT = 64
grid_b = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(output_dim, BLOCK_OUT),
2, # this dimension decides current block computes on gate or up proj
batch_info.bs,
)
if base_output is None:
output = torch.empty((s, 2 * output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_gate_up_lora_b_kernel[grid_b](
x,
gate_up_lora_b,
output,
r,
output_dim,
x.stride(0),
x.stride(1),
gate_up_lora_b.stride(0),
gate_up_lora_b.stride(1),
gate_up_lora_b.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit @triton.jit
...@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel( ...@@ -108,7 +108,7 @@ def _qkv_lora_b_kernel(
def qkv_lora_b_fwd( def qkv_lora_b_fwd(
x: torch.Tensor, x: torch.Tensor,
qkv_lora_b: torch.Tensor, qkv_lora_b: torch.Tensor,
batch_info: LoraBatchInfo, batch_info: LoRABatchInfo,
output_offset: torch.Tensor, output_offset: torch.Tensor,
max_qkv_out_dim: int, max_qkv_out_dim: int,
base_output: torch.Tensor = None, base_output: torch.Tensor = None,
...@@ -123,11 +123,11 @@ def qkv_lora_b_fwd( ...@@ -123,11 +123,11 @@ def qkv_lora_b_fwd(
# output: (s, output_dim_q + 2 * output_dim_kv) # output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows: # Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], ) # lora_output[:, :output_dim_q] = sgemm(x[:, :r], qkv_lora_b[:, :outptu_dim_q, :])
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv] # lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0]) # = sgemm(x[:, r: 2 * r], qkv_lora_b[:, outptu_dim_q: output_dim_q + output_dim_kv, :])
# lora_output[:, output_dim_q + output_dim_kv: ] # lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1]) # = sgemm(x[:, 2 * r: , qkv_lora_b[:, output_dim_q + output_dim_kv: , :])
# Get dims # Get dims
s = x.shape[0] s = x.shape[0]
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit @triton.jit
...@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel( ...@@ -91,7 +91,7 @@ def _sgemm_lora_a_kernel(
def sgemm_lora_a_fwd( def sgemm_lora_a_fwd(
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
) -> torch.Tensor: ) -> torch.Tensor:
# x: (s, input_dim) # x: (s, input_dim)
# weights: (num_lora, r, input_dim) # weights: (num_lora, r, input_dim)
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit @triton.jit
...@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel( ...@@ -98,7 +98,7 @@ def _sgemm_lora_b_kernel(
def sgemm_lora_b_fwd( def sgemm_lora_b_fwd(
x: torch.Tensor, x: torch.Tensor,
weights: torch.Tensor, weights: torch.Tensor,
batch_info: LoraBatchInfo, batch_info: LoRABatchInfo,
base_output: torch.Tensor = None, base_output: torch.Tensor = None,
scaling: float = 1.0, scaling: float = 1.0,
) -> torch.Tensor: ) -> torch.Tensor:
......
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Set, Tuple
import torch
from sglang.srt.hf_transformers_utils import AutoConfig
@dataclass
class LoRABatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class LoRAType(Enum):
LORA_A = 0
LORA_B = 1
def get_layer_id(name: str) -> int:
"""
Extract integer id of layer from its name in string.
"""
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
return None
return int(match.group(1))
def get_customized_names_from_hf_names(
hf_module_names: Set[str], base_model: torch.nn.Module
) -> Set[str]:
"""
This function takes in a set of huggingface style module names:
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
and outputs a set of module names of customized sglang layers:
e.g., {"qkv_proj", "o_proj"}
"""
if hasattr(base_model, "get_module_name"):
return {base_model.get_module_name(name) for name in hf_module_names}
else:
"""
Fallback solution of mapping from config module name to module name in model class.
Please check if it aligns with your base model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
"up_proj": "gate_up_proj",
}
return {params_mapping.get(name, name) for name in hf_module_names}
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
"""
if hasattr(base_model, "get_hidden_dim"):
return base_model.get_hidden_dim(module_name)
else:
"""
WARNING: get_hidden_dim() is not defined,
which is used to get the hidden dim for different lora modules
Use the default one, but please check if it is correct for your model.
Please implement the function in the model class if it is not.
You can reference this function in llama.py.
"""
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
return config.hidden_size, config.hidden_size
elif module_name in ["kv_proj"]:
return config.hidden_size, config.hidden_size // (
config.num_attention_heads // config.num_key_value_heads
)
elif module_name == "gate_up_proj":
return config.hidden_size, config.intermediate_size
elif module_name == "down_proj":
return config.intermediate_size, config.hidden_size
else:
raise NotImplementedError()
def get_stacked_name(name: str) -> Tuple[str]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
"""
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
}
return params_mapping.get(name, (name, name))
def get_stacked_multiply(module_name: str) -> int:
"""
Mapping a lora module name to its magnification at output dimension
"""
stacked_rank = {
"qkv_proj": 3,
"kv_proj": 2,
"gate_up_proj": 2,
}
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name(
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
) -> Optional[str]:
"""
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else return None
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names:
if weight_name_pair[idx] in target_name:
return weight_name_pair[idx]
...@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l ...@@ -22,7 +22,11 @@ from sglang.test.test_utils import calculate_rouge_l
LORA_SETS = [ LORA_SETS = [
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]}, {"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]} {
"base": "meta-llama/Llama-3.1-8B-Instruct",
"loras": ["reissbaker/llama-3.1-8b-abliterated-lora"],
"decode_tolerance": 8e-2,
},
] ]
TORCH_DTYPES = [torch.float16] TORCH_DTYPES = [torch.float16]
...@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase): ...@@ -128,7 +132,8 @@ class TestLoRABackend(unittest.TestCase):
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)), torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
) )
if hf_logprobs.shape[0] <= 100: if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), ( tol = lora_set.get("prefill_tolerance", prefill_tolerance)
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
f"prefill logprobs are not all close with model_path={base_path}," f"prefill logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"prefill_tolerance={prefill_tolerance}." f"prefill_tolerance={prefill_tolerance}."
...@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase): ...@@ -144,7 +149,8 @@ class TestLoRABackend(unittest.TestCase):
"\n", "\n",
) )
if hf_logprobs.shape[0] <= 100: if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), ( tol = lora_set.get("decode_tolerance", decode_tolerance)
assert torch.all(abs(hf_logprobs - srt_logprobs) < tol), (
f"decode logprobs are not all close with model_path={base_path}," f"decode logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}" f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"decode_tolerance={decode_tolerance}." f"decode_tolerance={decode_tolerance}."
...@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase): ...@@ -153,7 +159,7 @@ class TestLoRABackend(unittest.TestCase):
# compare output strings # compare output strings
srt_output_str = srt_outputs.output_strs[i].strip(" ") srt_output_str = srt_outputs.output_strs[i].strip(" ")
hf_output_str = hf_outputs.output_strs[i] hf_output_str = hf_outputs.output_strs[i].strip(" ")
print(f"srt_output_str={srt_output_str}") print(f"srt_output_str={srt_output_str}")
print(f"hf_output_str={hf_output_str}") print(f"hf_output_str={hf_output_str}")
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str]) rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment