Unverified Commit f088a831 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Use unpadded num_tokens for PW CUDA graph attn metadata (#36626)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent f83b933b
...@@ -384,6 +384,7 @@ def prepare_inputs_to_capture( ...@@ -384,6 +384,7 @@ def prepare_inputs_to_capture(
attn_metadata = model_state.prepare_attn( attn_metadata = model_state.prepare_attn(
input_batch, input_batch,
CUDAGraphMode.NONE,
input_block_tables, input_block_tables,
slot_mappings, slot_mappings,
attn_groups, attn_groups,
......
...@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -936,6 +936,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert block_tables is not None assert block_tables is not None
attn_metadata = self.model_state.prepare_attn( attn_metadata = self.model_state.prepare_attn(
input_batch, input_batch,
batch_desc.cg_mode,
block_tables, block_tables,
slot_mappings, slot_mappings,
self.attn_groups, self.attn_groups,
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
...@@ -140,14 +141,20 @@ class DefaultModelState(ModelState): ...@@ -140,14 +141,20 @@ class DefaultModelState(ModelState):
def prepare_attn( def prepare_attn(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...], block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: ) -> dict[str, Any]:
# Use padded sizes - padding is handled by model_runner.prepare_attn. if cudagraph_mode == CUDAGraphMode.FULL:
num_reqs = input_batch.num_reqs_after_padding # Use padded sizes - padding is handled by model_runner.prepare_attn.
num_tokens = input_batch.num_tokens_after_padding num_reqs = input_batch.num_reqs_after_padding
num_tokens = input_batch.num_tokens_after_padding
else:
# For piecewise cudagraphs and eager, use unpadded sizes.
num_reqs = input_batch.num_reqs
num_tokens = input_batch.num_tokens
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np) query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item() max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
...@@ -59,6 +60,7 @@ class ModelState(ABC): ...@@ -59,6 +60,7 @@ class ModelState(ABC):
def prepare_attn( def prepare_attn(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
cudagraph_mode: CUDAGraphMode,
block_tables: tuple[torch.Tensor, ...], block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor, slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment