Unverified Commit 08ecd0aa authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[3/4] Speed up CSGMV backend perf by 10% through dynamic chunking + kernel optimization (#10592)

parent 720c1c8c
......@@ -41,6 +41,8 @@
"\n",
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n",
"\n",
"* `--max-lora-chunk-size`: Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance. Please tune this value based on your hardware and workload as needed. Defaults to 16.\n",
"\n",
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
"\n",
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
......
......@@ -9,6 +9,9 @@ from sglang.srt.lora.triton_ops import (
)
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
MIN_CHUNK_SIZE = 16
class ChunkedSgmvLoRABackend(BaseLoRABackend):
......@@ -23,17 +26,23 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
name = "csgmv"
def __init__(self, max_loras_per_batch: int, device: torch.device):
def __init__(
self,
max_loras_per_batch: int,
device: torch.device,
server_args: ServerArgs,
):
super().__init__(max_loras_per_batch, device)
self.segment_size = 16 # TODO (lifuhuang): make it configurable?
self.max_chunk_size = server_args.max_lora_chunk_size
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return chunked_sgmv_lora_shrink_forward(
x,
weights,
self.batch_info,
x=x,
weights=weights,
batch_info=self.batch_info,
num_slices=1,
)
def run_lora_b_sgemm(
......@@ -50,7 +59,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
max_slice_size = output_dim
return chunked_sgmv_lora_expand_forward(
x=x,
lora_weight_b=weights,
weights=weights,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=max_slice_size,
......@@ -75,14 +84,14 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
assert isinstance(qkv_lora_b, torch.Tensor)
lora_a_output = chunked_sgmv_lora_shrink_forward(
x,
qkv_lora_a,
self.batch_info,
x=x,
weights=qkv_lora_a,
batch_info=self.batch_info,
num_slices=3,
)
lora_output = chunked_sgmv_lora_expand_forward(
x=lora_a_output,
lora_weight_b=qkv_lora_b,
weights=qkv_lora_b,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=max_qkv_out_dim,
......@@ -109,14 +118,14 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
# lora_a_output: (s, 2 * r)
lora_a_output = chunked_sgmv_lora_shrink_forward(
x,
gate_up_lora_a,
self.batch_info,
x=x,
weights=gate_up_lora_a,
batch_info=self.batch_info,
num_slices=2,
)
lora_output = chunked_sgmv_lora_expand_forward(
x=lora_a_output,
lora_weight_b=gate_up_lora_b,
weights=gate_up_lora_b,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=output_dim,
......@@ -124,6 +133,33 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
)
return lora_output
def _determine_chunk_size(self, forward_batch: ForwardBatch) -> int:
"""
Heuristically determine the chunk size based on token token number in a batch.
Args:
forward_batch (ForwardBatch): The batch information containing sequence lengths.
Returns:
The determined chunk size
"""
if self.max_chunk_size <= MIN_CHUNK_SIZE:
return MIN_CHUNK_SIZE
num_tokens = (
forward_batch.extend_num_tokens
if forward_batch.forward_mode.is_extend()
else forward_batch.batch_size
)
if num_tokens >= 256:
chunk_size = 128
elif num_tokens >= 64:
chunk_size = 32
else: # num_tokens < 64
chunk_size = 16
return min(self.max_chunk_size, chunk_size)
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
......@@ -132,12 +168,16 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
chunk_size = self._determine_chunk_size(forward_batch)
permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
weight_indices, forward_batch
seq_weight_indices=weight_indices,
forward_batch=forward_batch,
)
seg_weight_indices, seg_indptr = self._get_segments_info(
weight_indices_reordered
weights_reordered=weight_indices_reordered,
chunk_size=chunk_size,
)
num_segments = len(seg_weight_indices)
......@@ -152,6 +192,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=num_segments,
max_len=chunk_size,
use_cuda_graph=False,
seg_indptr=torch.empty(
(num_segments + 1,), dtype=torch.int32, device=self.device
......@@ -169,12 +210,12 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
(len(permutation),), dtype=torch.int32, device=self.device
),
# Not used in chunked kernels
max_len=None,
seg_lens=None,
)
else:
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = num_segments
batch_info.max_len = chunk_size
# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
......@@ -241,7 +282,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
return permutation, weights_reordered
def _get_segments_info(self, weights_reordered: torch.Tensor):
def _get_segments_info(self, weights_reordered: torch.Tensor, chunk_size: int):
"""
Computes segment information for chunked SGMV operations.
......@@ -269,6 +310,7 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
Args:
weights_reordered (torch.Tensor): Sorted adapter indices for each token
chunk_size (int): Fixed size for each segment
Returns:
tuple: (weight_indices_list, seg_indptr) where:
......@@ -285,11 +327,11 @@ class ChunkedSgmvLoRABackend(BaseLoRABackend):
for weight_idx, group_len in zip(unique_weights, counts):
group_len = group_len.item()
num_segs = (group_len + self.segment_size - 1) // self.segment_size
num_segs = (group_len + chunk_size - 1) // chunk_size
weight_indices_list.extend([weight_idx.item()] * num_segs)
seg_lens_list.extend([self.segment_size] * (num_segs - 1))
seg_lens_list.append(group_len - (num_segs - 1) * self.segment_size)
seg_lens_list.extend([chunk_size] * (num_segs - 1))
seg_lens_list.append(group_len - (num_segs - 1) * chunk_size)
seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
......
......@@ -11,12 +11,18 @@ from sglang.srt.lora.triton_ops import (
)
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
class TritonLoRABackend(BaseLoRABackend):
name = "triton"
def __init__(self, max_loras_per_batch: int, device: torch.device):
def __init__(
self,
max_loras_per_batch: int,
device: torch.device,
**kwargs,
):
super().__init__(max_loras_per_batch, device)
def run_lora_a_sgemm(
......@@ -30,7 +36,7 @@ class TritonLoRABackend(BaseLoRABackend):
weights: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
......@@ -43,7 +49,7 @@ class TritonLoRABackend(BaseLoRABackend):
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
# x: (s, input_dim)
......@@ -69,7 +75,7 @@ class TritonLoRABackend(BaseLoRABackend):
gate_up_lora_b: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
**kwargs,
) -> torch.Tensor:
# x: (s, input_dim)
......
......@@ -37,6 +37,7 @@ from sglang.srt.lora.utils import (
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import replace_submodule
logger = logging.getLogger(__name__)
......@@ -56,6 +57,7 @@ class LoRAManager:
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[List[LoRARef]] = None,
server_args: Optional[ServerArgs] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
......@@ -72,6 +74,7 @@ class LoRAManager:
self.lora_backend: BaseLoRABackend = backend_type(
max_loras_per_batch=max_loras_per_batch,
device=self.device,
server_args=server_args,
)
# Initialize mutable internal state of the LoRAManager.
......
......@@ -13,15 +13,6 @@ def _chunked_lora_expand_kernel(
x,
weights,
output,
# Parameters of size
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_indptr,
weight_indices,
......@@ -34,8 +25,9 @@ def _chunked_lora_expand_kernel(
slice_offsets,
# Meta parameters
NUM_SLICES: tl.constexpr,
OUTPUT_DIM: tl.constexpr,
MAX_RANK: tl.constexpr, # K = R
BLOCK_S: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
......@@ -57,6 +49,16 @@ def _chunked_lora_expand_kernel(
"""
tl.static_assert(NUM_SLICES <= 3)
x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK
x_stride_1: tl.constexpr = 1
w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK
w_stride_1: tl.constexpr = MAX_RANK
w_stride_2: tl.constexpr = 1
output_stride_0: tl.constexpr = OUTPUT_DIM
output_stride_1: tl.constexpr = 1
pid_s = tl.program_id(axis=2)
if pid_s >= num_segs:
return
......@@ -83,7 +85,7 @@ def _chunked_lora_expand_kernel(
cur_rank = tl.minimum(MAX_RANK, cur_rank)
# Map logical sequence index to physical index
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end
)
......@@ -105,7 +107,7 @@ def _chunked_lora_expand_kernel(
)
# Iterate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
......@@ -140,32 +142,37 @@ def _chunked_lora_expand_kernel(
def chunked_sgmv_lora_expand_forward(
x: torch.Tensor,
lora_weight_b: torch.Tensor,
weights: torch.Tensor,
batch_info: LoRABatchInfo,
slice_offsets: torch.Tensor,
max_slice_size: int,
base_output: torch.Tensor = None,
base_output: Optional[torch.Tensor],
) -> torch.Tensor:
# x: (s, slice_num * r)
# lora_weight_b: (num_lora, output_dim, r)
# weights: (num_lora, output_dim, r)
# slice_offsets: boundaries for different slices in the output dimension
# output: (s, output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# For each slice i, accumulates:
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], lora_weight_b[:, slice_offsets[i]:slice_offsets[i+1], :])
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], weights[:, slice_offsets[i]:slice_offsets[i+1], :])
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
# Get dims
s = x.shape[0]
M = x.shape[0]
input_dim = x.shape[1]
max_lora_rank = lora_weight_b.shape[-1]
output_dim = lora_weight_b.shape[-2]
OUTPUT_DIM = weights.shape[1]
MAX_RANK = weights.shape[2]
num_slices = len(slice_offsets) - 1
assert input_dim == num_slices * max_lora_rank
assert input_dim == num_slices * MAX_RANK
# TODO (lifuhuang): fine-tune per operation
BLOCK_M = 16
BLOCK_M = batch_info.max_len
BLOCK_K = 16
BLOCK_N = 64
......@@ -178,21 +185,14 @@ def chunked_sgmv_lora_expand_forward(
)
if base_output is None:
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
output = torch.zeros((M, OUTPUT_DIM), device=x.device, dtype=x.dtype)
else:
output = base_output
_chunked_lora_expand_kernel[grid](
x=x,
weights=lora_weight_b,
weights=weights,
output=output,
x_stride_0=x.stride(0),
x_stride_1=x.stride(1),
w_stride_0=lora_weight_b.stride(0),
w_stride_1=lora_weight_b.stride(1),
w_stride_2=lora_weight_b.stride(2),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
seg_indptr=batch_info.seg_indptr,
weight_indices=batch_info.weight_indices,
lora_ranks=batch_info.lora_ranks,
......@@ -202,8 +202,9 @@ def chunked_sgmv_lora_expand_forward(
slice_offsets=slice_offsets,
# constants
NUM_SLICES=num_slices,
MAX_RANK=max_lora_rank,
BLOCK_S=BLOCK_M,
OUTPUT_DIM=OUTPUT_DIM,
MAX_RANK=MAX_RANK,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
......
......@@ -11,14 +11,6 @@ def _chunked_lora_shrink_kernel(
x,
weights,
output,
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths,ranks and weight id
seg_indptr,
weight_indices,
......@@ -29,7 +21,7 @@ def _chunked_lora_shrink_kernel(
N: tl.constexpr, # num_slices * r
K: tl.constexpr, # input_dim
NUM_SLICES: tl.constexpr,
BLOCK_S: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
......@@ -48,6 +40,16 @@ def _chunked_lora_shrink_kernel(
with shape `(num_lora, N, K)` where N = num_slices * r.
output (torch.Tensor): The output tensor of shape `(s, N)`.
"""
x_stride_1: tl.constexpr = 1
x_stride_0: tl.constexpr = K
w_stride_0: tl.constexpr = N * K
w_stride_1: tl.constexpr = K
w_stride_2: tl.constexpr = 1
output_stride_0: tl.constexpr = N
output_stride_1: tl.constexpr = 1
pid_s = tl.program_id(1)
if pid_s >= num_segs:
return
......@@ -70,7 +72,7 @@ def _chunked_lora_shrink_kernel(
cur_n = tl.minimum(N, rank * NUM_SLICES)
# Map logical sequence index to physical index
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
s_offset_logical = tl.arange(0, BLOCK_M) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end
)
......@@ -85,7 +87,7 @@ def _chunked_lora_shrink_kernel(
)
# Iterate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
partial_sum = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
......@@ -117,7 +119,7 @@ def chunked_sgmv_lora_shrink_forward(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoRABatchInfo,
num_slices: int = 1,
num_slices: int,
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, num_slices * r, input_dim)
......@@ -133,7 +135,7 @@ def chunked_sgmv_lora_shrink_forward(
# Block shapes
# TODO (lifuhuang): experiment with split-k
BLOCK_S = 16
BLOCK_M = batch_info.max_len
BLOCK_N = 16
BLOCK_K = 256
......@@ -153,13 +155,6 @@ def chunked_sgmv_lora_shrink_forward(
x=x,
weights=weights,
output=output,
x_stride_0=x.stride(0),
x_stride_1=x.stride(1),
w_stride_0=weights.stride(0),
w_stride_1=weights.stride(1),
w_stride_2=weights.stride(2),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
seg_indptr=batch_info.seg_indptr,
weight_indices=batch_info.weight_indices,
lora_ranks=batch_info.lora_ranks,
......@@ -169,7 +164,7 @@ def chunked_sgmv_lora_shrink_forward(
N=N,
K=K,
NUM_SLICES=num_slices,
BLOCK_S=BLOCK_S,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
......
......@@ -19,6 +19,9 @@ class LoRABatchInfo:
# Number of segments. For triton backend, it is equal to batch size.
num_segments: int
# Maximum segment length of current batch
max_len: int
# Indice pointers of each segment in shape (num_segments + 1, )
seg_indptr: torch.Tensor
......@@ -34,9 +37,6 @@ class LoRABatchInfo:
# Lengths of each segments in shape (num_segments,)
seg_lens: Optional[torch.Tensor]
# Maximum segment length of current batch
max_len: Optional[int]
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
permutation: Optional[torch.Tensor]
......
......@@ -1195,6 +1195,7 @@ class ModelRunner:
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
lora_paths=self.server_args.lora_paths,
server_args=self.server_args,
)
def load_lora_adapter(self, lora_ref: LoRARef):
......
......@@ -268,6 +268,7 @@ class ServerArgs:
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
max_lora_chunk_size: Optional[int] = 16
# Kernel backend
attention_backend: Optional[str] = None
......@@ -1779,6 +1780,13 @@ class ServerArgs:
default=ServerArgs.lora_backend,
help="Choose the kernel backend for multi-LoRA serving.",
)
parser.add_argument(
"--max-lora-chunk-size",
type=int,
default=ServerArgs.max_lora_chunk_size,
choices=[16, 32, 64, 128],
help="Maximum chunk size for the ChunkedSGMV LoRA backend. Only used when --lora-backend is 'csgmv'. Choosing a larger value might improve performance.",
)
# Kernel backend
parser.add_argument(
......@@ -2779,6 +2787,12 @@ class ServerArgs:
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
)
if self.max_lora_chunk_size is not None:
assert (
16 <= self.max_lora_chunk_size <= 128
and (self.max_lora_chunk_size & (self.max_lora_chunk_size - 1)) == 0
), "--max-lora-chunk-size must be a power of 2 between 16 and 128."
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
larger_tp = max(decode_tp, prefill_tp)
smaller_tp = min(decode_tp, prefill_tp)
......
......@@ -12,6 +12,8 @@ from sglang.srt.lora.triton_ops import (
)
from sglang.srt.lora.utils import LoRABatchInfo
CHUNK_SIZE = 16
def safe_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Matrix multiplication with mixed precision handling for float16"""
......@@ -343,9 +345,15 @@ class TestChunkedSGMV(unittest.TestCase):
)
# Create a minimal backend instance to access _get_segments_info
mock_backend = ChunkedSgmvLoRABackend(max_loras_per_batch=8, device=self.device)
mock_server_args = type(
"ServerArgs", (object,), {"max_lora_chunk_size": "MOCK_NEVER_USED"}
)
mock_backend = ChunkedSgmvLoRABackend(
max_loras_per_batch=8, device=self.device, server_args=mock_server_args
)
weight_indices_list, seg_indptr = mock_backend._get_segments_info(
weights_reordered
weights_reordered,
chunk_size=CHUNK_SIZE,
)
scalings = [1.0] * len(unique_loras)
......@@ -377,7 +385,7 @@ class TestChunkedSGMV(unittest.TestCase):
lora_ranks=lora_ranks_tensor,
scalings=scalings_tensor,
seg_lens=seq_lens_tensor, # Original sequence lengths for reference
max_len=max(seq_lengths) if seq_lengths else 0,
max_len=CHUNK_SIZE,
permutation=permutation_tensor, # Token reordering permutation
)
......@@ -515,6 +523,7 @@ class TestChunkedSGMV(unittest.TestCase):
batch_info,
self.slice_offsets,
self.max_slice_size,
base_output=None,
)
reference_expand = reference_sgmv_expand(
reference_shrink,
......@@ -594,6 +603,7 @@ class TestChunkedSGMV(unittest.TestCase):
batch_info,
self.slice_offsets,
self.max_slice_size,
base_output=None,
)
reference_expand = reference_sgmv_expand(
intermediate,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment