Unverified Commit ffe1fc7a authored by yugong333's avatar yugong333 Committed by GitHub
Browse files

Reduce the kernel overhead when num of active loras is smaller than max...


  Reduce the kernel overhead when num of active loras is smaller than max loras. Multiple cuda graphs are captured for each num of active-loras. (#32005)
Signed-off-by: default avatarYu Gong <yu3.gong@gmail.com>
parent 8b7346d5
...@@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel( ...@@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel(
expert_ids = expert_ids.view(max_loras, -1) expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1) sorted_token_ids = sorted_token_ids.view(max_loras, -1)
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
fused_moe_lora( fused_moe_lora(
output, output,
hidden_states, hidden_states,
...@@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel( ...@@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel(
max_lora_rank, max_lora_rank,
top_k_num, top_k_num,
lora_ids, lora_ids,
num_active_loras,
adapter_enabled, adapter_enabled,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"], config["BLOCK_SIZE_N"],
...@@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive( ...@@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive(
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)
# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
fused_moe_lora( fused_moe_lora(
output, output,
hidden_states, hidden_states,
...@@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive( ...@@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive(
max_lora_rank, max_lora_rank,
top_k_num, top_k_num,
lora_ids, lora_ids,
num_active_loras,
adapter_enabled, adapter_enabled,
config["BLOCK_SIZE_M"], config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"], config["BLOCK_SIZE_N"],
......
...@@ -161,7 +161,7 @@ def check_lora_shrink_kernel( ...@@ -161,7 +161,7 @@ def check_lora_shrink_kernel(
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
out_tensor, out_tensor,
*lora_meta.meta_args(token_nums=token_nums), *lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
scaling, scaling,
) )
...@@ -234,7 +234,7 @@ def check_lora_expand_kernel( ...@@ -234,7 +234,7 @@ def check_lora_expand_kernel(
data.inputs_tensor, data.inputs_tensor,
data.lora_weights, data.lora_weights,
out_tensor, out_tensor,
*lora_meta.meta_args(token_nums=token_nums), *lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
offset_start=0, offset_start=0,
add_inputs=add_inputs, add_inputs=add_inputs,
) )
......
...@@ -17,6 +17,7 @@ from vllm.config import ( ...@@ -17,6 +17,7 @@ from vllm.config import (
SchedulerConfig, SchedulerConfig,
VllmConfig, VllmConfig,
) )
from vllm.config.lora import LoRAConfig
from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
...@@ -47,6 +48,12 @@ def _create_vllm_config( ...@@ -47,6 +48,12 @@ def _create_vllm_config(
mock_config.speculative_config = None # No speculative decoding mock_config.speculative_config = None # No speculative decoding
if not lora_config: if not lora_config:
mock_config.lora_config = None mock_config.lora_config = None
else:
# Create a real LoRAConfig with specialize_active_lora enabled
mock_config.lora_config = LoRAConfig(
max_loras=4,
specialize_active_lora=True,
)
# Mimic the behavior of VllmConfig.__post_init__() # Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1( compilation_config.set_splitting_ops_for_v1(
...@@ -106,15 +113,19 @@ class TestCudagraphDispatcher: ...@@ -106,15 +113,19 @@ class TestCudagraphDispatcher:
) )
# Verify the key is initialized correctly # Verify the key is initialized correctly
# With LoRA specialization (max_loras=4, specialize_active_lora=True):
# - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
# - capture_sizes = [1, 8]
# - Total keys = 2 sizes × 5 lora_cases = 10
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == ( assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2 10 if lora_config else 2
) )
else: else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == ( assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2 10 if lora_config else 2
) )
else: else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
......
...@@ -60,6 +60,13 @@ class LoRAConfig: ...@@ -60,6 +60,13 @@ class LoRAConfig:
of multimodal models will be enabled. This is an experimental feature and of multimodal models will be enabled. This is an experimental feature and
currently only supports some MM models such as the Qwen VL series. The default currently only supports some MM models such as the Qwen VL series. The default
is False.""" is False."""
specialize_active_lora: bool = False
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
When set to True, separate cuda graphs will be captured for different counts
of active LoRAs (powers of 2 up to max_loras), which can improve performance
for variable LoRA usage patterns at the cost of increased startup time and
memory usage. Only takes effect when cudagraph_specialize_lora is True.
"""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
......
...@@ -485,6 +485,7 @@ class EngineArgs: ...@@ -485,6 +485,7 @@ class EngineArgs:
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
...@@ -1026,6 +1027,9 @@ class EngineArgs: ...@@ -1026,6 +1027,9 @@ class EngineArgs:
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
) )
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
)
# Observability arguments # Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig) observability_kwargs = get_kwargs(ObservabilityConfig)
...@@ -1657,6 +1661,7 @@ class EngineArgs: ...@@ -1657,6 +1661,7 @@ class EngineArgs:
fully_sharded_loras=self.fully_sharded_loras, fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype, lora_dtype=self.lora_dtype,
enable_tower_connector_lora=self.enable_tower_connector_lora, enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0 if self.max_cpu_loras and self.max_cpu_loras > 0
else None, else None,
......
...@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple): ...@@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple):
""" """
Whether this batch has active LoRA adapters. Whether this batch has active LoRA adapters.
""" """
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
""" """
...@@ -54,7 +62,11 @@ class BatchDescriptor(NamedTuple): ...@@ -54,7 +62,11 @@ class BatchDescriptor(NamedTuple):
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
""" """
return BatchDescriptor( return BatchDescriptor(
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
) )
......
...@@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): ...@@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
def _adjust_kernel_inputs( def _adjust_kernel_inputs(
max_loras: int, num_active_loras: int,
sorted_token_ids: torch.Tensor | None, sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
): ):
...@@ -109,7 +109,7 @@ def _adjust_kernel_inputs( ...@@ -109,7 +109,7 @@ def _adjust_kernel_inputs(
else: else:
stride_tl = sorted_token_ids.stride(0) stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0) stride_el = expert_ids.stride(0)
grid_lora_dim = max_loras + 1 grid_lora_dim = num_active_loras
return grid_lora_dim, stride_tl, stride_el return grid_lora_dim, stride_tl, stride_el
...@@ -354,6 +354,7 @@ def _fused_moe_lora_shrink( ...@@ -354,6 +354,7 @@ def _fused_moe_lora_shrink(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
use_gdc: bool = False, use_gdc: bool = False,
) -> None: ) -> None:
...@@ -373,7 +374,7 @@ def _fused_moe_lora_shrink( ...@@ -373,7 +374,7 @@ def _fused_moe_lora_shrink(
b_ptr = _get_ptr(lora_a_stacked, device) b_ptr = _get_ptr(lora_a_stacked, device)
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids num_active_loras, sorted_token_ids, expert_ids
) )
grid = lambda META: ( grid = lambda META: (
split_k split_k
...@@ -457,6 +458,7 @@ def _fused_moe_lora_expand( ...@@ -457,6 +458,7 @@ def _fused_moe_lora_expand(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
offset: int = 0, offset: int = 0,
use_gdc: bool = False, use_gdc: bool = False,
...@@ -484,7 +486,7 @@ def _fused_moe_lora_expand( ...@@ -484,7 +486,7 @@ def _fused_moe_lora_expand(
} }
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids num_active_loras, sorted_token_ids, expert_ids
) )
grid = lambda META: ( grid = lambda META: (
...@@ -557,6 +559,7 @@ def _fused_moe_lora( ...@@ -557,6 +559,7 @@ def _fused_moe_lora(
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor, adapter_enabled: torch.Tensor,
shrink_block_size_m: int, shrink_block_size_m: int,
shrink_block_size_n: int, shrink_block_size_n: int,
...@@ -648,6 +651,7 @@ def _fused_moe_lora( ...@@ -648,6 +651,7 @@ def _fused_moe_lora(
shrink_num_warps, shrink_num_warps,
shrink_num_stages, shrink_num_stages,
shrink_split_k, shrink_split_k,
num_active_loras,
mul_routed_weight, mul_routed_weight,
use_gdc=use_gdc, use_gdc=use_gdc,
) )
...@@ -695,6 +699,7 @@ def _fused_moe_lora( ...@@ -695,6 +699,7 @@ def _fused_moe_lora(
expand_num_warps, expand_num_warps,
expand_num_stages, expand_num_stages,
expand_split_k, expand_split_k,
num_active_loras,
mul_routed_weight, mul_routed_weight,
offset, offset,
use_gdc=use_gdc, use_gdc=use_gdc,
...@@ -714,6 +719,7 @@ def _fused_moe_lora_fake( ...@@ -714,6 +719,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor, adapter_enabled: torch.Tensor,
shrink_block_size_m: int, shrink_block_size_m: int,
shrink_block_size_n: int, shrink_block_size_n: int,
...@@ -730,6 +736,8 @@ def _fused_moe_lora_fake( ...@@ -730,6 +736,8 @@ def _fused_moe_lora_fake(
expand_num_stages: int, expand_num_stages: int,
expand_split_k: int, expand_split_k: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
fully_sharded: bool = False,
offset: int = 0,
) -> None: ) -> None:
return return
...@@ -761,6 +769,7 @@ def _fused_moe_lora_shrink_fake( ...@@ -761,6 +769,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
use_gdc: bool = False, use_gdc: bool = False,
) -> None: ) -> None:
...@@ -770,6 +779,7 @@ def _fused_moe_lora_shrink_fake( ...@@ -770,6 +779,7 @@ def _fused_moe_lora_shrink_fake(
def _fused_moe_lora_expand_fake( def _fused_moe_lora_expand_fake(
output: torch.Tensor, output: torch.Tensor,
a_intermediate_cache1: torch.Tensor, a_intermediate_cache1: torch.Tensor,
b_intermediate_cache1: torch.Tensor,
lora_b_stacked: list[torch.Tensor], lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None, sorted_token_ids: torch.Tensor | None,
...@@ -796,7 +806,9 @@ def _fused_moe_lora_expand_fake( ...@@ -796,7 +806,9 @@ def _fused_moe_lora_expand_fake(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False, use_gdc: bool = False,
) -> None: ) -> None:
return return
......
...@@ -138,6 +138,7 @@ def _lora_expand( ...@@ -138,6 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1] no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
...@@ -234,10 +235,7 @@ def _lora_expand( ...@@ -234,10 +235,7 @@ def _lora_expand(
grid = ( grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES, NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output num_active_loras,
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
) )
# We disable PDL temporarily because LoRA kernels are not launching back-to-back, # We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance. # making PDL invalid and affecting the kernel performance.
...@@ -291,6 +289,7 @@ def _lora_expand_fake( ...@@ -291,6 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor, no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
......
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
LoRA kernels metadata preparation utilities. LoRA kernels metadata preparation utilities.
""" """
from dataclasses import dataclass import bisect
from dataclasses import dataclass, field
import torch import torch
...@@ -28,9 +29,22 @@ class LoRAKernelMeta: ...@@ -28,9 +29,22 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation. # to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
# to the nearest value in this list to match cudagraph capture keys.
# Empty list means no specialization (use actual count).
captured_lora_counts: list[int] = field(default_factory=list)
@staticmethod @staticmethod
def make( def make(
max_loras: int, max_num_tokens: int, device: torch.device | str max_loras: int,
max_num_tokens: int,
device: torch.device | str,
captured_lora_counts: list[int] | None = None,
) -> "LoRAKernelMeta": ) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty( token_lora_mapping = torch.empty(
max_num_tokens, dtype=torch.int32, device=device max_num_tokens, dtype=torch.int32, device=device
...@@ -66,6 +80,9 @@ class LoRAKernelMeta: ...@@ -66,6 +80,9 @@ class LoRAKernelMeta:
num_tokens_per_lora=num_tokens_per_lora, num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc, lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu, no_lora_flag_cpu=no_lora_flag_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
) )
def _reset(self): def _reset(self):
...@@ -73,6 +90,8 @@ class LoRAKernelMeta: ...@@ -73,6 +90,8 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0) self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0) self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False) self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
self.captured_lora_counts = []
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
""" """
...@@ -118,6 +137,15 @@ class LoRAKernelMeta: ...@@ -118,6 +137,15 @@ class LoRAKernelMeta:
num_tokens_per_lora, non_blocking=True num_tokens_per_lora, non_blocking=True
) )
self.num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx]
# lora_token_start_loc # lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_( self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
...@@ -125,7 +153,9 @@ class LoRAKernelMeta: ...@@ -125,7 +153,9 @@ class LoRAKernelMeta:
) )
def meta_args( def meta_args(
self, token_nums: int self,
token_nums: int,
specialize_active_lora: bool,
) -> tuple[ ) -> tuple[
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
...@@ -133,6 +163,7 @@ class LoRAKernelMeta: ...@@ -133,6 +163,7 @@ class LoRAKernelMeta:
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
int,
]: ]:
""" """
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
...@@ -144,6 +175,7 @@ class LoRAKernelMeta: ...@@ -144,6 +175,7 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward token_nums (int): Number of input tokens in the current forward
pass of the kernel. pass of the kernel.
""" """
max_loras = self.active_lora_ids.size(0) - 1
return ( return (
self.token_lora_mapping[:token_nums], self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums], self.token_indices_sorted_by_lora_ids[:token_nums],
...@@ -151,4 +183,5 @@ class LoRAKernelMeta: ...@@ -151,4 +183,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc, self.lora_token_start_loc,
self.active_lora_ids, self.active_lora_ids,
self.no_lora_flag_cpu, self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1,
) )
...@@ -134,6 +134,7 @@ def _lora_shrink( ...@@ -134,6 +134,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1] no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
...@@ -214,10 +215,7 @@ def _lora_shrink( ...@@ -214,10 +215,7 @@ def _lora_shrink(
grid = ( grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES, NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output num_active_loras,
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
) )
# We disable PDL temporarily because LoRA kernels are not launching back-to-back, # We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance. # making PDL invalid and affecting the kernel performance.
...@@ -269,6 +267,7 @@ def _lora_shrink_fake( ...@@ -269,6 +267,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor, no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
scaling: float, scaling: float,
) -> None: ) -> None:
return return
......
...@@ -12,6 +12,7 @@ from typing import final ...@@ -12,6 +12,7 @@ from typing import final
import torch import torch
from vllm.lora.layers import LoRAMapping from vllm.lora.layers import LoRAMapping
from vllm.lora.utils import get_captured_lora_counts
from vllm.triton_utils import HAS_TRITON, triton from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
...@@ -48,8 +49,16 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -48,8 +49,16 @@ class PunicaWrapperGPU(PunicaWrapperBase):
self.lora_config = kwargs["lora_config"] self.lora_config = kwargs["lora_config"]
self.max_loras = self.lora_config.max_loras self.max_loras = self.lora_config.max_loras
# Compute captured LoRA counts for cudagraph specialization.
captured_lora_counts = get_captured_lora_counts(
self.max_loras, self.lora_config.specialize_active_lora
)
self.token_mapping_meta = LoRAKernelMeta.make( self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device self.max_loras,
max_num_batched_tokens,
device=device,
captured_lora_counts=captured_lora_counts,
) )
# When speculative decoding is enabled, max_num_samples is # When speculative decoding is enabled, max_num_samples is
...@@ -57,7 +66,10 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -57,7 +66,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
# This line can be optimized by replacing max_num_batched_tokens # This line can be optimized by replacing max_num_batched_tokens
# to max_batches * (num_speculative_decoding_tokens + 1). # to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make( self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device self.max_loras,
max_num_batched_tokens,
device=device,
captured_lora_counts=captured_lora_counts,
) )
def update_metadata( def update_metadata(
...@@ -102,7 +114,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -102,7 +114,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x, x,
lora_a_stacked, lora_a_stacked,
y, y,
*self.token_mapping_meta.meta_args(x.size(0)), *self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
scale, scale,
) )
...@@ -143,7 +157,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -143,7 +157,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x, x,
lora_b_stacked, lora_b_stacked,
y, y,
*self.token_mapping_meta.meta_args(num_tokens), *self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
),
offset_start=offset_start, offset_start=offset_start,
add_inputs=True, add_inputs=True,
) )
...@@ -175,7 +191,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -175,7 +191,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x.unsqueeze(dim=0), x.unsqueeze(dim=0),
(lora_b_stacked,), (lora_b_stacked,),
y, y,
*self.token_mapping_meta.meta_args(x.size(0)), *self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
offset_start=0, offset_start=0,
add_inputs=add_inputs, add_inputs=add_inputs,
) )
...@@ -287,7 +305,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -287,7 +305,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x, x,
[lora_a_stacked], [lora_a_stacked],
buffer.unsqueeze(dim=0), buffer.unsqueeze(dim=0),
*self.prompt_mapping_meta.meta_args(x.size(0)), *self.prompt_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
scale, scale,
) )
...@@ -295,7 +315,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -295,7 +315,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
buffer.unsqueeze(dim=0), buffer.unsqueeze(dim=0),
[lora_b_stacked], [lora_b_stacked],
y, y,
*self.prompt_mapping_meta.meta_args(buffer.size(0)), *self.prompt_mapping_meta.meta_args(
buffer.size(0), self.lora_config.specialize_active_lora
),
add_inputs=True, add_inputs=True,
) )
y = y.view_as(y_org) y = y.view_as(y_org)
...@@ -316,8 +338,10 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -316,8 +338,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
Aligns tokens and experts into block-sized chunks for LoRA-based Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution. mixture-of-experts (MoE) execution.
""" """
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( (token_lora_mapping, _, _, _, lora_ids, _, _) = (
num_tokens self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
) )
if naive_block_assignment: if naive_block_assignment:
expert_ids = topk_ids.reshape(-1) expert_ids = topk_ids.reshape(-1)
...@@ -392,7 +416,10 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -392,7 +416,10 @@ class PunicaWrapperGPU(PunicaWrapperBase):
_, _,
lora_ids, lora_ids,
_, _,
) = self.token_mapping_meta.meta_args(x.size(0)) num_active_loras,
) = self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
)
if token_lora_mapping is None: if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta token_lora_mapping = token_lora_mapping_meta
fused_moe_lora( fused_moe_lora(
...@@ -408,6 +435,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -408,6 +435,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
max_lora_rank, max_lora_rank,
top_k_num, top_k_num,
lora_ids, lora_ids,
num_active_loras,
adapter_enabled, adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64), shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64), shrink_config.get("BLOCK_SIZE_N", 64),
......
...@@ -44,6 +44,25 @@ if TYPE_CHECKING: ...@@ -44,6 +44,25 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0 _GLOBAL_LORA_ID = 0
......
...@@ -5,6 +5,7 @@ from itertools import product ...@@ -5,6 +5,7 @@ from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
from vllm.forward_context import BatchDescriptor from vllm.forward_context import BatchDescriptor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.utils import get_captured_lora_counts
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -57,6 +58,11 @@ class CudagraphDispatcher: ...@@ -57,6 +58,11 @@ class CudagraphDispatcher:
) )
self.keys_initialized = False self.keys_initialized = False
self.specialize_lora_count = (
self.vllm_config.lora_config.specialize_active_lora
if self.vllm_config.lora_config is not None
else False
)
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE self.cudagraph_mode = CUDAGraphMode.NONE
...@@ -92,8 +98,33 @@ class CudagraphDispatcher: ...@@ -92,8 +98,33 @@ class CudagraphDispatcher:
"Use values from cudagraph_capture_sizes." "Use values from cudagraph_capture_sizes."
) )
def _get_lora_cases(self) -> list[int]:
"""
Returns list of has_lora values for CUDA graph capture.
This is the single source of truth for LoRA capture cases.
"""
lora_config = self.vllm_config.lora_config
if lora_config is None:
# No LoRA configured - single case with no LoRA
return [0]
# LoRA is enabled - capture graphs based on cudagraph_specialize_lora
if self.compilation_config.cudagraph_specialize_lora:
captured_counts = get_captured_lora_counts(
lora_config.max_loras, self.specialize_lora_count
)
# Specialize: capture separate graphs for with and without LoRA
return [0] + captured_counts
else:
# No specialization: only capture graphs with LoRA active
return [lora_config.max_loras + 1]
def _create_padded_batch_descriptor( def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor: ) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len uniform_decode_query_len = self.uniform_decode_query_len
...@@ -111,6 +142,7 @@ class CudagraphDispatcher: ...@@ -111,6 +142,7 @@ class CudagraphDispatcher:
num_reqs=num_reqs, num_reqs=num_reqs,
uniform=uniform_decode, uniform=uniform_decode,
has_lora=has_lora, has_lora=has_lora,
num_active_loras=num_active_loras,
) )
def add_cudagraph_key( def add_cudagraph_key(
...@@ -135,26 +167,23 @@ class CudagraphDispatcher: ...@@ -135,26 +167,23 @@ class CudagraphDispatcher:
self._compute_bs_to_padded_graph_size() self._compute_bs_to_padded_graph_size()
# LoRA activation cases to specialize the cuda graphs on # Get LoRA cases to capture
if self.vllm_config.lora_config: lora_cases = self._get_lora_cases()
if self.compilation_config.cudagraph_specialize_lora: self.captured_lora_counts = [
lora_cases = [True, False] lora_count for lora_count in lora_cases if lora_count
else: ]
lora_cases = [True]
else:
lora_cases = [False]
# Note: we create all valid keys for cudagraph here but do not # Note: we create all valid keys for cudagraph here but do not
# guarantee all keys would be used. For example, if we allow lazy # guarantee all keys would be used. For example, if we allow lazy
# capturing in future PR, some keys may never be triggered. # capturing in future PR, some keys may never be triggered.
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
for bs, has_lora in product( for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases self.compilation_config.cudagraph_capture_sizes, lora_cases
): ):
self.add_cudagraph_key( self.add_cudagraph_key(
cudagraph_mode.mixed_mode(), cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor( self._create_padded_batch_descriptor(
bs, False, has_lora bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(), ).relax_for_mixed_batch_cudagraphs(),
) )
...@@ -173,10 +202,14 @@ class CudagraphDispatcher: ...@@ -173,10 +202,14 @@ class CudagraphDispatcher:
for x in self.compilation_config.cudagraph_capture_sizes for x in self.compilation_config.cudagraph_capture_sizes
if x <= max_num_tokens and x >= uniform_decode_query_len if x <= max_num_tokens and x >= uniform_decode_query_len
] ]
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): for bs, num_active_loras in product(
cudagraph_capture_sizes_for_decode, lora_cases
):
self.add_cudagraph_key( self.add_cudagraph_key(
CUDAGraphMode.FULL, CUDAGraphMode.FULL,
self._create_padded_batch_descriptor(bs, True, has_lora), self._create_padded_batch_descriptor(
bs, True, num_active_loras > 0, num_active_loras
),
) )
self.keys_initialized = True self.keys_initialized = True
...@@ -187,6 +220,7 @@ class CudagraphDispatcher: ...@@ -187,6 +220,7 @@ class CudagraphDispatcher:
uniform_decode: bool = False, uniform_decode: bool = False,
has_lora: bool = False, has_lora: bool = False,
disable_full: bool = False, disable_full: bool = False,
num_active_loras: int = 0,
) -> tuple[CUDAGraphMode, BatchDescriptor]: ) -> tuple[CUDAGraphMode, BatchDescriptor]:
""" """
Given conditions(e.g.,batch descriptor and if using piecewise only), Given conditions(e.g.,batch descriptor and if using piecewise only),
...@@ -202,6 +236,7 @@ class CudagraphDispatcher: ...@@ -202,6 +236,7 @@ class CudagraphDispatcher:
disable_full: If True, skip FULL cudagraph checks and disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs) cascade attention that are not supported by full cudagraphs)
num_active_loras: Number of distinct active LoRA adapters.
""" """
if ( if (
not self.keys_initialized not self.keys_initialized
...@@ -210,8 +245,24 @@ class CudagraphDispatcher: ...@@ -210,8 +245,24 @@ class CudagraphDispatcher:
): ):
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
effective_num_active_loras = num_active_loras
if has_lora and num_active_loras > 0:
if self.specialize_lora_count:
# Find the smallest captured `num_active_loras` that is >= the current
# `num_active_loras`. This is because we only capture graphs for
# a subset of possible `num_active_loras` values (powers of 2).
import bisect
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
if idx < len(self.captured_lora_counts):
effective_num_active_loras = self.captured_lora_counts[idx]
else:
# When not specializing, graphs are captured only with max_loras + 1,
# so we must use max_loras + 1 for dispatch to find a matching graph.
effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1
batch_desc = self._create_padded_batch_descriptor( batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora num_tokens, uniform_decode, has_lora, effective_num_active_loras
) )
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
......
...@@ -3082,6 +3082,7 @@ class GPUModelRunner( ...@@ -3082,6 +3082,7 @@ class GPUModelRunner(
# be improved in model runner v2) # be improved in model runner v2)
force_uniform_decode: bool | None = None, force_uniform_decode: bool | None = None,
force_has_lora: bool | None = None, force_has_lora: bool | None = None,
force_num_active_loras: int | None = None,
num_encoder_reqs: int = 0, num_encoder_reqs: int = 0,
) -> tuple[ ) -> tuple[
CUDAGraphMode, CUDAGraphMode,
...@@ -3103,11 +3104,13 @@ class GPUModelRunner( ...@@ -3103,11 +3104,13 @@ class GPUModelRunner(
self.model_config.is_encoder_decoder and num_encoder_reqs > 0 self.model_config.is_encoder_decoder and num_encoder_reqs > 0
) )
has_lora = ( # Compute LoRA state for cudagraph dispatch
len(self.input_batch.lora_id_to_lora_request) > 0 num_active_loras = (
if force_has_lora is None force_num_active_loras
else force_has_lora if force_num_active_loras is not None
else len(self.input_batch.lora_id_to_lora_request)
) )
has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora
num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens)
dispatch_cudagraph = ( dispatch_cudagraph = (
...@@ -3116,6 +3119,7 @@ class GPUModelRunner( ...@@ -3116,6 +3119,7 @@ class GPUModelRunner(
has_lora=has_lora, has_lora=has_lora,
uniform_decode=uniform_decode, uniform_decode=uniform_decode,
disable_full=disable_full, disable_full=disable_full,
num_active_loras=num_active_loras,
) )
if not force_eager if not force_eager
else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded))
...@@ -4606,8 +4610,8 @@ class GPUModelRunner( ...@@ -4606,8 +4610,8 @@ class GPUModelRunner(
is_profile: bool = False, is_profile: bool = False,
create_mixed_batch: bool = False, create_mixed_batch: bool = False,
remove_lora: bool = True, remove_lora: bool = True,
activate_lora: bool = False,
is_graph_capturing: bool = False, is_graph_capturing: bool = False,
num_active_loras: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Run a dummy forward pass to warm up/profile run or capture the Run a dummy forward pass to warm up/profile run or capture the
...@@ -4630,7 +4634,8 @@ class GPUModelRunner( ...@@ -4630,7 +4634,8 @@ class GPUModelRunner(
create_mixed_batch: If True, create a mixed batch with both decode create_mixed_batch: If True, create a mixed batch with both decode
(1 token) and prefill (multiple tokens) requests. (1 token) and prefill (multiple tokens) requests.
remove_lora: If False, dummy LoRAs are not destroyed after the run remove_lora: If False, dummy LoRAs are not destroyed after the run
activate_lora: If False, dummy_run is performed without LoRAs. num_active_loras: Number of distinct active LoRAs to capture for.
LoRA is activated when num_active_loras > 0.
""" """
mm_config = self.vllm_config.model_config.multimodal_config mm_config = self.vllm_config.model_config.multimodal_config
if mm_config and mm_config.mm_encoder_only: if mm_config and mm_config.mm_encoder_only:
...@@ -4712,7 +4717,10 @@ class GPUModelRunner( ...@@ -4712,7 +4717,10 @@ class GPUModelRunner(
# `force_has_lora` is used for cudagraph capture; because LoRA is # `force_has_lora` is used for cudagraph capture; because LoRA is
# activated later in the context manager, but we need to know the # activated later in the context manager, but we need to know the
# LoRA state when determining the batch descriptor for capture # LoRA state when determining the batch descriptor for capture
force_has_lora=activate_lora, force_has_lora=num_active_loras > 0,
# `force_num_active_loras` is used for cudagraph capture; because we
# need to capture graphs for specific num_active_loras counts
force_num_active_loras=num_active_loras,
) )
) )
...@@ -4782,8 +4790,8 @@ class GPUModelRunner( ...@@ -4782,8 +4790,8 @@ class GPUModelRunner(
self.lora_config, self.lora_config,
num_scheduled_tokens, num_scheduled_tokens,
num_sampled_tokens, num_sampled_tokens,
activate_lora,
remove_lora, remove_lora,
num_active_loras,
): ):
# Make sure padding doesn't exceed max_num_tokens # Make sure padding doesn't exceed max_num_tokens
assert num_tokens_padded <= self.max_num_tokens assert num_tokens_padded <= self.max_num_tokens
...@@ -4884,7 +4892,10 @@ class GPUModelRunner( ...@@ -4884,7 +4892,10 @@ class GPUModelRunner(
# lora cases when cudagraph_specialize_lora is enabled. This is a # lora cases when cudagraph_specialize_lora is enabled. This is a
# short term mitigation for issue mentioned in # short term mitigation for issue mentioned in
# https://github.com/vllm-project/vllm/issues/28334 # https://github.com/vllm-project/vllm/issues/28334
if self.compilation_config.cudagraph_specialize_lora and activate_lora: if (
self.compilation_config.cudagraph_specialize_lora
and num_active_loras > 0
):
use_cudagraphs = False use_cudagraphs = False
self.drafter.dummy_run( self.drafter.dummy_run(
...@@ -5259,7 +5270,7 @@ class GPUModelRunner( ...@@ -5259,7 +5270,7 @@ class GPUModelRunner(
# We skip EPLB here since we don't want to record dummy metrics # We skip EPLB here since we don't want to record dummy metrics
for batch_desc in batch_descriptors: for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens num_tokens = batch_desc.num_tokens
activate_lora = batch_desc.has_lora num_active_loras = batch_desc.num_active_loras
# We currently only capture ubatched graphs when its a FULL # We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens # cudagraph, a uniform decode batch, and the number of tokens
...@@ -5286,7 +5297,7 @@ class GPUModelRunner( ...@@ -5286,7 +5297,7 @@ class GPUModelRunner(
num_tokens, num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
activate_lora=activate_lora, num_active_loras=num_active_loras,
) )
# Capture run # Capture run
...@@ -5294,7 +5305,7 @@ class GPUModelRunner( ...@@ -5294,7 +5305,7 @@ class GPUModelRunner(
num_tokens, num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
allow_microbatching=allow_microbatching, allow_microbatching=allow_microbatching,
activate_lora=activate_lora, num_active_loras=num_active_loras,
is_graph_capturing=True, is_graph_capturing=True,
) )
self.maybe_remove_all_loras(self.lora_config) self.maybe_remove_all_loras(self.lora_config)
......
...@@ -133,11 +133,23 @@ class LoRAModelRunnerMixin: ...@@ -133,11 +133,23 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
num_sampled_tokens: np.ndarray | None = None, num_sampled_tokens: np.ndarray | None = None,
activate_lora: bool = True, num_active_loras: int = 0,
): ):
"""
Context manager to select dummy LoRAs for capture/warmup.
Args:
lora_config: LoRA configuration, or None if LoRA is disabled.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
num_active_loras: Number of distinct active LoRAs to use.
- 0: No LoRA active (set up zero mappings).
- >0: Use exactly this many distinct LoRAs.
"""
if num_sampled_tokens is None: if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
# Skip LoRA setup entirely only if no LoRA config
if lora_config is None: if lora_config is None:
yield yield
else: else:
...@@ -145,15 +157,52 @@ class LoRAModelRunnerMixin: ...@@ -145,15 +157,52 @@ class LoRAModelRunnerMixin:
assert self.lora_manager is not None, "LoRA is not enabled" assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens) num_reqs = len(num_scheduled_tokens)
num_loras = lora_config.max_loras max_loras = lora_config.max_loras
# Determine how many distinct LoRAs to use and whether to include
# no-LoRA tokens (-1 entries).
# When num_active_loras > max_loras (e.g., max_loras + 1), we need
# to include -1 entries to simulate batches with both LoRA and
# no-LoRA tokens. This ensures prepare_tensors computes the correct
# num_active_loras that matches the cudagraph capture key.
if num_active_loras == 0:
# No LoRA active - use 0 mappings like the original code
effective_num_loras = 0
include_no_lora = False
elif num_active_loras > max_loras:
# num_active_loras > max_loras means we want max_loras adapters
# PLUS no-LoRA tokens (-1). This is the max_loras + 1 case.
effective_num_loras = max_loras
include_no_lora = True
else:
# Specific number of active LoRAs requested
effective_num_loras = min(num_active_loras, max_loras)
include_no_lora = False
# Make prompt lora mapping # Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario. # Assign LoRA IDs cyclically to simulate a worst-case scenario.
if activate_lora: # LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest.
prompt_lora_mapping = ( # convert_mapping() will convert these to 0-indexed slot indices.
np.arange(num_reqs, dtype=np.int32) % num_loras if effective_num_loras > 0:
) + 1 if include_no_lora:
# Include -1 (no-LoRA) entries by cycling through
# -1, 1, 2, ..., effective_num_loras
# This ensures prepare_tensors sees both LoRA and no-LoRA
# tokens, computing num_active_loras = effective_num_loras+1
cycle_values = np.array(
list(range(1, effective_num_loras + 1)),
dtype=np.int32,
)
prompt_lora_mapping = cycle_values[
np.arange(num_reqs, dtype=np.int32) % len(cycle_values)
]
else:
# Use 1 to effective_num_loras (1-indexed lora IDs)
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
) + 1
else: else:
# No LoRA active - use 0 for all tokens (original behavior)
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make sample lora mapping # Make sample lora mapping
...@@ -162,14 +211,14 @@ class LoRAModelRunnerMixin: ...@@ -162,14 +211,14 @@ class LoRAModelRunnerMixin:
# Make token lora mapping # Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests # Make dummy lora requests (only for the active LoRAs)
lora_requests: set[LoRARequest] = { lora_requests: set[LoRARequest] = {
LoRARequest( LoRARequest(
lora_name=f"warmup_{lora_id}", lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id, lora_int_id=lora_id,
lora_path="/not/a/real/path", lora_path="/not/a/real/path",
) )
for lora_id in range(1, num_loras + 1) for lora_id in range(1, effective_num_loras + 1)
} }
self._set_active_loras( self._set_active_loras(
...@@ -187,10 +236,21 @@ class LoRAModelRunnerMixin: ...@@ -187,10 +236,21 @@ class LoRAModelRunnerMixin:
lora_config: LoRAConfig | None, lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray, num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray, num_sampled_tokens: np.ndarray,
activate_lora: bool = True,
remove_lora: bool = True, remove_lora: bool = True,
num_active_loras: int = 0,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
): ):
"""
Context manager for dummy runs with LoRA.
Args:
lora_config: LoRA configuration.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
remove_lora: Whether to remove LoRAs after the context exits.
num_active_loras: Number of distinct active LoRAs to use.
LoRA is activated when num_active_loras > 0.
"""
with ( with (
self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras( self.maybe_select_dummy_loras(
...@@ -198,7 +258,7 @@ class LoRAModelRunnerMixin: ...@@ -198,7 +258,7 @@ class LoRAModelRunnerMixin:
num_scheduled_tokens, num_scheduled_tokens,
mapping_type, mapping_type,
num_sampled_tokens, num_sampled_tokens,
activate_lora, num_active_loras,
), ),
): ):
yield yield
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment