Unverified Commit 94100294 authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

[1/2] Refactor LoRA to support backend-specific batch preprocessing. (#10251)

parent cda7e47c
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import torch
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class BaseLoRABackend:
......@@ -10,13 +11,14 @@ class BaseLoRABackend:
Each backend has its own implementation of Lora kernels.
Args:
name: name of backend
batch_info: information of current batch for use
max_loras_per_batch: maximum number of different lora weights
that can be applied in a single forward batch.
device: the device where the backend runs.
"""
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
self.name = name
self.batch_info = batch_info
def __init__(self, max_loras_per_batch: int, device: torch.device):
self.max_loras_per_batch = max_loras_per_batch
self.device = device
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
......@@ -93,8 +95,44 @@ class BaseLoRABackend:
"""
pass
def set_batch_info(self, batch_info: LoRABatchInfo):
self.batch_info = batch_info
def init_cuda_graph_batch_info(
self,
cuda_graph_batch_info: LoRABatchInfo,
max_bs_in_cuda_graph: int,
):
"""Initialize the batch info for CUDA Graph mode.
This method provides a hook for each backend to conduct its own initialization
logic for CUDA Graph mode.
Args:
cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
"""
pass
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
"""Prepare the lora weights and batch info for current forward batch.
This method provides a hook for each backend to conduct its own preparation
logic for each forward batch.
Args:
forward_batch: the ForwardBatch object for current forward pass
weight_indices: list of indices of lora weights to be applied for current batch
lora_ranks: list of lora ranks corresponding to weight_indices
scalings: list of scaling factors corresponding to weight_indices
batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
"""
pass
def get_backend_from_name(name: str) -> BaseLoRABackend:
......@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
return TritonLoRABackend
# elif name == "csgmv":
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
# return ChunkedSgmvLoRABackend
elif name == "flashinfer":
raise ValueError(
"FlashInfer LoRA backend has been deprecated, please use `triton` instead."
......
from typing import Optional
import torch
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
......@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
sgemm_lora_b_fwd,
)
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class TritonLoRABackend(BaseLoRABackend):
name = "triton"
def __init__(self, name: str, batch_info: LoRABatchInfo = None):
super().__init__(name, batch_info)
def __init__(self, max_loras_per_batch: int, device: torch.device):
super().__init__(max_loras_per_batch, device)
def run_lora_a_sgemm(
self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
......@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
base_output,
)
return lora_output
def init_cuda_graph_batch_info(
self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
):
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
dim=0,
out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
)
def prepare_lora_batch(
self,
forward_batch: ForwardBatch,
weight_indices: list[int],
lora_ranks: list[int],
scalings: list[float],
batch_info: Optional[LoRABatchInfo] = None,
):
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
bs = forward_batch.batch_size
if batch_info is not None:
assert (
batch_info.use_cuda_graph
), "batch_info.use_cuda_graph must be True when batch_info is provided"
batch_info.bs = forward_batch.batch_size
batch_info.num_segments = forward_batch.batch_size
else:
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
batch_info = LoRABatchInfo(
bs=forward_batch.batch_size,
num_segments=forward_batch.batch_size,
max_len=max_len,
use_cuda_graph=False,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
weight_indices=torch.empty(
(bs,), dtype=torch.int32, device=self.device
),
lora_ranks=torch.empty(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
),
scalings=torch.empty(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
),
permutation=None,
)
# Copy to device asynchronously
batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
batch_info.scalings[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
self.batch_info = batch_info
......@@ -66,6 +66,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
shard_size = self.base_layer.output_partition_sizes[0]
self.output_offset = torch.tensor(
[
0,
shard_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def set_lora_info(
self,
......@@ -81,6 +90,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output,
weights=self.B_buffer,
output_offset=self.output_offset,
base_output=base_output,
)
return lora_output
......@@ -130,11 +140,23 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.A_buffer_gate_up = A_buffer
self.B_buffer_gate_up = B_buffer
shard_size = self.base_layer.output_partition_sizes[0]
self.output_offset = torch.tensor(
[
0,
shard_size,
2 * shard_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_output = self.lora_backend.run_gate_up_lora(
x=x,
gate_up_lora_a=self.A_buffer_gate_up,
gate_up_lora_b=self.B_buffer_gate_up,
output_offset=self.output_offset,
base_output=base_output,
)
return lora_output
......@@ -243,12 +265,22 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
output_size = self.base_layer.output_size
self.output_offset = torch.tensor(
[
0,
output_size,
],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output,
weights=self.B_buffer,
output_offset=self.output_offset,
base_output=base_output,
)
return lora_output
......
......@@ -28,6 +28,9 @@ from torch import nn
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
# from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_loader.loader import DefaultModelLoader
......@@ -156,7 +159,7 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name])
assert self.lora_backend.name == "triton", (
assert isinstance(self.lora_backend, TritonLoRABackend), (
f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
......
......@@ -69,7 +69,10 @@ class LoRAManager:
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(
max_loras_per_batch=max_loras_per_batch,
device=self.device,
)
# Initialize mutable internal state of the LoRAManager.
self.init_state(
......@@ -82,29 +85,22 @@ class LoRAManager:
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=self.max_bs_in_cuda_graph,
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
),
bs=max_bs_in_cuda_graph,
use_cuda_graph=True,
num_segments=None,
seg_lens=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(max_bs_in_cuda_graph + 1, dtype=torch.int32),
max_len=1,
weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32
),
weight_indices=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
permutation=torch.zeros(max_bs_in_cuda_graph, dtype=torch.int32),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[
1 : self.max_bs_in_cuda_graph + 1
],
)
self.lora_backend.init_cuda_graph_batch_info(
cuda_graph_batch_info=self.cuda_graph_batch_info,
max_bs_in_cuda_graph=max_bs_in_cuda_graph,
)
def create_lora_update_result(
self, success: bool, error_message: str = ""
......@@ -232,7 +228,6 @@ class LoRAManager:
return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)
......@@ -247,102 +242,30 @@ class LoRAManager:
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
def transfer_adapter_info(
weight_indices_out: torch.Tensor,
lora_ranks_out: torch.Tensor,
scalings_out: torch.Tensor,
):
"""
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
# Copy to device tensors asynchronously
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
lora_ranks_out[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
scalings_out[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
if (
use_cuda_graph = (
hasattr(self, "max_bs_in_cuda_graph")
and bs <= self.max_bs_in_cuda_graph
and forward_batch.forward_mode.is_cuda_graph()
):
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
transfer_adapter_info(
self.cuda_graph_batch_info.weight_indices,
self.cuda_graph_batch_info.lora_ranks,
self.cuda_graph_batch_info.scalings,
)
self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.max_len = 1
batch_info = self.cuda_graph_batch_info
else:
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
)
transfer_adapter_info(
weight_indices,
lora_ranks,
scalings,
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
self.lora_backend.set_batch_info(batch_info)
weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
self.lora_backend.prepare_lora_batch(
forward_batch=forward_batch,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
batch_info=self.cuda_graph_batch_info if use_cuda_graph else None,
)
def update_lora_info(self):
"""
......
......@@ -10,19 +10,19 @@ from sglang.srt.hf_transformers_utils import AutoConfig
@dataclass
class LoRABatchInfo:
# The forward mode is using CUDA Graph.
use_cuda_graph: bool
# Batch size
bs: int
# Lengths of each sequence in shape (bs,)
seg_lens: torch.Tensor
# Number of segments. For triton backend, it is equal to batch size.
num_segments: int
# Indice pointers of each sequence in shape (bs + 1, )
# Indice pointers of each segment in shape (num_segments + 1, )
seg_indptr: torch.Tensor
# Maximum sequence length of current batch
max_len: int
# The index of lora adapter used by each sequence, in shape (bs,)
# The index of lora adapter used by each segment, in shape (num_segments,)
weight_indices: torch.Tensor
# ranks of each lora adapter, in shape (lora_num,)
......@@ -31,6 +31,15 @@ class LoRABatchInfo:
# scaling of each lora adapter, in shape (lora_num,)
scalings: torch.Tensor
# Lengths of each segments in shape (num_segments,)
seg_lens: Optional[torch.Tensor]
# Maximum segment length of current batch
max_len: Optional[int]
# The logical (re)ordering of input rows (tokens), in shape (num_tokens,)
permutation: Optional[torch.Tensor]
class LoRAType(Enum):
LORA_A = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment