"vscode:/vscode.git/clone" did not exist on "bb59c902480ddb054e7f3f0762b386e0d4e269bd"
Unverified Commit ada4f4fa authored by Runkai Tao's avatar Runkai Tao Committed by GitHub
Browse files

[Fix Bug]`num_active_loras` always equals to zero (#34119)


Signed-off-by: default avatarRunkai Tao <rt572@physics.rutgers.edu>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 7e9149d9
...@@ -187,7 +187,8 @@ def use_fused_moe_lora_kernel( ...@@ -187,7 +187,8 @@ def use_fused_moe_lora_kernel(
# num_active_loras is the number of active LoRAs # num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case) # (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1 # Stored as CPU tensor to match the kernel API (torch.compile compatibility)
num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu")
fused_moe_lora( fused_moe_lora(
output, output,
...@@ -399,7 +400,8 @@ def use_fused_moe_lora_kernel_naive( ...@@ -399,7 +400,8 @@ def use_fused_moe_lora_kernel_naive(
# num_active_loras is the number of active LoRAs # num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case) # (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1 # Stored as CPU tensor to match the kernel API (torch.compile compatibility)
num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu")
fused_moe_lora( fused_moe_lora(
output, output,
......
...@@ -70,8 +70,12 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: ...@@ -70,8 +70,12 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
@pytest.mark.parametrize("mxfp4_use_marlin", [True, False]) @pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
@pytest.mark.parametrize("specialize_active_lora", [True, False])
def test_gpt_oss_lora( def test_gpt_oss_lora(
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin monkeypatch: pytest.MonkeyPatch,
gptoss20b_lora_files,
mxfp4_use_marlin,
specialize_active_lora,
): ):
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0") m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
...@@ -83,6 +87,7 @@ def test_gpt_oss_lora( ...@@ -83,6 +87,7 @@ def test_gpt_oss_lora(
max_lora_rank=8, max_lora_rank=8,
max_num_seqs=2, max_num_seqs=2,
max_num_batched_tokens=2048, max_num_batched_tokens=2048,
specialize_active_lora=specialize_active_lora,
compilation_config=vllm.config.CompilationConfig( # Avoid OOM compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False, cudagraph_specialize_lora=False,
), ),
......
...@@ -127,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): ...@@ -127,7 +127,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
def _adjust_kernel_inputs( def _adjust_kernel_inputs(
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
sorted_token_ids: torch.Tensor | None, sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
): ):
...@@ -141,7 +141,7 @@ def _adjust_kernel_inputs( ...@@ -141,7 +141,7 @@ def _adjust_kernel_inputs(
else: else:
stride_tl = sorted_token_ids.stride(0) stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0) stride_el = expert_ids.stride(0)
grid_lora_dim = num_active_loras grid_lora_dim = num_active_loras.item()
return grid_lora_dim, stride_tl, stride_el return grid_lora_dim, stride_tl, stride_el
...@@ -444,7 +444,7 @@ def _fused_moe_lora_shrink( ...@@ -444,7 +444,7 @@ def _fused_moe_lora_shrink(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
use_gdc: bool = False, use_gdc: bool = False,
use_tma: bool = False, use_tma: bool = False,
...@@ -562,7 +562,7 @@ def _fused_moe_lora_expand( ...@@ -562,7 +562,7 @@ def _fused_moe_lora_expand(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
offset: int = 0, offset: int = 0,
use_gdc: bool = False, use_gdc: bool = False,
...@@ -683,7 +683,7 @@ def _fused_moe_lora( ...@@ -683,7 +683,7 @@ def _fused_moe_lora(
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor, adapter_enabled: torch.Tensor,
shrink_block_size_m: int, shrink_block_size_m: int,
shrink_block_size_n: int, shrink_block_size_n: int,
...@@ -871,7 +871,7 @@ def _fused_moe_lora_fake( ...@@ -871,7 +871,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int, max_lora_rank: int,
top_k_num: int, top_k_num: int,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor, adapter_enabled: torch.Tensor,
shrink_block_size_m: int, shrink_block_size_m: int,
shrink_block_size_n: int, shrink_block_size_n: int,
...@@ -921,7 +921,7 @@ def _fused_moe_lora_shrink_fake( ...@@ -921,7 +921,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
use_gdc: bool = False, use_gdc: bool = False,
use_tma: bool = False, use_tma: bool = False,
...@@ -958,7 +958,7 @@ def _fused_moe_lora_expand_fake( ...@@ -958,7 +958,7 @@ def _fused_moe_lora_expand_fake(
num_warps: int, num_warps: int,
num_stages: int, num_stages: int,
split_k: int, split_k: int,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False, mul_routed_weight: bool = False,
offset: int = 0, offset: int = 0,
use_gdc: bool = False, use_gdc: bool = False,
......
...@@ -138,7 +138,7 @@ def _lora_expand( ...@@ -138,7 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1] no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat) num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
...@@ -235,7 +235,7 @@ def _lora_expand( ...@@ -235,7 +235,7 @@ def _lora_expand(
grid = ( grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES, NUM_SLICES,
num_active_loras, num_active_loras.item(),
) )
# We disable PDL temporarily because LoRA kernels are not launching back-to-back, # We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance. # making PDL invalid and affecting the kernel performance.
...@@ -289,7 +289,7 @@ def _lora_expand_fake( ...@@ -289,7 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor, no_lora_flag_cpu: torch.Tensor,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
......
...@@ -29,9 +29,16 @@ class LoRAKernelMeta: ...@@ -29,9 +29,16 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation. # to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping) # Number of active LoRAs (unique non-(-1) values in token_lora_mapping).
# Stored as a Python int to avoid GPU->CPU sync during forward pass # Stored as a CPU tensor (not a Python int) so that torch.compile treats
num_active_loras: int = 0 # it as a dynamic value rather than baking it as a constant at trace time.
# This follows the same pattern as no_lora_flag_cpu above.
num_active_loras_cpu: torch.Tensor
# Default num_active_loras value (max_loras + 1) as a CPU tensor,
# used when specialize_active_lora is False to avoid allocating a
# new tensor on every meta_args() call.
default_num_active_loras_cpu: torch.Tensor
# Captured LoRA counts for cudagraph specialization (sorted list). # Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up # When specialize_active_lora is enabled, num_active_loras is rounded up
...@@ -73,6 +80,11 @@ class LoRAKernelMeta: ...@@ -73,6 +80,11 @@ class LoRAKernelMeta:
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu") no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu")
default_num_active_loras_cpu = torch.tensor(
[max_loras + 1], dtype=torch.int32, device="cpu"
)
return LoRAKernelMeta( return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping, token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
...@@ -80,6 +92,8 @@ class LoRAKernelMeta: ...@@ -80,6 +92,8 @@ class LoRAKernelMeta:
num_tokens_per_lora=num_tokens_per_lora, num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc, lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu, no_lora_flag_cpu=no_lora_flag_cpu,
num_active_loras_cpu=num_active_loras_cpu,
default_num_active_loras_cpu=default_num_active_loras_cpu,
captured_lora_counts=sorted(captured_lora_counts) captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts if captured_lora_counts
else [], else [],
...@@ -90,8 +104,7 @@ class LoRAKernelMeta: ...@@ -90,8 +104,7 @@ class LoRAKernelMeta:
self.num_tokens_per_lora.fill_(0) self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0) self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False) self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0 self.num_active_loras_cpu.fill_(0)
self.captured_lora_counts = []
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
""" """
...@@ -137,14 +150,16 @@ class LoRAKernelMeta: ...@@ -137,14 +150,16 @@ class LoRAKernelMeta:
num_tokens_per_lora, non_blocking=True num_tokens_per_lora, non_blocking=True
) )
self.num_active_loras = lora_ids.size(0) num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys. # Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph. # This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0: if self.captured_lora_counts and num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras) idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
if idx < len(self.captured_lora_counts): if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx] num_active_loras = self.captured_lora_counts[idx]
self.num_active_loras_cpu[0] = num_active_loras
# lora_token_start_loc # lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
...@@ -163,7 +178,7 @@ class LoRAKernelMeta: ...@@ -163,7 +178,7 @@ class LoRAKernelMeta:
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor,
int, torch.Tensor,
]: ]:
""" """
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
...@@ -175,7 +190,10 @@ class LoRAKernelMeta: ...@@ -175,7 +190,10 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward token_nums (int): Number of input tokens in the current forward
pass of the kernel. pass of the kernel.
""" """
max_loras = self.active_lora_ids.size(0) - 1 if specialize_active_lora:
num_active_loras = self.num_active_loras_cpu
else:
num_active_loras = self.default_num_active_loras_cpu
return ( return (
self.token_lora_mapping[:token_nums], self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums], self.token_indices_sorted_by_lora_ids[:token_nums],
...@@ -183,5 +201,5 @@ class LoRAKernelMeta: ...@@ -183,5 +201,5 @@ class LoRAKernelMeta:
self.lora_token_start_loc, self.lora_token_start_loc,
self.active_lora_ids, self.active_lora_ids,
self.no_lora_flag_cpu, self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1, num_active_loras,
) )
...@@ -134,7 +134,7 @@ def _lora_shrink( ...@@ -134,7 +134,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1] no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat) num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
...@@ -157,6 +157,9 @@ def _lora_shrink( ...@@ -157,6 +157,9 @@ def _lora_shrink(
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA. if there are any requests that require LoRA.
num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the
number of active LoRAs. Stored as a tensor (not int) so
torch.compile treats it as dynamic rather than a constant.
scaling (float): Scaling factor. scaling (float): Scaling factor.
""" """
...@@ -215,7 +218,7 @@ def _lora_shrink( ...@@ -215,7 +218,7 @@ def _lora_shrink(
grid = ( grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES, NUM_SLICES,
num_active_loras, num_active_loras.item(),
) )
# We disable PDL temporarily because LoRA kernels are not launching back-to-back, # We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance. # making PDL invalid and affecting the kernel performance.
...@@ -267,7 +270,7 @@ def _lora_shrink_fake( ...@@ -267,7 +270,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor, no_lora_flag_cpu: torch.Tensor,
num_active_loras: int, num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float, scaling: float,
) -> None: ) -> None:
return return
......
...@@ -5379,6 +5379,7 @@ class GPUModelRunner( ...@@ -5379,6 +5379,7 @@ class GPUModelRunner(
# if we want to warm up attention or not. This is # if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture # different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention. # attention while `PIECEWISE` implies no attention.
dummy_run( dummy_run(
num_tokens, num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment