Unverified Commit 94100294 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251)

parent cda7e47c
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class BaseLoRABackend: class BaseLoRABackend:
...@@ -10,13 +11,14 @@ class BaseLoRABackend: ...@@ -10,13 +11,14 @@ class BaseLoRABackend:
Each backend has its own implementation of Lora kernels. Each backend has its own implementation of Lora kernels.
Args: Args:
name: name of backend max_loras_per_batch: maximum number of different lora weights
batch_info: information of current batch for use that can be applied in a single forward batch.
device: the device where the backend runs.
""" """
def __init__(self, name: str, batch_info: LoRABatchInfo = None): def __init__(self, max_loras_per_batch: int, device: torch.device):
self.name = name self.max_loras_per_batch = max_loras_per_batch
self.batch_info = batch_info self.device = device
def run_lora_a_sgemm( def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
...@@ -93,8 +95,44 @@ class BaseLoRABackend: ...@@ -93,8 +95,44 @@ class BaseLoRABackend:
""" """
pass pass
def set_batch_info(self, batch_info: LoRABatchInfo): def init_cuda_graph_batch_info(
self.batch_info = batch_info self,
cuda_graph_batch_info: LoRABatchInfo,
max_bs_in_cuda_graph: int,
):
"""Initialize the batch info for CUDA Graph mode.
This method provides a hook for each backend to conduct its own initialization
logic for CUDA Graph mode.
Args:
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
"""
pass
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
"""Prepare the lora weights and batch info for current forward batch.
This method provides a hook for each backend to conduct its own preparation
logic for each forward batch.
Args:
forward_batch: the ForwardBatch object for current forward pass
weight_indices: list of indices of lora weights to be applied for current batch
lora_ranks: list of lora ranks corresponding to weight_indices
scalings: list of scaling factors corresponding to weight_indices
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
"""
pass
def get_backend_from_name(name: str) -> BaseLoRABackend: def get_backend_from_name(name: str) -> BaseLoRABackend:
...@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend: ...@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend return TritonLoRABackend
# elif name == "csgmv":
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
# return ChunkedSgmvLoRABackend
elif name == "flashinfer": elif name == "flashinfer":
raise ValueError( raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead." "FlashInfer LoRA backend has been deprecated, please use `triton` instead."
......
from typing import Optional
import torch import torch
from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
...@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import ( ...@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
sgemm_lora_b_fwd, sgemm_lora_b_fwd,
) )
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class TritonLoRABackend(BaseLoRABackend): class TritonLoRABackend(BaseLoRABackend):
name = "triton"
def __init__(self, name: str, batch_info: LoRABatchInfo = None): def __init__(self, max_loras_per_batch: int, device: torch.device):
super().__init__(name, batch_info) super().__init__(max_loras_per_batch, device)
def run_lora_a_sgemm( def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
...@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend): ...@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
base_output, base_output,
) )
return lora_output return lora_output
def init_cuda_graph_batch_info(
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
):
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
bs = forward_batch.batch_size
if batch_info is not None:
assert (
batch_info.use_cuda_graph
), "batch_info.use_cuda_graph must be True when batch_info is provided"
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = forward_batch.batch_size
else:
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=forward_batch.batch_size,
max_len=max_len,
use_cuda_graph=False,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
weight_indices=torch.empty(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
),
permutation=None,
)
# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
self.batch_info = batch_info
...@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
lora_backend: BaseLoRABackend, lora_backend: BaseLoRABackend,
) -> None: ) -> None:
super().__init__(base_layer, lora_backend) super().__init__(base_layer, lora_backend)
shard_size = self.base_layer.output_partition_sizes[0]
self.output_offset = torch.tensor(
[
0,
shard_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def set_lora_info( def set_lora_info(
self, self,
...@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
lora_output = self.lora_backend.run_lora_b_sgemm( lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output, x=lora_a_output,
weights=self.B_buffer, weights=self.B_buffer,
output_offset=self.output_offset,
base_output=base_output, base_output=base_output,
) )
return lora_output return lora_output
...@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.A_buffer_gate_up = A_buffer self.A_buffer_gate_up = A_buffer
self.B_buffer_gate_up = B_buffer self.B_buffer_gate_up = B_buffer
shard_size = self.base_layer.output_partition_sizes[0]
self.output_offset = torch.tensor(
[
0,
shard_size,
2 * shard_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.lora_backend.run_gate_up_lora( lora_output = self.lora_backend.run_gate_up_lora(
x=x, x=x,
gate_up_lora_a=self.A_buffer_gate_up, gate_up_lora_a=self.A_buffer_gate_up,
gate_up_lora_b=self.B_buffer_gate_up, gate_up_lora_b=self.B_buffer_gate_up,
output_offset=self.output_offset,
base_output=base_output, base_output=base_output,
) )
return lora_output return lora_output
...@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): ...@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.set_lora = True self.set_lora = True
self.A_buffer = A_buffer self.A_buffer = A_buffer
self.B_buffer = B_buffer self.B_buffer = B_buffer
output_size = self.base_layer.output_size
self.output_offset = torch.tensor(
[
0,
output_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor: def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer) lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm( lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output, x=lora_a_output,
weights=self.B_buffer, weights=self.B_buffer,
output_offset=self.output_offset,
base_output=base_output, base_output=base_output,
) )
return lora_output return lora_output
......
...@@ -28,6 +28,9 @@ from torch import nn ...@@ -28,6 +28,9 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend.base_backend import BaseLoRABackend from sglang.srt.lora.backend.base_backend import BaseLoRABackend
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader from sglang.srt.model_loader.loader import DefaultModelLoader
...@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module): ...@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj") gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights: if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name]) weights[up_name] = torch.zeros_like(weights[weight_name])
assert self.lora_backend.name == "triton", ( assert isinstance(self.lora_backend, TritonLoRABackend), (
f"LoRA weight initialization currently only supported for 'triton' backend. " f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration " f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends." f"or consider implementing custom initialization logic for other backends."
......
...@@ -69,7 +69,10 @@ class LoRAManager: ...@@ -69,7 +69,10 @@ class LoRAManager:
# LoRA backend for running sgemm kernels # LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.") logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend) backend_type = get_backend_from_name(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(lora_backend) self.lora_backend: BaseLoRABackend = backend_type(
max_loras_per_batch=max_loras_per_batch,
device=self.device,
)
# Initialize mutable internal state of the LoRAManager. # Initialize mutable internal state of the LoRAManager.
self.init_state( self.init_state(
...@@ -82,29 +85,22 @@ class LoRAManager: ...@@ -82,29 +85,22 @@ class LoRAManager:
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"): with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo( self.cuda_graph_batch_info = LoRABatchInfo(
bs=self.max_bs_in_cuda_graph, bs=max_bs_in_cuda_graph,
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32), use_cuda_graph=True,
seg_indptr=torch.zeros( num_segments=None,
self.max_bs_in_cuda_graph + 1, dtype=torch.int32 seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
), seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=1, max_len=1,
weight_indices=torch.zeros( weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
self.max_bs_in_cuda_graph, dtype=torch.int32 permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32), lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float), scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
) )
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant self.lora_backend.init_cuda_graph_batch_info(
# across batches. cuda_graph_batch_info=self.cuda_graph_batch_info,
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1) max_bs_in_cuda_graph=max_bs_in_cuda_graph,
torch.cumsum( )
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[
1 : self.max_bs_in_cuda_graph + 1
],
)
def create_lora_update_result( def create_lora_update_result(
self, success: bool, error_message: str = "" self, success: bool, error_message: str = ""
...@@ -232,7 +228,6 @@ class LoRAManager: ...@@ -232,7 +228,6 @@ class LoRAManager:
return required_slots <= mem_pool_vacancy return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch): def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool # Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids) cur_uids = set(forward_batch.lora_ids)
...@@ -247,102 +242,30 @@ class LoRAManager: ...@@ -247,102 +242,30 @@ class LoRAManager:
# set up batch info shared by all lora modules # set up batch info shared by all lora modules
bs = forward_batch.batch_size bs = forward_batch.batch_size
def transfer_adapter_info( use_cuda_graph = (
weight_indices_out: torch.Tensor,
lora_ranks_out: torch.Tensor,
scalings_out: torch.Tensor,
):
"""
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
# Copy to device tensors asynchronously
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
lora_ranks_out[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
scalings_out[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
if (
hasattr(self, "max_bs_in_cuda_graph") hasattr(self, "max_bs_in_cuda_graph")
and bs <= self.max_bs_in_cuda_graph and bs <= self.max_bs_in_cuda_graph
and forward_batch.forward_mode.is_cuda_graph() and forward_batch.forward_mode.is_cuda_graph()
): )
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
transfer_adapter_info(
self.cuda_graph_batch_info.weight_indices,
self.cuda_graph_batch_info.lora_ranks,
self.cuda_graph_batch_info.scalings,
)
self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.max_len = 1
batch_info = self.cuda_graph_batch_info
else:
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
)
transfer_adapter_info(
weight_indices,
lora_ranks,
scalings,
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) weight_indices = [0] * len(forward_batch.lora_ids)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0) lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
batch_info = LoRABatchInfo( for i, uid in enumerate(forward_batch.lora_ids):
bs=bs, weight_indices[i] = self.memory_pool.get_buffer_id(uid)
seg_lens=seg_lens, if uid is not None:
seg_indptr=seg_indptr, lora = self.loras[uid]
max_len=max_len, lora_ranks[weight_indices[i]] = lora.config.r
weight_indices=weight_indices, scalings[weight_indices[i]] = lora.scaling
lora_ranks=lora_ranks, # Do in-place updates when CUDA graph is enabled and the batch forward mode
scalings=scalings, # could use CUDA graph.
) self.lora_backend.prepare_lora_batch(
self.lora_backend.set_batch_info(batch_info) forward_batch=forward_batch,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
)
def update_lora_info(self): def update_lora_info(self):
""" """
......
...@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig ...@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
@dataclass @dataclass
class LoRABatchInfo: class LoRABatchInfo:
# The forward mode is using CUDA Graph.
use_cuda_graph: bool
# Batch size # Batch size
bs: int bs: int
# Lengths of each sequence in shape (bs,) # Number of segments. For triton backend, it is equal to batch size.
seg_lens: torch.Tensor num_segments: int
# Indice pointers of each sequence in shape (bs + 1, ) # Indice pointers of each segment in shape (num_segments + 1, )
seg_indptr: torch.Tensor seg_indptr: torch.Tensor
# Maximum sequence length of current batch # The index of lora adapter used by each segment, in shape (num_segments,)
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
weight_indices: torch.Tensor weight_indices: torch.Tensor
# ranks of each lora adapter, in shape (lora_num,) # ranks of each lora adapter, in shape (lora_num,)
...@@ -31,6 +31,15 @@ class LoRABatchInfo: ...@@ -31,6 +31,15 @@ class LoRABatchInfo:
# scaling of each lora adapter, in shape (lora_num,) # scaling of each lora adapter, in shape (lora_num,)
scalings: torch.Tensor scalings: torch.Tensor
# Lengths of each segments in shape (num_segments,)
seg_lens: Optional[torch.Tensor]
# Maximum segment length of current batch
max_len: Optional[int]
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
permutation: Optional[torch.Tensor]
class LoRAType(Enum): class LoRAType(Enum):
LORA_A = 0 LORA_A = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment