Unverified Commit 3f41b48c authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[2/2] Introduce Chunked-SGMV kernels and corresponding LoRA backend for...

[2/2] Introduce Chunked-SGMV kernels and corresponding LoRA backend for improved performance (#10286)
parent 2689f0bf
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
"\n", "\n",
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n", "* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
"\n", "\n",
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we only support Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n", "* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
"\n", "\n",
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n", "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
"\n", "\n",
...@@ -79,7 +79,7 @@ ...@@ -79,7 +79,7 @@
"python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n", " --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" --max-loras-per-batch 1 --lora-backend triton \\\n", " --max-loras-per-batch 1 \\\n",
" --log-level warning \\\n", " --log-level warning \\\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
...@@ -139,7 +139,7 @@ ...@@ -139,7 +139,7 @@
" --enable-lora \\\n", " --enable-lora \\\n",
" --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n", " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
" lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n", " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n", " --max-loras-per-batch 2 \\\n",
" --log-level warning \\\n", " --log-level warning \\\n",
"\"\"\"\n", "\"\"\"\n",
")\n", ")\n",
...@@ -214,7 +214,7 @@ ...@@ -214,7 +214,7 @@
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n", " --enable-lora \\\n",
" --cuda-graph-max-bs 2 \\\n", " --cuda-graph-max-bs 2 \\\n",
" --max-loras-per-batch 2 --lora-backend triton \\\n", " --max-loras-per-batch 2 \\\n",
" --max-lora-rank 256\n", " --max-lora-rank 256\n",
" --lora-target-modules all\n", " --lora-target-modules all\n",
" --log-level warning\n", " --log-level warning\n",
...@@ -413,7 +413,7 @@ ...@@ -413,7 +413,7 @@
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n", " --enable-lora \\\n",
" --cuda-graph-max-bs 8 \\\n", " --cuda-graph-max-bs 8 \\\n",
" --max-loras-per-batch 3 --lora-backend triton \\\n", " --max-loras-per-batch 3 \\\n",
" --max-lora-rank 256 \\\n", " --max-lora-rank 256 \\\n",
" --lora-target-modules all \\\n", " --lora-target-modules all \\\n",
" --lora-paths \\\n", " --lora-paths \\\n",
...@@ -501,6 +501,48 @@ ...@@ -501,6 +501,48 @@
"terminate_process(server_process)" "terminate_process(server_process)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Choosing LoRA Backend\n",
"\n",
"SGLang supports two LoRA backends that you can choose from using the `--lora-backend` argument:\n",
"\n",
"- `triton`: Default basic Triton-based backend.\n",
"- `csgmv`: Chunked SGMV backend optimized for high concurrency scenarios.\n",
"\n",
"The `csgmv` backend was recently introduced to improve performance especially at high-concurrency scenarios. Our benchmark shows that it achieves 20% to 80% latency improvements over the basic triton backend.\n",
"Currently it is at preview phase, we expect to make it our the default LoRA backend in future release. Before that, you can adopt it by manually setting the `--lora-backend` server config."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
" python3 -m sglang.launch_server \\\n",
" --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
" --enable-lora \\\n",
" --lora-backend csgmv \\\n",
" --max-loras-per-batch 16 \\\n",
" --lora-paths lora1=path/to/lora1 lora2=path/to/lora2\n",
" \"\"\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
......
...@@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend: ...@@ -143,10 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend return TritonLoRABackend
# elif name == "csgmv": elif name == "csgmv":
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
# return ChunkedSgmvLoRABackend return ChunkedSgmvLoRABackend
elif name == "flashinfer": elif name == "flashinfer":
raise ValueError( raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead." "FlashInfer LoRA backend has been deprecated, please use `triton` instead."
......
from typing import Optional
import torch
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.triton_ops import (
chunked_sgmv_lora_expand_forward,
chunked_sgmv_lora_shrink_forward,
)
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class ChunkedSgmvLoRABackend(BaseLoRABackend):
"""
Chunked LoRA backend using segmented matrix-vector multiplication.
This backend is largely based on the SGMV (Segmented Gather Matrix-Vector multiplication) algorithm
introduced in the Punica paper (https://arxiv.org/pdf/2310.18547). One main variation made here is to
segment the input sequences into fixed-size chunks, which reduces excessive kernel launches especially
when the LoRA distribution is skewed.
"""
name = "csgmv"
def __init__(self, max_loras_per_batch: int, device: torch.device):
super().__init__(max_loras_per_batch, device)
self.segment_size = 16 # TODO (lifuhuang): make it configurable?
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
) -> torch.Tensor:
return chunked_sgmv_lora_shrink_forward(
x,
weights,
self.batch_info,
)
def run_lora_b_sgemm(
self,
x: torch.Tensor,
weights: torch.Tensor,
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
) -> torch.Tensor:
# For simple lora B, we use slice offsets [0, output_dim]
output_dim = weights.shape[-2]
max_slice_size = output_dim
return chunked_sgmv_lora_expand_forward(
x=x,
lora_weight_b=weights,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=max_slice_size,
base_output=base_output,
)
def run_qkv_lora(
self,
x: torch.Tensor,
qkv_lora_a: torch.Tensor,
qkv_lora_b: torch.Tensor,
output_offset: torch.Tensor,
max_qkv_out_dim: int,
base_output: torch.Tensor = None,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# qkv_lora_a: (num_lora, 3 * r, input_dim)
# qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
assert isinstance(qkv_lora_b, torch.Tensor)
lora_a_output = chunked_sgmv_lora_shrink_forward(
x,
qkv_lora_a,
self.batch_info,
num_slices=3,
)
lora_output = chunked_sgmv_lora_expand_forward(
x=lora_a_output,
lora_weight_b=qkv_lora_b,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=max_qkv_out_dim,
base_output=base_output,
)
return lora_output
def run_gate_up_lora(
self,
x: torch.Tensor,
gate_up_lora_a: torch.Tensor,
gate_up_lora_b: torch.Tensor,
output_offset: torch.Tensor,
base_output: torch.Tensor = None,
*args,
**kwargs
) -> torch.Tensor:
# x: (s, input_dim)
# gate_up_lora_a: (num_lora, 2 * r, input_dim)
# gate_up_lora_b: (num_lora, 2 * output_dim, r)
assert isinstance(gate_up_lora_b, torch.Tensor)
output_dim = gate_up_lora_b.shape[-2] // 2
# lora_a_output: (s, 2 * r)
lora_a_output = chunked_sgmv_lora_shrink_forward(
x,
gate_up_lora_a,
self.batch_info,
num_slices=2,
)
lora_output = chunked_sgmv_lora_expand_forward(
x=lora_a_output,
lora_weight_b=gate_up_lora_b,
batch_info=self.batch_info,
slice_offsets=output_offset,
max_slice_size=output_dim,
base_output=base_output,
)
return lora_output
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
permutation, weight_indices_reordered = ChunkedSgmvLoRABackend._get_permutation(
weight_indices, forward_batch
)
seg_weight_indices, seg_indptr = self._get_segments_info(
weight_indices_reordered
)
num_segments = len(seg_weight_indices)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
if batch_info is None:
batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=num_segments,
use_cuda_graph=False,
seg_indptr=torch.empty(
(num_segments + 1,), dtype=torch.int32, device=self.device
),
weight_indices=torch.empty(
(num_segments,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int32, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
),
permutation=torch.empty(
(len(permutation),), dtype=torch.int32, device=self.device
),
# Not used in chunked kernels
max_len=None,
seg_lens=None,
)
else:
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = num_segments
# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:num_segments].copy_(
seg_weight_indices, non_blocking=True
)
batch_info.seg_indptr[: num_segments + 1].copy_(seg_indptr, non_blocking=True)
batch_info.permutation[: len(permutation)].copy_(permutation, non_blocking=True)
self.batch_info = batch_info
@staticmethod
def _get_permutation(seq_weight_indices, forward_batch: ForwardBatch):
"""
Computes permutation indices for reordering tokens by their LoRA adapter assignments.
This function implements the "gather" step in Chunked Segmented Gather Matrix Vector
multiplication by creating a permutation that groups tokens by their LoRA adapter.
Tokens using the same LoRA adapter are placed together to enable efficient batched
computation.
Example:
seq_weight_indices = [0, 1, 0] # 3 sequences using adapters [0, 1, 0]
extend_seq_lens = [2, 1, 3] # sequence lengths [2, 1, 3 tokens]
# Creates row_weight_indices: [0, 0, 1, 0, 0, 0] (6 tokens total)
# Returns permutation: [0, 1, 3, 4, 5, 2] (groups adapter 0 tokens together)
# weights_reordered: [0, 0, 0, 0, 0, 1] (sorted by adapter)
Args:
seq_weight_indices: List of LoRA adapter indices for each sequence
forward_batch (ForwardBatch): Batch information containing sequence lengths
Returns:
tuple: (permutation, weights_reordered) where:
- permutation: Token reordering indices to group by adapter
- weights_reordered: Sorted adapter indices for each token
"""
with torch.device("cpu"):
seq_weight_indices = torch.tensor(seq_weight_indices, dtype=torch.int32)
seg_lens_cpu = (
torch.tensor(
forward_batch.extend_seq_lens_cpu,
dtype=torch.int32,
)
if forward_batch.forward_mode.is_extend()
else torch.ones(forward_batch.batch_size, dtype=torch.int32)
)
row_weight_indices = torch.repeat_interleave(
seq_weight_indices, seg_lens_cpu
)
permutation = torch.empty(
(len(row_weight_indices),), dtype=torch.long, pin_memory=True
)
torch.argsort(row_weight_indices, stable=True, out=permutation)
weights_reordered = row_weight_indices[permutation]
return permutation, weights_reordered
def _get_segments_info(self, weights_reordered: torch.Tensor):
"""
Computes segment information for chunked SGMV operations.
This function takes the reordered weight indices and creates segments of fixed size
(self.segment_size) for efficient kernel execution. Each segment contains tokens
that use the same LoRA adapter, enabling vectorized computation.
The segmentation is necessary because:
1. GPU kernels work efficiently on fixed-size blocks
2. Large groups of tokens using the same adapter are split into manageable chunks
3. Each segment can be processed independently in parallel
Example:
weights_reordered = [0, 0, 0, 0, 0, 1] # 5 tokens with adapter 0, 1 with adapter 1
segment_size = 3
# Creates segments:
# Segment 0: tokens 0-2 (adapter 0), length=3
# Segment 1: tokens 3-4 (adapter 0), length=2
# Segment 2: token 5 (adapter 1), length=1
# Returns:
# weight_indices_list: [0, 0, 1] (adapter for each segment)
# seg_indptr: [0, 3, 5, 6] (cumulative segment boundaries)
Args:
weights_reordered (torch.Tensor): Sorted adapter indices for each token
Returns:
tuple: (weight_indices_list, seg_indptr) where:
- weight_indices_list: LoRA adapter index for each segment
- seg_indptr: Cumulative segment boundaries (CSR-style indptr)
"""
with torch.device("cpu"):
unique_weights, counts = torch.unique_consecutive(
weights_reordered, return_counts=True
)
weight_indices_list = []
seg_lens_list = []
for weight_idx, group_len in zip(unique_weights, counts):
group_len = group_len.item()
num_segs = (group_len + self.segment_size - 1) // self.segment_size
weight_indices_list.extend([weight_idx.item()] * num_segs)
seg_lens_list.extend([self.segment_size] * (num_segs - 1))
seg_lens_list.append(group_len - (num_segs - 1) * self.segment_size)
seg_lens = torch.tensor(seg_lens_list, dtype=torch.int32)
weight_indices_list = torch.tensor(
weight_indices_list, dtype=torch.int32, pin_memory=True
)
seg_indptr = torch.empty(
(len(seg_lens) + 1,), dtype=torch.int32, pin_memory=True
)
seg_indptr[0] = 0
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
return weight_indices_list, seg_indptr
...@@ -28,14 +28,15 @@ from torch import nn ...@@ -28,14 +28,15 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SUPPORTED_BACKENDS = (TritonLoRABackend, ChunkedSgmvLoRABackend)
class LoRALayer(nn.Module): class LoRALayer(nn.Module):
def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig): def __init__(self, config: LoRAConfig, base_hf_config: AutoConfig):
...@@ -48,6 +49,7 @@ class LoRALayer(nn.Module): ...@@ -48,6 +49,7 @@ class LoRALayer(nn.Module):
class LoRAAdapter(nn.Module): class LoRAAdapter(nn.Module):
def __init__( def __init__(
self, self,
uid: str, uid: str,
...@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module): ...@@ -159,8 +161,8 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights: if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name]) weights[up_name] = torch.zeros_like(weights[weight_name])
assert isinstance(self.lora_backend, TritonLoRABackend), ( assert isinstance(self.lora_backend, SUPPORTED_BACKENDS), (
f"LoRA weight initialization currently only supported for 'triton' backend. " f"LoRA weight initialization currently only supported for LoRA backends: {', '.join(b.name for b in SUPPORTED_BACKENDS)}"
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends." f"or consider implementing custom initialization logic for other backends."
) )
......
from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward
from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward
from .gate_up_lora_b import gate_up_lora_b_fwd from .gate_up_lora_b import gate_up_lora_b_fwd
from .qkv_lora_b import qkv_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd
from .sgemm_lora_a import sgemm_lora_a_fwd from .sgemm_lora_a import sgemm_lora_a_fwd
...@@ -8,4 +10,6 @@ __all__ = [ ...@@ -8,4 +10,6 @@ __all__ = [
"qkv_lora_b_fwd", "qkv_lora_b_fwd",
"sgemm_lora_a_fwd", "sgemm_lora_a_fwd",
"sgemm_lora_b_fwd", "sgemm_lora_b_fwd",
"chunked_sgmv_lora_shrink_forward",
"chunked_sgmv_lora_expand_forward",
] ]
from typing import Optional
import torch
import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
def _chunked_lora_expand_kernel(
# Pointers to matrices
x,
weights,
output,
# Parameters of size
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths and weight id
seg_indptr,
weight_indices,
lora_ranks,
permutation,
num_segs,
# For fused output scaling
scalings,
# Offsets of q/k/v slice on output dimension
slice_offsets,
# Meta parameters
NUM_SLICES: tl.constexpr,
MAX_RANK: tl.constexpr, # K = R
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Computes a chunked SGMV for LoRA expand operations.
When a sequence's rank is 0, the kernel is essentially a no-op, following
the convention in pytorch where the product of two matrices of shape (m, 0)
and (0, n) is an all-zero matrix of shape (m, n).
Args:
x (Tensor): The input tensor, which is the result of the LoRA A projection.
Shape: (s, num_slices * K), where s is the sum of all sequence lengths in the
batch and K is the maximum LoRA rank.
weights (Tensor): The LoRA B weights for all adapters.
Shape: (num_lora, output_dim, K).
output (Tensor): The output tensor where the result is stored.
Shape: (s, output_dim).
"""
tl.static_assert(NUM_SLICES <= 3)
pid_s = tl.program_id(axis=2)
if pid_s >= num_segs:
return
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len.
# qkv_id decides which of q,k,v to compute (0: q, 1: k, 2: v)
w_index = tl.load(weight_indices + pid_s)
cur_rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel is a no-op.
if cur_rank == 0:
return
seg_start = tl.load(seg_indptr + pid_s)
seg_end = tl.load(seg_indptr + pid_s + 1)
slice_id = tl.program_id(axis=1)
slice_start = tl.load(slice_offsets + slice_id)
slice_end = tl.load(slice_offsets + slice_id + 1)
scaling = tl.load(scalings + w_index)
# Adjust K (rank) according to the specific LoRA adapter
cur_rank = tl.minimum(MAX_RANK, cur_rank)
# Map logical sequence index to physical index
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end
)
# Create pointers for the first block of x and weights[batch_id][n_start: n_end][:]
# The pointers will be advanced as we move in the K direction
# and accumulate
pid_n = tl.program_id(axis=0)
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_start
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = (
x
+ slice_id * cur_rank * x_stride_1
+ (s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1)
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iterate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(cur_rank, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset_logical[:, None] < seg_end)
& (k_offset[None, :] < cur_rank - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < cur_rank - k * BLOCK_K)
& (n_offset[None, :] < slice_end),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum *= scaling
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = output + (
s_offset_physical[:, None] * output_stride_0
+ n_offset[None, :] * output_stride_1
)
output_mask = (s_offset_logical[:, None] < seg_end) & (
n_offset[None, :] < slice_end
)
partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0)
tl.store(output_ptr, partial_sum, mask=output_mask)
def chunked_sgmv_lora_expand_forward(
x: torch.Tensor,
lora_weight_b: torch.Tensor,
batch_info: LoRABatchInfo,
slice_offsets: torch.Tensor,
max_slice_size: int,
base_output: torch.Tensor = None,
) -> torch.Tensor:
# x: (s, slice_num * r)
# lora_weight_b: (num_lora, output_dim, r)
# slice_offsets: boundaries for different slices in the output dimension
# output: (s, output_dim)
# Compute lora_output with shape (s, output_dim) as follows:
# For each slice i, accumulates:
# lora_output[:, slice_offsets[i]:slice_offsets[i+1]] += scaling * sgemm(x[:, i*cur_rank:(i+1)*cur_rank], lora_weight_b[:, slice_offsets[i]:slice_offsets[i+1], :])
# Get dims
s = x.shape[0]
input_dim = x.shape[1]
max_lora_rank = lora_weight_b.shape[-1]
output_dim = lora_weight_b.shape[-2]
num_slices = len(slice_offsets) - 1
assert input_dim == num_slices * max_lora_rank
# TODO (lifuhuang): fine-tune per operation
BLOCK_M = 16
BLOCK_K = 16
BLOCK_N = 64
num_segments = batch_info.num_segments
grid = (
triton.cdiv(max_slice_size, BLOCK_N),
num_slices, # number of slices in the input/output
batch_info.bs if batch_info.use_cuda_graph else num_segments,
)
if base_output is None:
output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype)
else:
output = base_output
_chunked_lora_expand_kernel[grid](
x=x,
weights=lora_weight_b,
output=output,
x_stride_0=x.stride(0),
x_stride_1=x.stride(1),
w_stride_0=lora_weight_b.stride(0),
w_stride_1=lora_weight_b.stride(1),
w_stride_2=lora_weight_b.stride(2),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
seg_indptr=batch_info.seg_indptr,
weight_indices=batch_info.weight_indices,
lora_ranks=batch_info.lora_ranks,
permutation=batch_info.permutation,
num_segs=num_segments,
scalings=batch_info.scalings,
slice_offsets=slice_offsets,
# constants
NUM_SLICES=num_slices,
MAX_RANK=max_lora_rank,
BLOCK_S=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
import torch
import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
@triton.jit
def _chunked_lora_shrink_kernel(
# Pointers to matrices
x,
weights,
output,
# Strides
x_stride_0,
x_stride_1,
w_stride_0,
w_stride_1,
w_stride_2,
output_stride_0,
output_stride_1,
# Information on sequence lengths,ranks and weight id
seg_indptr,
weight_indices,
lora_ranks,
permutation,
num_segs,
# Meta parameters
N: tl.constexpr, # num_slices * r
K: tl.constexpr, # input_dim
NUM_SLICES: tl.constexpr,
BLOCK_S: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""
Computes a chunked SGMV for LoRA shrink operations.
The kernel ensures that output[seg_start:seg_start + seg_len, :rank * num_slices]
stores the product of the input `x` and the LoRA weights for the corresponding
sequence. This implies that when rank is 0, the kernel is essentially a no-op,
as output[seg_start:seg_start + seg_len, :0] is trivially correct (empty).
Args:
x (torch.Tensor): The input activations tensor of shape `(s, K)`, where `s`
is the sum of all sequence lengths in the batch.
weights (torch.Tensor): The LoRA A weights for all available adapters,
with shape `(num_lora, N, K)` where N = num_slices * r.
output (torch.Tensor): The output tensor of shape `(s, N)`.
"""
pid_s = tl.program_id(1)
if pid_s >= num_segs:
return
pid_n = tl.program_id(0)
# Current block computes sequence with batch_id,
# which starts from row seg_start of x with length seg_len
w_index = tl.load(weight_indices + pid_s)
rank = tl.load(lora_ranks + w_index)
# If rank is 0, this kernel becomes a no-op as the output is always trivially correct.
if rank == 0:
return
seg_start = tl.load(seg_indptr + pid_s)
seg_end = tl.load(seg_indptr + pid_s + 1)
# Adjust N dim according to the specific LoRA adapter
cur_n = tl.minimum(N, rank * NUM_SLICES)
# Map logical sequence index to physical index
s_offset_logical = tl.arange(0, BLOCK_S) + seg_start
s_offset_physical = tl.load(
permutation + s_offset_logical, mask=s_offset_logical < seg_end
)
n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
k_offset = tl.arange(0, BLOCK_K)
x_ptrs = x + (
s_offset_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1
)
w_ptrs = (weights + w_index * w_stride_0) + (
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
)
# Iterate to compute the block in output matrix
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
x_tile = tl.load(
x_ptrs,
mask=(s_offset_logical[:, None] < seg_end)
& (k_offset[None, :] < K - k * BLOCK_K),
other=0.0,
)
w_tile = tl.load(
w_ptrs,
mask=(k_offset[:, None] < K - k * BLOCK_K) & (n_offset[None, :] < cur_n),
other=0.0,
)
partial_sum += tl.dot(x_tile, w_tile)
x_ptrs += BLOCK_K * x_stride_1
w_ptrs += BLOCK_K * w_stride_2
# Store result to output matrix
partial_sum = partial_sum.to(x.dtype.element_ty)
output_ptr = output + (
s_offset_physical[:, None] * output_stride_0
+ n_offset[None, :] * output_stride_1
)
output_mask = (s_offset_logical[:, None] < seg_end) & (n_offset[None, :] < cur_n)
tl.store(output_ptr, partial_sum, mask=output_mask)
def chunked_sgmv_lora_shrink_forward(
x: torch.Tensor,
weights: torch.Tensor,
batch_info: LoRABatchInfo,
num_slices: int = 1,
) -> torch.Tensor:
# x: (s, input_dim)
# weights: (num_lora, num_slices * r, input_dim)
# output: (s, num_slices * r)
# num_slices: qkv=3, gate_up=2, others=1
# when called with multiple slices, the weights.shape[-2] will be num_slices * r
# input_dim is much larger than r
assert x.is_contiguous()
assert weights.is_contiguous()
assert len(x.shape) == 2
assert len(weights.shape) == 3
# Block shapes
# TODO (lifuhuang): experiment with split-k
BLOCK_S = 16
BLOCK_N = 16
BLOCK_K = 256
S = x.shape[0]
N = weights.shape[1]
K = weights.shape[2]
assert x.shape[-1] == K
num_segments = batch_info.num_segments
grid = (
triton.cdiv(N, BLOCK_N),
batch_info.bs if batch_info.use_cuda_graph else num_segments,
)
output = torch.empty((S, N), device=x.device, dtype=x.dtype)
_chunked_lora_shrink_kernel[grid](
x=x,
weights=weights,
output=output,
x_stride_0=x.stride(0),
x_stride_1=x.stride(1),
w_stride_0=weights.stride(0),
w_stride_1=weights.stride(1),
w_stride_2=weights.stride(2),
output_stride_0=output.stride(0),
output_stride_1=output.stride(1),
seg_indptr=batch_info.seg_indptr,
weight_indices=batch_info.weight_indices,
lora_ranks=batch_info.lora_ranks,
permutation=batch_info.permutation,
num_segs=num_segments,
# constants
N=N,
K=K,
NUM_SLICES=num_slices,
BLOCK_S=BLOCK_S,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
)
return output
...@@ -110,6 +110,8 @@ ATTENTION_BACKEND_CHOICES = [ ...@@ -110,6 +110,8 @@ ATTENTION_BACKEND_CHOICES = [
"ascend", "ascend",
] ]
LORA_BACKEND_CHOICES = ["triton", "csgmv"]
DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"] DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"] GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
...@@ -1601,7 +1603,8 @@ class ServerArgs: ...@@ -1601,7 +1603,8 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--lora-backend", "--lora-backend",
type=str, type=str,
default="triton", choices=LORA_BACKEND_CHOICES,
default=ServerArgs.lora_backend,
help="Choose the kernel backend for multi-LoRA serving.", help="Choose the kernel backend for multi-LoRA serving.",
) )
......
This diff is collapsed.
...@@ -24,6 +24,7 @@ suites = { ...@@ -24,6 +24,7 @@ suites = {
TestFile("lora/test_lora_update.py", 400), TestFile("lora/test_lora_update.py", 400),
TestFile("lora/test_lora_qwen3.py", 97), TestFile("lora/test_lora_qwen3.py", 97),
TestFile("lora/test_lora_radix_cache.py", 100), TestFile("lora/test_lora_radix_cache.py", 100),
TestFile("lora/test_chunked_sgmv_backend.py", 30),
TestFile("models/test_embedding_models.py", 73), TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52), # TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100), TestFile("models/test_encoder_embedding_models.py", 100),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment