Unverified Commit 70817a7e authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

[Feature] Define backends and add Triton backend for Lora (#3161)


Co-authored-by: default avatarYing Sheng <sqy1415@gmail.com>
parent 7b5a3741
import argparse import argparse
import os import os
NUM_LORAS = 8 NUM_LORAS = 4
LORA_PATH = { LORA_PATH = {
"base": "mistralai/Mistral-7B-Instruct-v0.3", "base": "meta-llama/Llama-2-7b-hf",
"lora": "/home/ying/test_lora", "lora": "winddude/wizardLM-LlaMA-LoRA-7B",
} }
...@@ -21,7 +21,8 @@ def launch_server(args): ...@@ -21,7 +21,8 @@ def launch_server(args):
cmd += f"{lora_name}={lora_path} " cmd += f"{lora_name}={lora_path} "
cmd += f"--disable-radix --disable-cuda-graph " cmd += f"--disable-radix --disable-cuda-graph "
cmd += f"--max-loras-per-batch {args.max_loras_per_batch} " cmd += f"--max-loras-per-batch {args.max_loras_per_batch} "
cmd += f"--max-running-requests {args.max_running_requests}" cmd += f"--max-running-requests {args.max_running_requests} "
cmd += f"--lora-backend {args.lora_backend}"
print(cmd) print(cmd)
os.system(cmd) os.system(cmd)
...@@ -42,6 +43,11 @@ if __name__ == "__main__": ...@@ -42,6 +43,11 @@ if __name__ == "__main__":
type=int, type=int,
default=8, default=8,
) )
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
)
args = parser.parse_args() args = parser.parse_args()
launch_server(args) launch_server(args)
...@@ -183,6 +183,7 @@ async def benchmark( ...@@ -183,6 +183,7 @@ async def benchmark(
api_url=api_url, api_url=api_url,
prompt_len=test_prompt_len, prompt_len=test_prompt_len,
output_len=test_output_len, output_len=test_output_len,
lora_name="dummy", # the lora_name argument will not be used
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
test_output = await request_func(request_func_input=test_input) test_output = await request_func(request_func_input=test_input)
...@@ -206,6 +207,7 @@ async def benchmark( ...@@ -206,6 +207,7 @@ async def benchmark(
api_url=api_url, api_url=api_url,
prompt_len=prompt_len, prompt_len=prompt_len,
output_len=output_len, output_len=output_len,
lora_name="dummy",
extra_request_body=extra_request_body, extra_request_body=extra_request_body,
) )
tasks.append( tasks.append(
...@@ -255,6 +257,9 @@ async def benchmark( ...@@ -255,6 +257,9 @@ async def benchmark(
"Output token throughput (tok/s):", metrics.output_throughput "Output token throughput (tok/s):", metrics.output_throughput
) )
) )
print(
"{:<40} {:<10.2f}".format("Total throughput (tok/s):", metrics.total_throughput)
)
print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
print( print(
"{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
......
...@@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma ...@@ -124,6 +124,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). * `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929).
* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. * `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model.
* `lora_backend`: The backend of running GEMM kernels for Lora modules, can be one of `triton` or `flashinfer`. Defaults to be `triton`.
## Kernel backend ## Kernel backend
......
from .base_backend import BaseLoraBackend
from .flashinfer_backend import FlashInferLoraBackend
from .triton_backend import TritonLoraBackend
__all__ = [
"FlashInferLoraBackend",
"TritonLoraBackend",
]
from typing import Tuple, Union
import torch
from sglang.srt.lora.lora import LoraBatchInfo
def get_fuse_output_scaling_add_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
def get_fuse_qkv_lora_b_from_name(name: str) -> bool:
mapping = {
"triton": True,
"flashinfer": False,
}
return mapping.get(name, False)
class BaseLoraBackend:
"""Base class for different Lora backends.
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
and the operation of scaling and adding will be fused into kernel
"""
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
self.name = name
self.batch_info = batch_info
self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
self.fuse_qkv_lora_b = get_fuse_qkv_lora_b_from_name(name)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora a modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
weights: a set of lora weights with shape (num_lora, r, input_dim), here r is lora rank
usually input_dim is much larger than r
Returns:
result with shape (s, r)
"""
pass
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
"""Run segment Gemm of lora b modules with current backend.
The definition of segment Gemm can be referred to https://docs.flashinfer.ai/api/gemm.html.
Args:
x: input matrix with shape (s, r), here s is the sum of all sequence lengths, r is lora rank
weights: a set of lora weights with shape (num_lora, output_dim, r)
usually output_dim is much larger than r
Returns:
result with shape (s, output_dim)
"""
pass
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
*args,
**kwargs
) -> torch.Tensor:
"""Run the lora pass for QKV Layer.
Args:
x: input matrix with shape (s, input_dim), here s is the sum of all sequence lengths
qkv_lora_a: lora_a module for qkv, with shape (num_lora, 3 * r, input_dim)
qkv_lora_b: lora_b module for qkv.
If passed in as a tensor, its shape should be (num_lora,output_dim_q + 2 * output_dim_kv, r)
If passed in as a tuple of two tensors containing:
a lora_b module for q, with shape (1, num_lora, output_dim_q, r)
and a combined lora_b module for kv, with shape (2, num_lora, output_dim_kv, r)
Returns:
result with shape (s, output_dim_q + 2 * output_dim_kv)
"""
pass
def set_batch_info(self, batch_info: LoraBatchInfo):
self.batch_info = batch_info
from typing import Tuple
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
class FlashInferLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)
# Set up SGemm Wrapper from flashinfer
# FIXME wait for flashinfer segment gemm update
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_lora_b_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return self.segment_gemm.run(
x=x,
weights=weights,
batch_size=self.batch_info.bs,
weight_column_major=True,
seg_indptr=self.batch_info.seg_indptr,
weight_indices=self.batch_info.weight_indices,
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: Tuple[torch.Tensor],
*args,
**kwargs,
) -> torch.Tensor:
# Shape of lora_a_output: (s, 3 * r)
lora_a_output = self.run_lora_a_sgemm(x=x, weights=qkv_lora_a)
q_lora_b, kv_lora_b = qkv_lora_b
lora_rank = kv_lora_b.shape[-1]
output_dim_q = q_lora_b.shape[-2]
output_dim_kv = kv_lora_b.shape[-2]
lora_output = torch.empty(
(x.shape[0], output_dim_q + 2 * output_dim_kv),
device=x.device,
dtype=x.dtype,
)
# q
lora_output[:, :output_dim_q] = self.run_lora_b_sgemm(
x=lora_a_output[:, :lora_rank].contiguous(), weights=q_lora_b[0]
)
# kv
lora_output[:, output_dim_q : output_dim_q + output_dim_kv] = (
self.run_lora_b_sgemm(
x=lora_a_output[:, lora_rank : 2 * lora_rank].contiguous(),
weights=kv_lora_b[0],
)
)
lora_output[
:, output_dim_q + output_dim_kv : output_dim_q + 2 * output_dim_kv
] = self.run_lora_b_sgemm(
x=lora_a_output[:, 2 * lora_rank : 3 * lora_rank].contiguous(),
weights=kv_lora_b[1],
)
return lora_output
import torch
from sglang.srt.lora.backend import BaseLoraBackend
from sglang.srt.lora.lora import LoraBatchInfo
from sglang.srt.lora.triton_ops import (
qkv_lora_b_fwd,
sgemm_lora_a_fwd,
sgemm_lora_b_fwd,
)
class TritonLoraBackend(BaseLoraBackend):
def __init__(self, name: str, batch_info: LoraBatchInfo = None):
super().__init__(name, batch_info)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return sgemm_lora_a_fwd(x, weights, self.batch_info)
def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)
lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
lora_output = qkv_lora_b_fwd(
lora_a_output,
qkv_lora_b,
self.batch_info,
output_offset,
max_qkv_out_dim,
base_output,
scaling,
)
return lora_output
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
# LoRA layers class inheritance adapted from: # LoRA layers class inheritance adapted from:
# https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py
import re import re
from dataclasses import dataclass
import torch import torch
from torch import nn from torch import nn
...@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding ...@@ -34,14 +34,32 @@ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
@dataclass
class LoraBatchInfo:
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Indice pointers of each sequence in shape (bs + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor
class BaseLayerWithLoRA(nn.Module): class BaseLayerWithLoRA(nn.Module):
def __init__(self, base_layer, segment_gemm, lora_rank, scaling): def __init__(self, base_layer, lora_rank, scaling, lora_backend):
super().__init__() super().__init__()
self.base_layer = base_layer self.base_layer = base_layer
self.segment_gemm = segment_gemm
self.lora_rank = lora_rank self.lora_rank = lora_rank
self.scaling = scaling self.scaling = scaling
self.set_lora = False self.set_lora = False
self.lora_backend = lora_backend
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
return self.base_layer.forward(x) return self.base_layer.forward(x)
...@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module): ...@@ -52,17 +70,17 @@ class BaseLayerWithLoRA(nn.Module):
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__( def __init__(
self, base_layer: VocabParallelEmbedding, segment_gemm, lora_rank, scaling self, base_layer: VocabParallelEmbedding, lora_rank, scaling, lora_backend
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, lora_rank, scaling, lora_backend)
self.weight = base_layer.weight self.weight = base_layer.weight
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__( def __init__(
self, base_layer: ColumnParallelLinear, segment_gemm, lora_rank, scaling self, base_layer: ColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, lora_rank, scaling, lora_backend)
def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# TODO # TODO
...@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -88,136 +106,127 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__( def __init__(
self, base_layer: MergedColumnParallelLinear, segment_gemm, lora_rank, scaling self, base_layer: MergedColumnParallelLinear, lora_rank, scaling, lora_backend
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): def set_lora_info(
self,
A_buffer,
B_buffer,
):
self.set_lora = True self.set_lora = True
self.A_buffer = A_buffer self.A_buffer = A_buffer
self.B_buffer = B_buffer self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run( lora_a_output = self.lora_backend.run_lora_a_sgemm(x=x, weights=self.A_buffer)
x=x,
weights=self.A_buffer, output_dim = base_output.shape[-1]
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
# FIXME
lora_output = torch.empty_like(base_output) lora_output = torch.empty_like(base_output)
output_dim = lora_output.shape[-1] // 2 lora_output[:, :output_dim] = self.lora_backend.run_lora_b_sgemm(
for i in range(2): x=lora_a_output[:, 0 : self.lora_rank].contiguous(),
left = output_dim * i weights=self.B_buffer[0],
right = left + output_dim )
lora_output[:, left:right] = self.segment_gemm.run(
x=lora_a_output[ lora_output[:, output_dim : 2 * output_dim] = (
:, self.lora_rank * i : self.lora_rank * (i + 1) self.lora_backend.run_lora_b_sgemm(
].contiguous(), x=lora_a_output[:, self.lora_rank : 2 * self.lora_rank].contiguous(),
weights=self.B_buffer[:, left:right, :].contiguous(), weights=self.B_buffer[1],
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
) )
)
return base_output + lora_output * self.scaling return base_output + lora_output * self.scaling
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__( def init__(
self, base_layer: QKVParallelLinear, segment_gemm, lora_rank, scaling self, base_layer: QKVParallelLinear, lora_rank, scaling, lora_backend
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info( def set_lora_info(
self, A_buffer_qkv, B_buffer_q, B_buffer_kv, bs, seg_indptr, weight_indices self,
A_buffer_qkv,
B_buffer_q,
B_buffer_kv,
): ):
self.set_lora = True self.set_lora = True
self.A_buffer_qkv = A_buffer_qkv self.A_buffer_qkv = A_buffer_qkv
self.B_buffer_q = B_buffer_q
self.B_buffer_kv = B_buffer_kv if self.lora_backend.fuse_qkv_lora_b:
self.bs = bs assert (
self.seg_indptr = seg_indptr B_buffer_q.shape[-1] == B_buffer_kv.shape[-1]
self.weight_indices = weight_indices ), "The lora rank of q and kv should be the same when enabling fusion of qkv lora_b"
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
self.B_buffer_qkv = torch.cat(
(B_buffer_q[0], B_buffer_kv[0], B_buffer_kv[1]), dim=-2
).contiguous()
# Offsets of q/k/v in output dimension
self.output_offset = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else:
self.B_buffer_qkv = (
B_buffer_q,
B_buffer_kv,
)
self.output_offset = None
self.max_qkv_out_dim = None
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.segment_gemm.run( lora_output = self.lora_backend.run_qkv_lora(
x=x, x,
weights=self.A_buffer_qkv, self.A_buffer_qkv,
batch_size=self.bs, self.B_buffer_qkv,
weight_column_major=True, output_offset=self.output_offset,
seg_indptr=self.seg_indptr, max_qkv_out_dim=self.max_qkv_out_dim,
weight_indices=self.weight_indices, base_output=base_output,
scaling=self.scaling,
) )
# FIXME parallelize qkv return (
lora_output = torch.empty_like(base_output) lora_output
# q if self.lora_backend.fuse_output_scaling_add
output_dim_q = self.B_buffer_q.shape[-2] else base_output + lora_output * self.scaling
lora_output[:, :output_dim_q] = self.segment_gemm.run(
x=lora_a_output[:, : self.lora_rank].contiguous(),
weights=self.B_buffer_q,
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
) )
# kv
output_dim_kv = self.B_buffer_kv.shape[-2] // 2
for i in range(2):
left = output_dim_kv * i
right = left + output_dim_kv
lora_output[:, output_dim_q + left : output_dim_q + right] = (
self.segment_gemm.run(
x=lora_a_output[
:, self.lora_rank * (i + 1) : self.lora_rank * (i + 2)
].contiguous(),
weights=self.B_buffer_kv[:, left:right, :].contiguous(),
batch_size=self.bs,
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
)
)
return base_output + lora_output * self.scaling
class RowParallelLinearWithLoRA(BaseLayerWithLoRA): class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def __init__( def __init__(
self, base_layer: RowParallelLinear, segment_gemm, lora_rank, scaling self, base_layer: RowParallelLinear, lora_rank, scaling, lora_backend
) -> None: ) -> None:
super().__init__(base_layer, segment_gemm, lora_rank, scaling) super().__init__(base_layer, lora_rank, scaling, lora_backend)
def set_lora_info(self, A_buffer, B_buffer, bs, seg_indptr, weight_indices): def set_lora_info(self, A_buffer, B_buffer):
self.set_lora = True self.set_lora = True
self.A_buffer = A_buffer self.A_buffer = A_buffer
self.B_buffer = B_buffer self.B_buffer = B_buffer
self.bs = bs
self.seg_indptr = seg_indptr
self.weight_indices = weight_indices
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.segment_gemm.run( lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
x=x, lora_output = self.lora_backend.run_lora_b_sgemm(
weights=self.A_buffer, lora_a_output,
batch_size=self.bs, self.B_buffer[0],
weight_column_major=True, base_output=base_output,
seg_indptr=self.seg_indptr, scaling=self.scaling,
weight_indices=self.weight_indices,
) )
lora_output = self.segment_gemm.run( return (
x=lora_output, lora_output
weights=self.B_buffer, if self.lora_backend.fuse_output_scaling_add
batch_size=self.bs, else base_output + lora_output * self.scaling
weight_column_major=True,
seg_indptr=self.seg_indptr,
weight_indices=self.weight_indices,
) )
return base_output + lora_output * self.scaling
def forward(self, input_): def forward(self, input_):
# duplicate the logic in RowParallelLinear # duplicate the logic in RowParallelLinear
...@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -255,7 +264,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def get_lora_layer( def get_lora_layer(
layer: nn.Module, segment_gemm, lora_rank, scaling layer: nn.Module, lora_rank, scaling, lora_backend
) -> BaseLayerWithLoRA: ) -> BaseLayerWithLoRA:
supported_layer_types = { supported_layer_types = {
# the order matters # the order matters
...@@ -267,7 +276,7 @@ def get_lora_layer( ...@@ -267,7 +276,7 @@ def get_lora_layer(
} }
for src_layer_type, lora_layer_type in supported_layer_types.items(): for src_layer_type, lora_layer_type in supported_layer_types.items():
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
ret = lora_layer_type(layer, segment_gemm, lora_rank, scaling) ret = lora_layer_type(layer, lora_rank, scaling, lora_backend)
return ret return ret
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.") raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
...@@ -297,13 +306,14 @@ class LoRALayer(nn.Module): ...@@ -297,13 +306,14 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module): class LoRAAdapter(nn.Module):
def __init__(self, uid, config, base_hf_config, load_config): def __init__(self, uid, config, base_hf_config, load_config, lora_backend):
super().__init__() super().__init__()
self.uid = uid self.uid = uid
self.config = config self.config = config
assert self.config.hf_config["peft_type"].lower() == "lora" assert self.config.hf_config["peft_type"].lower() == "lora"
self.base_hf_config = base_hf_config self.base_hf_config = base_hf_config
self.load_config = load_config self.load_config = load_config
self.lora_backend = lora_backend
self.scaling = self.config.lora_alpha / self.config.r self.scaling = self.config.lora_alpha / self.config.r
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
...@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module): ...@@ -376,20 +386,25 @@ class LoRAAdapter(nn.Module):
layer.weights.pop(weight_name) layer.weights.pop(weight_name)
layer.weights.pop(v_name) layer.weights.pop(v_name)
else: else:
layer.weights[kv_name] = torch.cat( layer.weights[kv_name] = torch.stack(
( [
layer.weights[weight_name], layer.weights[weight_name],
layer.weights[v_name], layer.weights[v_name],
), ],
0, dim=0,
) )
layer.weights.pop(weight_name) layer.weights.pop(weight_name)
layer.weights.pop(v_name) layer.weights.pop(v_name)
elif "gate_proj" in weight_name: elif "gate_proj" in weight_name:
up_name = weight_name.replace("gate_proj", "up_proj") up_name = weight_name.replace("gate_proj", "up_proj")
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
layer.weights[gate_up_name] = torch.cat( if "lora_A" in weight_name:
(layer.weights[weight_name], layer.weights[up_name]), 0 layer.weights[gate_up_name] = torch.cat(
) (layer.weights[weight_name], layer.weights[up_name]), 0
)
else:
layer.weights[gate_up_name] = torch.stack(
[layer.weights[weight_name], layer.weights[up_name]], dim=0
)
layer.weights.pop(weight_name) layer.weights.pop(weight_name)
layer.weights.pop(up_name) layer.weights.pop(up_name)
...@@ -20,16 +20,14 @@ import re ...@@ -20,16 +20,14 @@ import re
import torch import torch
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.backend import FlashInferLoraBackend, TritonLoraBackend
from sglang.srt.lora.lora import LoRAAdapter, LoraBatchInfo, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_flashinfer_available, replace_submodule from sglang.srt.utils import is_flashinfer_available, replace_submodule
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if is_flashinfer_available():
from flashinfer import SegmentGEMMWrapper
def get_module_name(name): def get_module_name(name):
# Fallback solution of mapping from config module name to module name in model class. # Fallback solution of mapping from config module name to module name in model class.
...@@ -77,6 +75,20 @@ def get_stacked_name(name): ...@@ -77,6 +75,20 @@ def get_stacked_name(name):
return params_mapping.get(name, (name, name)) return params_mapping.get(name, (name, name))
def get_backend_from_name(name):
backend_mapping = {
"triton": TritonLoraBackend,
"flashinfer": FlashInferLoraBackend,
}
if name in backend_mapping:
return backend_mapping[name]
raise Exception(
f"No supported lora backend called {name}. It should be one of {list(backend_mapping.keys())}"
)
def get_layer_id(name): def get_layer_id(name):
match = re.search(r"layers\.(\d+)\.", name) match = re.search(r"layers\.(\d+)\.", name)
if match is None: if match is None:
...@@ -93,6 +105,7 @@ class LoRAManager: ...@@ -93,6 +105,7 @@ class LoRAManager:
max_loras_per_batch, max_loras_per_batch,
load_config, load_config,
dtype, dtype,
lora_backend,
): ):
self.base_model = base_model self.base_model = base_model
self.lora_paths = lora_paths self.lora_paths = lora_paths
...@@ -101,8 +114,9 @@ class LoRAManager: ...@@ -101,8 +114,9 @@ class LoRAManager:
self.load_config = load_config self.load_config = load_config
self.dtype = dtype self.dtype = dtype
workspace_buffer = torch.empty(1 * 1024 * 1024, dtype=torch.int8, device="cuda") logger.info(f"Using {lora_backend} as backend of Lora kernels.")
self.segment_gemm = SegmentGEMMWrapper(workspace_buffer) backend_type = get_backend_from_name(lora_backend)
self.lora_backend = backend_type(lora_backend)
self.init_loras() self.init_loras()
self.init_lora_memory_pool() self.init_lora_memory_pool()
...@@ -123,7 +137,7 @@ class LoRAManager: ...@@ -123,7 +137,7 @@ class LoRAManager:
def set_lora_module(self, module_name, module): def set_lora_module(self, module_name, module):
lora_module = get_lora_layer( lora_module = get_lora_layer(
module, self.segment_gemm, self.max_lora_dim, self.scaling module, self.max_lora_dim, self.scaling, self.lora_backend
) )
replace_submodule(self.base_model, module_name, lora_module) replace_submodule(self.base_model, module_name, lora_module)
return lora_module return lora_module
...@@ -162,7 +176,11 @@ class LoRAManager: ...@@ -162,7 +176,11 @@ class LoRAManager:
self.lora_id[name] = len(self.loras) self.lora_id[name] = len(self.loras)
self.loras.append( self.loras.append(
LoRAAdapter( LoRAAdapter(
name, self.configs[name], self.base_hf_config, self.load_config name,
self.configs[name],
self.base_hf_config,
self.load_config,
self.lora_backend,
) )
) )
self.loras[-1].initialize_weights() self.loras[-1].initialize_weights()
...@@ -226,8 +244,9 @@ class LoRAManager: ...@@ -226,8 +244,9 @@ class LoRAManager:
self.B_buffer[module_B] = [ self.B_buffer[module_B] = [
torch.empty( torch.empty(
( (
c,
self.max_loras_per_batch, self.max_loras_per_batch,
hidden_dim_B * c, hidden_dim_B,
self.max_lora_dim, self.max_lora_dim,
), ),
dtype=self.dtype, dtype=self.dtype,
...@@ -263,7 +282,16 @@ class LoRAManager: ...@@ -263,7 +282,16 @@ class LoRAManager:
else: else:
lora_weight_name = self.get_weight_name(name, 1) lora_weight_name = self.get_weight_name(name, 1)
if lora_weight_name: if lora_weight_name:
self.B_buffer[lora_weight_name][i][buffer_id].copy_(weights) c = self.loras[-1].get_stacked_multiply(lora_weight_name)
if c > 1:
for j in range(c):
self.B_buffer[lora_weight_name][i][j][buffer_id].copy_(
weights[j]
)
else:
self.B_buffer[lora_weight_name][i][0][buffer_id].copy_(
weights
)
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool # load active loras into lora memory pool
...@@ -292,20 +320,30 @@ class LoRAManager: ...@@ -292,20 +320,30 @@ class LoRAManager:
if cur_uids == set([None]): if cur_uids == set([None]):
return return
# setup lora in forward modules # set up batch info shared by all lora moruldes
bs = forward_batch.batch_size bs = forward_batch.batch_size
seg_lens = ( seg_lens = (
forward_batch.extend_seq_lens forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend() if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device="cuda") else torch.ones(bs, device="cuda")
) )
# FIXME: reuse the data rather than recompute
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
max_len = int(torch.max(seg_lens))
weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda") weight_indices = torch.empty((bs,), dtype=torch.int64, device="cuda")
for i, lora_path in enumerate(forward_batch.lora_paths): for i, lora_path in enumerate(forward_batch.lora_paths):
weight_indices[i] = self.buffer_id[lora_path] weight_indices[i] = self.buffer_id[lora_path]
batch_info = LoraBatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
)
self.lora_backend.set_batch_info(batch_info)
# call set_lora_info for each lora modules
for module_name, module in self.lora_modules: for module_name, module in self.lora_modules:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)
...@@ -314,16 +352,10 @@ class LoRAManager: ...@@ -314,16 +352,10 @@ class LoRAManager:
module.set_lora_info( module.set_lora_info(
self.A_buffer[weight_name][layer_id], self.A_buffer[weight_name][layer_id],
self.B_buffer[weight_name][layer_id], self.B_buffer[weight_name][layer_id],
bs,
seg_indptr,
weight_indices,
) )
else: else:
module.set_lora_info( module.set_lora_info(
self.A_buffer["qkv_proj"][layer_id], self.A_buffer["qkv_proj"][layer_id],
self.B_buffer["q_proj"][layer_id], self.B_buffer["q_proj"][layer_id],
self.B_buffer["kv_proj"][layer_id], self.B_buffer["kv_proj"][layer_id],
bs,
seg_indptr,
weight_indices,
) )
from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd
from .sgemm_lora_b import sgemm_lora_b_fwd
__all__ = ["qkv_lora_b_fwd", "sgemm_lora_a_fwd", "sgemm_lora_b_fwd"]
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _qkv_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
K, # K = R
max_qkv_out_dim, # max(output_q_dim, output_kv_dim)
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Offsets of q/k/v slice on output dimension
n_offs,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# This kernel packs 3 sgemms (q/k/v) into a single kernel.
# x: (s, 3 * K), s is the sum of sequence lengths, K equals to lora rank
# weights: (num_lora, N_Q + 2 * N_KV, K)
# output: (s, N_Q + 2 * N_KV)
# N_Q >> K, N_KV >> K
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
batch_id = tl.program_id(axis=2)
qkv_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
n_start = tl.load(n_offs + qkv_id)
n_size = tl.load(n_offs + qkv_id + 1) - n_start
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0 + (qkv_id * K) * x_stride_1) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < n_size),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < n_size)
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def qkv_lora_b_fwd(
x: torch.Tensor,
qkv_lora_b: torch.Tensor,
batch_info: LoraBatchInfo,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, 3 * r)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
# output_offset = [0, output_dim_q, output_dim_q + output_dim_kv,
# output_dim_q + 2 * output_dim_kv]
# max_qkv_out_dim = max(output_dim_q, output_dim_kv)
# output: (s, output_dim_q + 2 * output_dim_kv)
# Compute lora_output with shape (s, output_dim) as follows:
# lora_output[:, :output_dim_q] = sgemm(lora_output_a[:, :r], )
# lora_output[:, output_dim_q: output_dim_q + output_dim_kv]
# = sgemm(lora_output_a[:, r: 2 * r], kv_lora_b[0])
# lora_output[:, output_dim_q + output_dim_kv: ]
# = sgemm(lora_output_a[:, 2 * r: 3 * r], kv_lora_b[1])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
r = qkv_lora_b.shape[-1]
output_dim = qkv_lora_b.shape[-2]
assert input_dim == 3 * r
assert output_offset.shape[0] == 4
BLOCK_S = 16
BLOCK_R = 16
BLOCK_OUT = 64
grid_b = (
triton.cdiv(batch_info.max_len, BLOCK_S)
* triton.cdiv(max_qkv_out_dim, BLOCK_OUT),
3, # this dimension decides current block computes on q, k or v
batch_info.bs,
)
if base_output is None:
output = torch.empty((s, output_dim), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_qkv_lora_b_kernel[grid_b](
x,
qkv_lora_b,
output,
r,
max_qkv_out_dim,
x.stride(0),
x.stride(1),
qkv_lora_b.stride(0),
qkv_lora_b.stride(1),
qkv_lora_b.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
output_offset,
BLOCK_S,
BLOCK_OUT,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _sgemm_lora_a_kernel(
# Pointers to matrices
x,
weights,
output,
# Matrix dimensions
N, # r
K, # input_dim
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) and (n_offset[None, :] < N),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = (s_offset[:, None] < seg_len) and (n_offset[None, :] < N)
tl.store(output_ptr, partial_sum, mask=output_mask)
def sgemm_lora_a_fwd(
x: torch.Tensor, weights: torch.Tensor, batch_info: LoraBatchInfo
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, r, input_dim)
# output: (s, r)
# when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
# input_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
S = x.shape[0]
R = weights.shape[-2]
K = weights.shape[-1]
assert x.shape[-1] == K
# Block shapes
BLOCK_S = 16
BLOCK_K = 256
BLOCK_R = 16
grid = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(R, BLOCK_R),
batch_info.bs,
)
output = torch.empty((S, R), device=x.device, dtype=x.dtype)
_sgemm_lora_a_kernel[grid](
x,
weights,
output,
R,
K,
x.stride(0),
x.stride(1),
weights.stride(0),
weights.stride(1),
weights.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_R,
BLOCK_K,
)
return output
import torch
import triton
import triton.language as tl
from sglang.srt.lora.lora import LoraBatchInfo
@triton.jit
def _sgemm_lora_b_kernel(
# Pointers to matrices
x,
weights,
output,
# Matrix dimensions
N, # output_dim
K, # r
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_lens,
seg_indptr,
weight_indices,
# Meta parameters
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
# For fused output scaling and adding
fuse_scaling_add,
scaling,
):
# x: (s, K), s is the sum of sequence lengths
# weights: (num_lora, N, K)
# output: (s, N)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
batch_id = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
seg_len = tl.load(seg_lens + batch_id)
w_index = tl.load(weight_indices + batch_id)
seg_start = tl.load(seg_indptr + batch_id)
# The tile in output matrix will have (pid_s, pid_n) as id
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_s = pid // num_pid_n
pid_n = pid % num_pid_n
# Create pointers for the first block of x and weights[batch_id]
# The pointers will be advanced as we move in the K direction
# and accumulate
s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (x + seg_start * x_stride_0) + (
s_offset[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iteate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset[:, None] < seg_len)
and (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = (output + seg_start * output_stride_0) + (
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
)
output_mask = s_offset[:, None] < seg_len
if fuse_scaling_add:
partial_sum += tl.load(output_ptr, mask=output_mask)
tl.store(output_ptr, partial_sum, mask=output_mask)
def sgemm_lora_b_fwd(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoraBatchInfo,
base_output: torch.Tensor = None,
scaling: float = 1.0,
) -> torch.Tensor:
# x: (s, r)
# weights: (num_lora, output_dim, r)
# output: (s, output_dim)
# output_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
S = x.shape[0]
N = weights.shape[-2]
R = weights.shape[-1]
assert x.shape[-1] == R
# Block shapes
BLOCK_S = 16
BLOCK_R = 16
BLOCK_N = 256
grid = (
triton.cdiv(batch_info.max_len, BLOCK_S) * triton.cdiv(N, BLOCK_N),
batch_info.bs,
)
if base_output is None:
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
fuse_scaling_add = False
else:
output = base_output
fuse_scaling_add = True
_sgemm_lora_b_kernel[grid](
x,
weights,
output,
N,
R,
x.stride(0),
x.stride(1),
weights.stride(0),
weights.stride(1),
weights.stride(2),
output.stride(0),
output.stride(1),
batch_info.seg_lens,
batch_info.seg_indptr,
batch_info.weight_indices,
BLOCK_S,
BLOCK_N,
BLOCK_R,
fuse_scaling_add,
scaling,
)
return output
...@@ -530,6 +530,7 @@ class ModelRunner: ...@@ -530,6 +530,7 @@ class ModelRunner:
max_loras_per_batch=self.server_args.max_loras_per_batch, max_loras_per_batch=self.server_args.max_loras_per_batch,
load_config=self.load_config, load_config=self.load_config,
dtype=self.dtype, dtype=self.dtype,
lora_backend=self.server_args.lora_backend,
) )
logger.info("LoRA manager ready.") logger.info("LoRA manager ready.")
......
...@@ -113,6 +113,7 @@ class ServerArgs: ...@@ -113,6 +113,7 @@ class ServerArgs:
# LoRA # LoRA
lora_paths: Optional[List[str]] = None lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8 max_loras_per_batch: int = 8
lora_backend: str = "triton"
# Kernel backend # Kernel backend
attention_backend: Optional[str] = None attention_backend: Optional[str] = None
...@@ -653,13 +654,19 @@ class ServerArgs: ...@@ -653,13 +654,19 @@ class ServerArgs:
nargs="*", nargs="*",
default=None, default=None,
action=LoRAPathAction, action=LoRAPathAction,
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}", help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
) )
parser.add_argument( parser.add_argument(
"--max-loras-per-batch", "--max-loras-per-batch",
type=int, type=int,
default=8, default=8,
help="Maximum number of adapters for a running batch, include base-only request", help="Maximum number of adapters for a running batch, include base-only request.",
)
parser.add_argument(
"--lora-backend",
type=str,
default="triton",
help="Choose the kernel backend for multi-LoRA serving.",
) )
# Kernel backend # Kernel backend
......
...@@ -272,6 +272,7 @@ class SRTRunner: ...@@ -272,6 +272,7 @@ class SRTRunner:
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER, port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
lora_paths: List[str] = None, lora_paths: List[str] = None,
max_loras_per_batch: int = 4, max_loras_per_batch: int = 4,
lora_backend: str = "triton",
disable_cuda_graph: bool = False, disable_cuda_graph: bool = False,
disable_radix_cache: bool = False, disable_radix_cache: bool = False,
): ):
...@@ -287,6 +288,7 @@ class SRTRunner: ...@@ -287,6 +288,7 @@ class SRTRunner:
is_embedding=not self.is_generation, is_embedding=not self.is_generation,
lora_paths=lora_paths, lora_paths=lora_paths,
max_loras_per_batch=max_loras_per_batch, max_loras_per_batch=max_loras_per_batch,
lora_backend=lora_backend,
disable_cuda_graph=disable_cuda_graph, disable_cuda_graph=disable_cuda_graph,
disable_radix_cache=disable_radix_cache, disable_radix_cache=disable_radix_cache,
) )
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import multiprocessing as mp
import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
from sglang.test.test_utils import calculate_rouge_l
LORA_SETS = [
{"base": "meta-llama/Llama-2-7b-hf", "loras": ["winddude/wizardLM-LlaMA-LoRA-7B"]},
# {"base": "meta-llama/Llama-2-7b-hf", "loras": ["RuterNorway/Llama-2-7b-chat-norwegian-LoRa"]}
]
TORCH_DTYPES = [torch.float16]
PROMPTS = [
"AI is a field of computer science focused on",
"""
### Instruction:
Tell me about llamas and alpacas
### Response:
Llamas are large, long-necked animals with a woolly coat. They have two toes on each foot instead of three like other camelids (camels, dromedaries). Llamas live in the Andean mountains of South America where they graze on grasses and shrubs. Alpaca is another name for domesticated llama. The word "alpaca" comes from an Incan language meaning "golden fleece." Alpacas look very similar to llamas but are smaller than their wild relatives. Both species were used by ancient people as pack animals and for meat. Today both llamas and alpacas are raised primarily for their fiber which can be spun into yarn or knitted into clothing.
### Question 2:
What do you know about llamas?
### Answer:
""",
]
BACKENDS = ["triton", "flashinfer"]
prefill_tolerance: float = 5e-2
decode_tolerance: float = 5e-2
rouge_l_tolerance: float = 1
class TestLoRABackend(unittest.TestCase):
def run_backend(
self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens, backend
):
print(f"=================== testing {backend} backend =======================")
base_path = lora_set["base"]
all_lora_paths = lora_set["loras"]
batch_lora_paths = []
i = 0
for _ in range(len(prompts)):
batch_lora_paths.append(all_lora_paths[i])
i = (i + 1) % len(all_lora_paths)
print(f"batch lora paths={batch_lora_paths}")
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
tp_size=tp_size,
lora_paths=all_lora_paths,
max_loras_per_batch=3,
lora_backend=backend,
disable_cuda_graph=True,
disable_radix_cache=True,
) as srt_runner:
srt_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with HFRunner(
base_path, torch_dtype=torch_dtype, model_type="generation"
) as hf_runner:
hf_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens, lora_paths=batch_lora_paths
)
with SRTRunner(
base_path,
tp_size=tp_size,
torch_dtype=torch_dtype,
model_type="generation",
) as srt_runner:
srt_no_lora_outputs = srt_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
with HFRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
) as hf_runner:
hf_no_lora_outputs = hf_runner.forward(
prompts, max_new_tokens=max_new_tokens
)
for i in range(len(prompts)):
print(f"Prompt {i} with lora path {batch_lora_paths[i]}:")
# compare input logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
hf_no_lora_logprobs = torch.Tensor(hf_no_lora_outputs.top_input_logprobs[i])
srt_no_lora_logprobs = torch.Tensor(
srt_no_lora_outputs.top_input_logprobs[i]
)
print(
"max input diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and srt_lora",
torch.max(abs(srt_no_lora_logprobs - srt_logprobs)),
)
print(
"max input diff between srt_base and hf_base",
torch.max(abs(srt_no_lora_logprobs - hf_no_lora_logprobs)),
)
print(
"max input diff between hf_lora and hf_base",
torch.max(abs(hf_logprobs - hf_no_lora_logprobs)),
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
f"prefill logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"prefill_tolerance={prefill_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# compare output logprobs
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
print(
"max output diff between hf_lora and srt_lora",
torch.max(abs(hf_logprobs - srt_logprobs)),
"\n",
)
if hf_logprobs.shape[0] <= 100:
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
f"decode logprobs are not all close with model_path={base_path},"
f"lora_path={batch_lora_paths[i]}, backend={backend}, prompt={prompts[i]}"
f"decode_tolerance={decode_tolerance}."
f"{hf_logprobs=}, {srt_logprobs=}"
)
# compare output strings
srt_output_str = srt_outputs.output_strs[i].strip(" ")
hf_output_str = hf_outputs.output_strs[i]
print(f"srt_output_str={srt_output_str}")
print(f"hf_output_str={hf_output_str}")
rouge_l_scores = calculate_rouge_l([srt_output_str], [hf_output_str])
print(f"{rouge_l_scores=}")
assert (
rouge_l_scores[0] >= rouge_l_tolerance
), f"ROUGE-L scores of prompt {i} outputs are greater than rouge_l_tolerance={rouge_l_tolerance}"
def test_all(self):
for lora_set in LORA_SETS:
print(f"Testing lora set {lora_set}: ")
for torch_dtype in TORCH_DTYPES:
tp_size = 1
max_new_tokens = 32
for backend in BACKENDS:
self.run_backend(
PROMPTS, lora_set, tp_size, torch_dtype, max_new_tokens, backend
)
if __name__ == "__main__":
try:
mp.set_start_method("spawn")
except RuntimeError:
pass
unittest.main(warnings="ignore")
...@@ -8,6 +8,7 @@ suites = { ...@@ -8,6 +8,7 @@ suites = {
"models/test_embedding_models.py", "models/test_embedding_models.py",
"models/test_generation_models.py", "models/test_generation_models.py",
"models/test_lora.py", "models/test_lora.py",
"models/test_lora_backend.py",
"models/test_qwen_models.py", "models/test_qwen_models.py",
"models/test_reward_models.py", "models/test_reward_models.py",
"sampling/penaltylib", "sampling/penaltylib",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment