Unverified Commit a0a5178a authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Use ModelState.prepare_attn() for cuda graph capture [5/N] (#35774)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 8ea8ba27
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
...@@ -15,13 +14,11 @@ from vllm.forward_context import BatchDescriptor, set_forward_context ...@@ -15,13 +14,11 @@ from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.model_executor.offloader.base import get_offloader from vllm.model_executor.offloader.base import get_offloader
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import build_slot_mappings_by_layer
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.cp_utils import prepare_dcp_local_seq_lens
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
...@@ -123,14 +120,11 @@ class CudaGraphManager: ...@@ -123,14 +120,11 @@ class CudaGraphManager:
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
model_state,
input_buffers, input_buffers,
block_tables, block_tables,
attn_groups, attn_groups,
self.max_model_len,
kv_cache_config, kv_cache_config,
uniform_decode_query_len=(
self.uniform_decode_query_len if uniform_decode else 0
),
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
...@@ -393,51 +387,36 @@ def capture_graphs( ...@@ -393,51 +387,36 @@ def capture_graphs(
def prepare_inputs_to_capture( def prepare_inputs_to_capture(
num_reqs: int, num_reqs: int,
num_tokens: int, num_tokens: int,
model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
max_model_len: int,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
uniform_decode_query_len: int = 0,
) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
if uniform_decode_query_len > 0: input_batch = InputBatch.make_dummy(num_reqs, num_tokens, input_buffers)
num_tokens_per_req = uniform_decode_query_len
else:
num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
query_start_loc_np[-1] = num_tokens
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu
input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens
query_start_loc = input_buffers.query_start_loc[: num_reqs + 1]
# HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens
# rather than max_model_len.
input_buffers.seq_lens[:num_reqs] = num_tokens
input_buffers.seq_lens[num_reqs:] = 0
input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens
input_buffers.dcp_local_seq_lens[num_reqs:] = 0
input_block_tables = block_tables.get_dummy_block_tables(num_reqs) input_block_tables = block_tables.get_dummy_block_tables(num_reqs)
slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens) slot_mappings = block_tables.get_dummy_slot_mappings(num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer( slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config slot_mappings, kv_cache_config
) )
attn_metadata = build_attn_metadata( # HACK(woosuk): Special handling for DCP.
attn_groups=attn_groups, if block_tables.cp_size > 1:
num_reqs=num_reqs, prepare_dcp_local_seq_lens(
num_tokens=num_tokens, input_buffers.dcp_local_seq_lens,
query_start_loc_gpu=query_start_loc, input_batch.seq_lens,
query_start_loc_cpu=query_start_loc_cpu, num_reqs,
max_query_len=num_tokens_per_req, block_tables.cp_size,
seq_lens=input_buffers.seq_lens, block_tables.cp_rank,
max_seq_len=max_model_len, block_tables.cp_interleave,
block_tables=input_block_tables, )
slot_mappings=slot_mappings, input_batch.dcp_local_seq_lens = input_buffers.dcp_local_seq_lens[:num_reqs]
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_buffers.dcp_local_seq_lens, attn_metadata = model_state.prepare_attn(
input_batch,
input_block_tables,
slot_mappings,
attn_groups,
kv_cache_config,
) )
return attn_metadata, slot_mappings_by_layer return attn_metadata, slot_mappings_by_layer
...@@ -82,14 +82,16 @@ class InputBatch: ...@@ -82,14 +82,16 @@ class InputBatch:
num_reqs: int, num_reqs: int,
num_tokens: int, num_tokens: int,
input_buffers: InputBuffers, input_buffers: InputBuffers,
device: torch.device,
) -> "InputBatch": ) -> "InputBatch":
assert 0 < num_reqs <= num_tokens assert 0 < num_reqs <= num_tokens
device = input_buffers.device
req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)] req_ids = [f"req_{i}_{random_uuid()}" for i in range(num_reqs)]
idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device) expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens assert int(num_scheduled_tokens.sum()) == num_tokens
...@@ -115,7 +117,6 @@ class InputBatch: ...@@ -115,7 +117,6 @@ class InputBatch:
input_ids = input_buffers.input_ids[:num_tokens].zero_() input_ids = input_buffers.input_ids[:num_tokens].zero_()
positions = input_buffers.positions[:num_tokens].zero_() positions = input_buffers.positions[:num_tokens].zero_()
# attn_metadata = defaultdict(lambda: None)
logits_indices = query_start_loc[1:] - 1 logits_indices = query_start_loc[1:] - 1
cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32) cu_num_logits = torch.arange(num_reqs + 1, device=device, dtype=torch.int32)
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32) cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
......
...@@ -311,6 +311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -311,6 +311,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None: if self.speculator is not None:
# HACK(woosuk) # HACK(woosuk)
self.speculator.set_attn( self.speculator.set_attn(
self.model_state,
self.kv_cache_config, self.kv_cache_config,
self.attn_groups, self.attn_groups,
self.block_tables, self.block_tables,
...@@ -880,10 +881,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -880,10 +881,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# No actual tokens to run. A dummy run for DP or memory profiling. # No actual tokens to run. A dummy run for DP or memory profiling.
num_reqs = min(num_tokens_after_padding, self.max_num_reqs) num_reqs = min(num_tokens_after_padding, self.max_num_reqs)
input_batch = InputBatch.make_dummy( input_batch = InputBatch.make_dummy(
num_reqs=num_reqs, num_reqs, num_tokens_after_padding, self.input_buffers
num_tokens=num_tokens_after_padding,
input_buffers=self.input_buffers,
device=self.device,
) )
if not skip_attn_for_dummy_run: if not skip_attn_for_dummy_run:
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch) block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
......
...@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import ( ...@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.cudagraph_utils import (
) )
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.utils import AttentionGroup from vllm.v1.worker.utils import AttentionGroup
...@@ -59,6 +60,7 @@ class EagleCudaGraphManager: ...@@ -59,6 +60,7 @@ class EagleCudaGraphManager:
num_tokens: int, num_tokens: int,
capture_cg_mode: CUDAGraphMode, capture_cg_mode: CUDAGraphMode,
generate_fn: Callable, generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
...@@ -76,12 +78,11 @@ class EagleCudaGraphManager: ...@@ -76,12 +78,11 @@ class EagleCudaGraphManager:
attn_metadata, slot_mappings = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
model_state,
input_buffers, input_buffers,
block_tables, block_tables,
attn_groups, attn_groups,
self.max_model_len,
kv_cache_config, kv_cache_config,
uniform_decode_query_len=1,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
...@@ -158,6 +159,7 @@ class EagleCudaGraphManager: ...@@ -158,6 +159,7 @@ class EagleCudaGraphManager:
def capture( def capture(
self, self,
generate_fn: Callable, generate_fn: Callable,
model_state: ModelState,
input_buffers: InputBuffers, input_buffers: InputBuffers,
block_tables: BlockTables, block_tables: BlockTables,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
...@@ -173,6 +175,7 @@ class EagleCudaGraphManager: ...@@ -173,6 +175,7 @@ class EagleCudaGraphManager:
capture_cudagraph_mode=self.cudagraph_mode, capture_cudagraph_mode=self.cudagraph_mode,
desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})", desc=f"Capturing eagle CUDA graphs ({self.cudagraph_mode.name})",
generate_fn=generate_fn, generate_fn=generate_fn,
model_state=model_state,
input_buffers=input_buffers, input_buffers=input_buffers,
block_tables=block_tables, block_tables=block_tables,
attn_groups=attn_groups, attn_groups=attn_groups,
......
...@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.attn_utils import ( ...@@ -17,6 +17,7 @@ from vllm.v1.worker.gpu.attn_utils import (
) )
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.model_states.interface import ModelState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
...@@ -76,10 +77,12 @@ class EagleSpeculator: ...@@ -76,10 +77,12 @@ class EagleSpeculator:
def set_attn( def set_attn(
self, self,
model_state: ModelState,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
attn_groups: list[list[AttentionGroup]], attn_groups: list[list[AttentionGroup]],
block_tables: BlockTables, block_tables: BlockTables,
) -> None: ) -> None:
self.model_state = model_state
self.kv_cache_config = kv_cache_config self.kv_cache_config = kv_cache_config
self.attn_groups = attn_groups self.attn_groups = attn_groups
self.block_tables = block_tables self.block_tables = block_tables
...@@ -171,6 +174,7 @@ class EagleSpeculator: ...@@ -171,6 +174,7 @@ class EagleSpeculator:
logger.info("Capturing model for Eagle speculator...") logger.info("Capturing model for Eagle speculator...")
self.cudagraph_manager.capture( self.cudagraph_manager.capture(
self.generate_draft, self.generate_draft,
self.model_state,
self.input_buffers, self.input_buffers,
self.block_tables, self.block_tables,
self.attn_groups, self.attn_groups,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment