"encoding/vscode:/vscode.git/clone" did not exist on "c2cb2aab69d5d276fbcb847fb8277c1a52947661"
Unverified Commit 70817a7e authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Feature] Define backends and add Triton backend for Lora (#3161)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 7b5a3741
import argparse
import os
NUM_LORAS = 8
NUM_LORAS = 4
LORA_PATH = {
"base": "mistralai/Mistral-7B-Instruct-v0.3",
"lora": "/home/ying/test_lora",
"base": "meta-llama/Llama-2-7b-hf",
"lora": "winddude/wizardLM-LlaMA-LoRA-7B",
}
......@@ -21,7 +21,8 @@ def launch_server(args):
cmd += f"{lora_name}={lora_path} "
cmd += f"--disable-radix --disable-cuda-graph "
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
cmd += f"--max-running-requests {args.max_running_requests}"
cmd += f"--max-running-requests {args.max_running_requests} "
cmd += f"--lora-backend {args.lora_backend}"
print(cmd)
os.system(cmd)
......@@ -42,6 +43,11 @@ if __name__ == "__main__":
type=int,
default=8,
)
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
)
args = parser.parse_args()
launch_server(args)
......@@ -183,6 +183,7 @@ async def benchmark(
api_url=api_url,
prompt_len=test_prompt_len,
output_len=test_output_len,
lora_name="dummy", # the lora_name argument will not be used
extra_request_body=extra_request_body,
)
test_output = await request_func(request_func_input=test_input)
......@@ -206,6 +207,7 @@ async def benchmark(
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
lora_name="dummy",
extra_request_body=extra_request_body,
)
tasks.append(
......@@ -255,6 +257,9 @@ async def benchmark(
"Output token throughput (tok/s):", metrics.output_throughput
)
)
print(
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
)
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
......
......@@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929).
* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model.
* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`.
## Kernel backend
......
from .base_backend import BaseLoraBackend
from .flashinfer_backend import FlashInferLoraBackend
from .triton_backend import TritonLoraBackend
__all__ = [
"FlashInferLoraBackend",
"TritonLoraBackend",
]
from typing import Tuple, Union
import torch
from sglang.srt.lora.lora import LoraBatchInfo
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
class BaseLoraBackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of scaling and adding will be fused into kernel
"""
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
usually input_dim is much larger than r
Returns:
result with shape (s, r)
"""
pass
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
pass
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
result with shape (s, output_dim_q + 2 * output_dim_kv)
"""
pass
def set_batch_info(self, batch_info: LoraBatchInfo):
self.batch_info = batch_info
from typing import Tuple
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
class FlashInferLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)
# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)
lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
return lora_output
import torch
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.triton_ops import (
qkv_lora_b_fwd,
sgemm_lora_a_fwd,
sgemm_lora_b_fwd,
)
class TritonLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return sgemm_lora_a_fwd(x, weights, self.batch_info)
def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
lora_output = qkv_lora_b_fwd(
lora_a_output,
qkv_lora_b,
self.batch_info,
output_offset,
max_qkv_out_dim,
base_output,
scaling,
)
return lora_output
......@@ -18,8 +18,8 @@
# LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import re
from dataclasses import dataclass
import torch
from torch import nn
......@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_loader.loader import DefaultModelLoader
@dataclass
class LoraBatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling):
def __init__(self, base_layer, lora_rank, scaling, lora_backend):
super().__init__()
self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank
self.scaling = scaling
self.set_lora = False
self.lora_backend = lora_backend
def forward(self, x: torch.Tensor):
return self.base_layer.forward(x)
......@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling
self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO
......@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(
self,
A_buffer,
B_buffer,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
output_dim = base_output.shape[-1]
lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2
for i in range(2):
left = output_dim * i
right = left + output_dim
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * i : self.lora_rank * (i + 1)
].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
weights=self.B_buffer[0],
)
lora_output[:, output_dim : 2 * output_dim] = (
self.lora_backend.run_lora_b_sgemm(
x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
weights=self.B_buffer[1],
)
)
return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling
def init__(
self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices
self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
):
self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
if self.lora_backend.fuse_qkv_lora_b:
assert (
B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer_qkv,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_output = self.lora_backend.run_qkv_lora(
x,
self.A_buffer_qkv,
self.B_buffer_qkv,
output_offset=self.output_offset,
max_qkv_out_dim=self.max_qkv_out_dim,
base_output=base_output,
scaling=self.scaling,
)
# FIXME parallelize qkv
lora_output = torch.empty_like(base_output)
# q
output_dim_q = self.B_buffer_q.shape[-2]
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling
self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling)
super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices):
def set_lora_info(self, A_buffer, B_buffer):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.segment_gemm.run(
x=x,
weights=self.A_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
lora_a_output,
self.B_buffer[0],
base_output=base_output,
scaling=self.scaling,
)
lora_output = self.segment_gemm.run(
x=lora_output,
weights=self.B_buffer,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
return (
lora_output
if self.lora_backend.fuse_output_scaling_add
else base_output + lora_output * self.scaling
)
return base_output + lora_output * self.scaling
def forward(self, input_):
# duplicate the logic in RowParallelLinear
......@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling
layer: nn.Module, lora_rank, scaling, lora_backend
) -> BaseLayerWithLoRA:
supported_layer_types = {
# the order matters
......@@ -267,7 +276,7 @@ def get_lora_layer(
}
for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling)
ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
......@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config):
def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
super().__init__()
self.uid = uid
self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config
self.load_config = load_config
self.lora_backend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList(
......@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
else:
layer.weights[kv_name] = torch.cat(
(
layer.weights[kv_name] = torch.stack(
[
layer.weights[weight_name],
layer.weights[v_name],
),
0,
],
dim=0,
)
layer.weights.pop(weight_name)
layer.weights.pop(v_name)
elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
if "lora_A" in weight_name:
layer.weights[gate_up_name] = torch.cat(
(layer.weights[weight_name], layer.weights[up_name]), 0
)
else:
layer.weights[gate_up_name] = torch.stack(
[layer.weights[weight_name], layer.weights[up_name]], dim=0
)
layer.weights.pop(weight_name)
layer.weights.pop(up_name)
......@@ -20,16 +20,14 @@ import re
import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available, replace_submodule
logger = logging.getLogger(__name__)
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class.
......@@ -77,6 +75,20 @@ def get_stacked_name(name):
return params_mapping.get(name, (name, name))
def get_backend_from_name(name):
backend_mapping = {
"triton": TritonLoraBackend,
"flashinfer": FlashInferLoraBackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name)
if match is None:
......@@ -93,6 +105,7 @@ class LoRAManager:
max_loras_per_batch,
load_config,
dtype,
lora_backend,
):
self.base_model = base_model
self.lora_paths = lora_paths
......@@ -101,8 +114,9 @@ class LoRAManager:
self.load_config = load_config
self.dtype = dtype
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
logger.info(f"Using {lora_backend} as backend of Lora kernels.")
backend_type = get_backend_from_name(lora_backend)
self.lora_backend = backend_type(lora_backend)
self.init_loras()
self.init_lora_memory_pool()
......@@ -123,7 +137,7 @@ class LoRAManager:
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(
module, self.segment_gemm, self.max_lora_dim, self.scaling
module, self.max_lora_dim, self.scaling, self.lora_backend
)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
......@@ -162,7 +176,11 @@ class LoRAManager:
self.lora_id[name] = len(self.loras)
self.loras.append(
LoRAAdapter(
name, self.configs[name], self.base_hf_config, self.load_config
name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
)
self.loras[-1].initialize_weights()
......@@ -226,8 +244,9 @@ class LoRAManager:
self.B_buffer[module_B] = [
torch.empty(
(
c,
self.max_loras_per_batch,
hidden_dim_B * c,
hidden_dim_B,
self.max_lora_dim,
),
dtype=self.dtype,
......@@ -263,7 +282,16 @@ class LoRAManager:
else:
lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights)
c = self.loras[-1].get_stacked_multiply(lora_weight_name)
if c > 1:
for j in range(c):
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
weights[j]
)
else:
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
weights
)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
......@@ -292,20 +320,30 @@ class LoRAManager:
if cur_uids == set([None]):
return
# setup lora in forward modules
# set up batch info shared by all lora moruldes
bs = forward_batch.batch_size
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device="cuda")
)
# FIXME: reuse the data rather than recompute
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path]
batch_info = LoraBatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules
for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name)
......@@ -314,16 +352,10 @@ class LoRAManager:
module.set_lora_info(
self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id],
bs,
seg_indptr,
weight_indices,
)
else:
module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id],
bs,
seg_indptr,
weight_indices,
)
from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd
from .sgemm_lora_b import sgemm_lora_b_fwd
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _qkv_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
K, # K = R
max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Offsets of q/k/v slice on output dimension
n_offs,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, N_Q + 2 * N_KV, K)
# output: (s, N_Q + 2 * N_KV)
# N_Q >> K, N_KV >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
batch_id = tl.program_id(axis=2)
qkv_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = tl.load(n_offs + qkv_id)
n_size = tl.load(n_offs + qkv_id + 1) - n_start
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def qkv_lora_b_fwd(
x: torch.Tensor,
qkv_lora_b: torch.Tensor,
batch_info: LoraBatchInfo,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, 3 * r)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
# output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
# output_dim_q + 2 * output_dim_kv]
# max_qkv_out_dim = max(output_dim_q, output_dim_kv)
# output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
# lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
r = qkv_lora_b.shape[-1]
output_dim = qkv_lora_b.shape[-2]
assert input_dim == 3 * r
assert output_offset.shape[0] == 4
BLOCK_S = 16
BLOCK_R = 16
BLOCK_OUT = 64
grid_b = (
triton.cdiv(batch_info.max_len, BLOCK_S)
* triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
3, # this dimension decides current block computes on q, k or v
batch_info.bs,
)
if base_output is None:
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_qkv_lora_b_kernel[grid_b](
x,
qkv_lora_b,
output,
r,
max_qkv_out_dim,
x.stride(0),
x.stride(1),
qkv_lora_b.stride(0),
qkv_lora_b.stride(1),
qkv_lora_b.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
output_offset,
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _sgemm_lora_a_kernel(
# Pointers to matrices
x,
weights,
output,
# Matrix dimensions
N, # r
K, # input_dim
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
tl.store(output_ptr, partial_sum, mask=output_mask)
def sgemm_lora_a_fwd(
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
# output: (s, r)
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
# input_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
S = x.shape[0]
R = weights.shape[-2]
K = weights.shape[-1]
assert x.shape[-1] == K
# Block shapes
BLOCK_S = 16
BLOCK_K = 256
BLOCK_R = 16
grid = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
batch_info.bs,
)
output = torch.empty((S, R), device=x.device, dtype=x.dtype)
_sgemm_lora_a_kernel[grid](
x,
weights,
output,
R,
K,
x.stride(0),
x.stride(1),
weights.stride(0),
weights.stride(1),
weights.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_R,
BLOCK_K,
)
return output
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _sgemm_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Matrix dimensions
N, # output_dim
K, # r
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = s_offset[:, None] < seg_len
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def sgemm_lora_b_fwd(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoraBatchInfo,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, r)
# weights: (num_lora, output_dim, r)
# output: (s, output_dim)
# output_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
S = x.shape[0]
N = weights.shape[-2]
R = weights.shape[-1]
assert x.shape[-1] == R
# Block shapes
BLOCK_S = 16
BLOCK_R = 16
BLOCK_N = 256
grid = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
batch_info.bs,
)
if base_output is None:
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_sgemm_lora_b_kernel[grid](
x,
weights,
output,
N,
R,
x.stride(0),
x.stride(1),
weights.stride(0),
weights.stride(1),
weights.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_N,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output
......@@ -530,6 +530,7 @@ class ModelRunner:
max_loras_per_batch=self.server_args.max_loras_per_batch,
load_config=self.load_config,
dtype=self.dtype,
lora_backend=self.server_args.lora_backend,
)
logger.info("LoRA manager ready.")
......
......@@ -113,6 +113,7 @@ class ServerArgs:
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
# Kernel backend
attention_backend: Optional[str] = None
......@@ -653,13 +654,19 @@ class ServerArgs:
nargs="*",
default=None,
action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
help="Maximum number of adapters for a running batch, include base-only request.",
)
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
help="Choose the kernel backend for multi-LoRA serving.",
)
# Kernel backend
......
......@@ -272,6 +272,7 @@ class SRTRunner:
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None,
max_loras_per_batch: int = 4,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
disable_radix_cache: bool = False,
):
......@@ -287,6 +288,7 @@ class SRTRunner:
is_embedding=not self.is_generation,
lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch,
lora_backend=lora_backend,
disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache,
)
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
LORA_SETS = [
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
]
TORCH_DTYPES = [torch.float16]
PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
]
BACKENDS = ["triton", "flashinfer"]
prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
class TestLoRABackend(unittest.TestCase):
def run_backend(
self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens, backend
):
print(f"=================== testing {backend} backend =======================")
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = []
i = 0
for _ in range(len(prompts)):
batch_lora_paths.append(all_lora_paths[i])
i = (i + 1) % len(all_lora_paths)
print(f"batch lora paths={batch_lora_paths}")
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=tp_size,
lora_paths=all_lora_paths,
max_loras_per_batch=3,
lora_backend=backend,
disable_cuda_graph=True,
disable_radix_cache=True,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
for i in range(len(prompts)):
print(f"Prompt {i} with lora path {batch_lora_paths[i]}:")
# compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i])
srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i]
)
print(
"max input diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and srt_lora",
torch.max(abs(srt_no_lora_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and hf_base",
torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)),
)
print(
"max input diff between hf_lora and hf_base",
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"max output diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
"\n",
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# compare output strings
srt_output_str = srt_outputs.output_strs[i].strip(" ")
hf_output_str = hf_outputs.output_strs[i]
print(f"srt_output_str={srt_output_str}")
print(f"hf_output_str={hf_output_str}")
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
print(f"{rouge_l_scores=}")
assert (
rouge_l_scores[0] >= rouge_l_tolerance
), f"ROUGE-L scores of prompt {i} outputs are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_all(self):
for lora_set in LORA_SETS:
print(f"Testing lora set {lora_set}: ")
for torch_dtype in TORCH_DTYPES:
tp_size = 1
max_new_tokens = 32
for backend in BACKENDS:
self.run_backend(
PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens, backend
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
......@@ -8,6 +8,7 @@ suites = {
"models/test_embedding_models.py",
"models/test_generation_models.py",
"models/test_lora.py",
"models/test_lora_backend.py",
"models/test_qwen_models.py",
"models/test_reward_models.py",
"sampling/penaltylib",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment