Unverified Commit 3d66502e authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Prepare attn metadata in ModelState [2/N] (#35383)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent c66aa48e
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
import numpy as np import numpy as np
import torch import torch
...@@ -60,6 +59,8 @@ class InputBatch: ...@@ -60,6 +59,8 @@ class InputBatch:
query_start_loc_np: np.ndarray query_start_loc_np: np.ndarray
# [num_reqs] # [num_reqs]
seq_lens: torch.Tensor seq_lens: torch.Tensor
# [num_reqs]
dcp_local_seq_lens: torch.Tensor | None
# [num_tokens_after_padding] # [num_tokens_after_padding]
input_ids: torch.Tensor input_ids: torch.Tensor
...@@ -68,11 +69,6 @@ class InputBatch: ...@@ -68,11 +69,6 @@ class InputBatch:
# [num_tokens_after_padding, hidden_size] # [num_tokens_after_padding, hidden_size]
inputs_embeds: torch.Tensor | None inputs_embeds: torch.Tensor | None
# layer_name -> Metadata
attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits] # [total_num_logits]
logits_indices: torch.Tensor logits_indices: torch.Tensor
# [num_reqs + 1] # [num_reqs + 1]
...@@ -139,11 +135,10 @@ class InputBatch: ...@@ -139,11 +135,10 @@ class InputBatch:
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np, query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens, seq_lens=seq_lens,
dcp_local_seq_lens=None,
input_ids=input_ids, input_ids=input_ids,
positions=positions, positions=positions,
inputs_embeds=None, inputs_embeds=None,
attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
......
...@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput ...@@ -46,7 +46,6 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.cp_utils import check_attention_cp_compatibility from vllm.v1.worker.cp_utils import check_attention_cp_compatibility
from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer, build_slot_mappings_by_layer,
get_kv_cache_spec, get_kv_cache_spec,
init_attn_backend, init_attn_backend,
...@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -317,31 +316,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict) self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=torch.from_numpy(input_batch.query_start_loc_np),
max_query_len=input_batch.num_scheduled_tokens.max().item(),
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=self.input_buffers.dcp_local_seq_lens,
)
input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
self, num_tokens: int, *args, skip_attn: bool = True, **kwargs self, num_tokens: int, *args, skip_attn: bool = True, **kwargs
...@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -384,7 +358,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None return None, None
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, _, input_batch, _ = self.execute_model_state input_batch, _, _, _, hidden_states, _, _ = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
assert hidden_states is not None # Last PP rank always has hidden_states assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
...@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -546,7 +520,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.encoder_runner.add_request(req_id, new_req_data.mm_features) self.encoder_runner.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data) self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids( self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True req_index, new_req_data.block_ids, overwrite=True
) )
...@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -624,9 +597,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping, total_num_logits, cu_num_logits, max_expand_len idx_mapping, total_num_logits, cu_num_logits, max_expand_len
) )
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(idx_mapping)
# Get query_start_loc. # Get query_start_loc.
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32) query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0 query_start_loc_np[0] = 0
...@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -635,11 +605,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends like FA3 require query_start_loc to be non-decreasing. # Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc) async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs + 1] query_start_loc_np = query_start_loc_np[: num_reqs + 1]
query_start_loc_cpu = torch.from_numpy(query_start_loc_np)
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
max_query_len = num_scheduled_tokens.max().item()
# Get prefill tokens if any. # Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np): if self.req_states.any_prefills(idx_mapping_np):
...@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -663,6 +630,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
seq_lens = self.input_buffers.seq_lens[:num_reqs] seq_lens = self.input_buffers.seq_lens[:num_reqs]
dcp_local_seq_lens = None
if self.use_dcp: if self.use_dcp:
# Prepare dcp local seq_lens. # Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens( prepare_dcp_local_seq_lens(
...@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -673,7 +641,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.dcp_rank, self.dcp_rank,
self.cp_interleave, self.cp_interleave,
) )
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs] dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs]
# Some input token ids are directly read from the last sampled tokens # Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from. # and draft tokens. Also, get the logits indices to sample tokens from.
...@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -689,35 +657,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
total_num_logits, total_num_logits,
) )
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
idx_mapping,
query_start_loc,
self.input_buffers.positions[:num_tokens],
)
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# Layer name -> attention metadata.
attn_metadata = build_attn_metadata(
attn_groups=self.attn_groups,
num_reqs=num_reqs,
num_tokens=num_tokens,
query_start_loc_gpu=query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=self.input_buffers.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config,
dcp_local_seq_lens=dcp_local_seq_lens,
)
input_ids = self.input_buffers.input_ids[:num_tokens_after_padding]
positions = self.input_buffers.positions[:num_tokens_after_padding]
return InputBatch( return InputBatch(
req_ids=req_ids, req_ids=req_ids,
num_reqs=num_reqs, num_reqs=num_reqs,
...@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -732,17 +671,38 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np, query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens, seq_lens=seq_lens,
input_ids=input_ids, dcp_local_seq_lens=dcp_local_seq_lens,
positions=positions, input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
inputs_embeds=None, inputs_embeds=None,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests, has_structured_output_reqs=scheduler_output.has_structured_output_requests,
) )
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs, max_num_blocks]
block_tables = self.block_tables.gather_block_tables(input_batch.idx_mapping)
# Compute slot mappings: [num_kv_cache_groups, num_tokens]
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
)
return block_tables, slot_mappings
def prepare_dummy_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
return block_tables, slot_mappings
@torch.inference_mode() @torch.inference_mode()
def get_mm_embeddings( def get_mm_embeddings(
self, self,
...@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -899,6 +859,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch = self.prepare_inputs( input_batch = self.prepare_inputs(
scheduler_output, num_tokens_after_padding scheduler_output, num_tokens_after_padding
) )
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config: if self.lora_config:
# Activate LoRA adapters. # Activate LoRA adapters.
lora_inputs = self.lora_state.make_lora_inputs( lora_inputs = self.lora_state.make_lora_inputs(
...@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -929,9 +891,28 @@ class GPUModelRunner(LoRAModelRunnerMixin):
device=self.device, device=self.device,
) )
if not skip_attn_for_dummy_run: if not skip_attn_for_dummy_run:
self.prepare_dummy_attn_metadata(input_batch) block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
else:
block_tables = None
slot_mappings = None
# FIXME(woosuk): Fix warmup for LoRA. # FIXME(woosuk): Fix warmup for LoRA.
attn_metadata = None
slot_mappings_by_layer = None
if not (dummy_run and skip_attn_for_dummy_run):
assert slot_mappings is not None
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
assert block_tables is not None
attn_metadata = self.model_state.prepare_attn(
input_batch,
block_tables,
slot_mappings,
self.attn_groups,
self.kv_cache_config,
)
model_inputs = { model_inputs = {
"input_ids": input_batch.input_ids, "input_ids": input_batch.input_ids,
"positions": input_batch.positions, "positions": input_batch.positions,
...@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -968,13 +949,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, attn_metadata,
self.vllm_config, self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding, num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
slot_mapping=input_batch.slot_mappings, slot_mapping=slot_mappings_by_layer,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs) model_output = self.model(**model_inputs)
...@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -985,22 +966,23 @@ class GPUModelRunner(LoRAModelRunnerMixin):
aux_hidden_states = None aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output) kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = (
input_batch,
model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
)
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending. # Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors) assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, None, input_batch, kv_connector_output)
return hidden_states return hidden_states
# Last rank (or no PP): hidden_states is a tensor for sampling. # Last rank (or no PP): hidden_states is a tensor for sampling.
assert isinstance(hidden_states, torch.Tensor) assert isinstance(hidden_states, torch.Tensor)
self.execute_model_state = (
hidden_states,
aux_hidden_states,
input_batch,
kv_connector_output,
)
return None return None
@torch.inference_mode() @torch.inference_mode()
...@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1010,9 +992,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.execute_model_state is None: if self.execute_model_state is None:
# The prior execute_model call must have failed. # The prior execute_model call must have failed.
return None return None
hidden_states, aux_hidden_states, input_batch, kv_connector_output = ( (
self.execute_model_state input_batch,
) model_inputs,
attn_metadata,
slot_mappings_by_layer,
hidden_states,
aux_hidden_states,
kv_connector_output,
) = self.execute_model_state
self.execute_model_state = None self.execute_model_state = None
if not self.is_last_pp_rank: if not self.is_last_pp_rank:
...@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1075,6 +1063,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.speculator is not None: if self.speculator is not None:
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
attn_metadata,
slot_mappings_by_layer,
hidden_states, hidden_states,
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.core.sched.output import NewRequestData from vllm.v1.core.sched.output import NewRequestData
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.utils import AttentionGroup
class ModelState: class ModelState:
...@@ -72,3 +77,29 @@ class ModelState: ...@@ -72,3 +77,29 @@ class ModelState:
return {} return {}
mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens] mrope_positions = self.mrope_state.mrope_positions[:, :num_tokens]
return {"positions": mrope_positions} return {"positions": mrope_positions}
def prepare_attn(
self,
input_batch: InputBatch,
block_tables: tuple[torch.Tensor, ...],
slot_mappings: torch.Tensor,
attn_groups: list[list[AttentionGroup]],
kv_cache_config: KVCacheConfig,
) -> dict[str, Any]:
query_start_loc_cpu = torch.from_numpy(input_batch.query_start_loc_np)
max_query_len = input_batch.num_scheduled_tokens.max().item()
attn_metadata = build_attn_metadata(
attn_groups=attn_groups,
num_reqs=input_batch.num_reqs,
num_tokens=input_batch.num_tokens,
query_start_loc_gpu=input_batch.query_start_loc,
query_start_loc_cpu=query_start_loc_cpu,
max_query_len=max_query_len,
seq_lens=input_batch.seq_lens,
max_seq_len=self.max_model_len,
block_tables=block_tables,
slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config,
dcp_local_seq_lens=input_batch.dcp_local_seq_lens,
)
return attn_metadata
...@@ -182,6 +182,8 @@ class EagleSpeculator: ...@@ -182,6 +182,8 @@ class EagleSpeculator:
def propose( def propose(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
# [num_tokens, hidden_size] # [num_tokens, hidden_size]
last_hidden_states: torch.Tensor, last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size] # num_layers x [num_tokens, hidden_size]
...@@ -229,8 +231,8 @@ class EagleSpeculator: ...@@ -229,8 +231,8 @@ class EagleSpeculator:
# TODO(woosuk): Support CUDA graph for prefill. # TODO(woosuk): Support CUDA graph for prefill.
last_hidden_states, hidden_states = self.run_model( last_hidden_states, hidden_states = self.run_model(
num_tokens, num_tokens,
input_batch.attn_metadata, attn_metadata,
input_batch.slot_mappings, slot_mappings,
num_tokens_across_dp=None, # FIXME num_tokens_across_dp=None, # FIXME
) )
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment