Unverified Commit edf927bc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Fix slot_mapping after #25954 (#33046)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent 22aeb430
...@@ -101,9 +101,6 @@ class CudaGraphManager: ...@@ -101,9 +101,6 @@ class CudaGraphManager:
kv_cache_config, kv_cache_config,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
# Warm up. # Warm up.
with set_forward_context( with set_forward_context(
...@@ -112,7 +109,7 @@ class CudaGraphManager: ...@@ -112,7 +109,7 @@ class CudaGraphManager:
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer, slot_mapping=slot_mappings,
): ):
hidden_states = model( hidden_states = model(
input_ids=input_ids, input_ids=input_ids,
...@@ -132,7 +129,7 @@ class CudaGraphManager: ...@@ -132,7 +129,7 @@ class CudaGraphManager:
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer, slot_mapping=slot_mappings,
), ),
torch.cuda.graph(graph, self.pool), torch.cuda.graph(graph, self.pool),
): ):
...@@ -252,7 +249,7 @@ def prepare_inputs_to_capture( ...@@ -252,7 +249,7 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int, max_model_len: int,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> tuple[dict[str, Any], torch.Tensor]: ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]:
num_tokens_per_req = num_tokens // num_reqs num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
...@@ -269,6 +266,9 @@ def prepare_inputs_to_capture( ...@@ -269,6 +266,9 @@ def prepare_inputs_to_capture(
input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables]
slot_mappings = block_tables.slot_mappings[:, :num_tokens] slot_mappings = block_tables.slot_mappings[:, :num_tokens]
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_metadata_builders=attn_metadata_builders, attn_metadata_builders=attn_metadata_builders,
...@@ -282,4 +282,4 @@ def prepare_inputs_to_capture( ...@@ -282,4 +282,4 @@ def prepare_inputs_to_capture(
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
) )
return attn_metadata, slot_mappings return attn_metadata, slot_mappings_by_layer
...@@ -66,6 +66,8 @@ class InputBatch: ...@@ -66,6 +66,8 @@ class InputBatch:
# layer_name -> Metadata # layer_name -> Metadata
attn_metadata: dict[str, Any] attn_metadata: dict[str, Any]
# layer_name -> slot_mapping
slot_mappings: dict[str, torch.Tensor]
# [total_num_logits] # [total_num_logits]
logits_indices: torch.Tensor logits_indices: torch.Tensor
...@@ -133,6 +135,7 @@ class InputBatch: ...@@ -133,6 +135,7 @@ class InputBatch:
mrope_positions=None, mrope_positions=None,
inputs_embeds=None, inputs_embeds=None,
attn_metadata=None, # type: ignore attn_metadata=None, # type: ignore
slot_mappings=None, # type: ignore
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
......
...@@ -269,6 +269,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -269,6 +269,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mappings = self.block_tables.get_dummy_slot_mappings( slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens input_batch.num_tokens
) )
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
attn_metadata_builders=self.attn_metadata_builders, attn_metadata_builders=self.attn_metadata_builders,
num_reqs=input_batch.num_reqs, num_reqs=input_batch.num_reqs,
...@@ -282,6 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -282,6 +285,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
input_batch.attn_metadata = attn_metadata input_batch.attn_metadata = attn_metadata
input_batch.slot_mappings = slot_mappings_by_layer
@torch.inference_mode() @torch.inference_mode()
def _dummy_run( def _dummy_run(
...@@ -345,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -345,6 +349,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculator.run_model( self.speculator.run_model(
self.max_num_tokens, self.max_num_tokens,
attn_metadata=None, attn_metadata=None,
slot_mappings=None,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -615,6 +620,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -615,6 +620,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
query_start_loc, query_start_loc,
self.input_buffers.positions[:num_tokens], self.input_buffers.positions[:num_tokens],
) )
# Layer name -> slot mapping.
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
# Layer name -> attention metadata. # Layer name -> attention metadata.
attn_metadata = build_attn_metadata( attn_metadata = build_attn_metadata(
...@@ -655,6 +664,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -655,6 +664,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
inputs_embeds=None, inputs_embeds=None,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
logits_indices=logits_indices, logits_indices=logits_indices,
cu_num_logits=cu_num_logits, cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np, cu_num_logits_np=cu_num_logits_np,
...@@ -882,14 +892,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -882,14 +892,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope: if self.uses_mrope:
assert input_batch.mrope_positions is not None assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions positions = input_batch.mrope_positions
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions[: input_batch.num_tokens],
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -897,7 +899,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -897,7 +899,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): Support piecewise CUDA graph. # TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer, slot_mapping=input_batch.slot_mappings,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model( hidden_states = self.model(
......
...@@ -13,7 +13,10 @@ from vllm.model_executor.model_loader import get_model ...@@ -13,7 +13,10 @@ from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
...@@ -108,7 +111,8 @@ class EagleSpeculator: ...@@ -108,7 +111,8 @@ class EagleSpeculator:
def run_model( def run_model(
self, self,
num_tokens: int, num_tokens: int,
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any] | None,
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
with set_forward_context( with set_forward_context(
...@@ -117,6 +121,7 @@ class EagleSpeculator: ...@@ -117,6 +121,7 @@ class EagleSpeculator:
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings,
): ):
ret_hidden_states = self.model( ret_hidden_states = self.model(
input_ids=self.input_buffers.input_ids[:num_tokens], input_ids=self.input_buffers.input_ids[:num_tokens],
...@@ -134,6 +139,7 @@ class EagleSpeculator: ...@@ -134,6 +139,7 @@ class EagleSpeculator:
self, self,
num_reqs: int, num_reqs: int,
attn_metadata: dict[str, Any], attn_metadata: dict[str, Any],
slot_mappings: dict[str, torch.Tensor],
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
) -> None: ) -> None:
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
...@@ -142,7 +148,7 @@ class EagleSpeculator: ...@@ -142,7 +148,7 @@ class EagleSpeculator:
for step in range(1, self.num_speculative_steps): for step in range(1, self.num_speculative_steps):
# Run the eagle model. # Run the eagle model.
last_hidden_states, hidden_states = self.run_model( last_hidden_states, hidden_states = self.run_model(
num_reqs, attn_metadata, num_tokens_across_dp num_reqs, attn_metadata, slot_mappings, num_tokens_across_dp
) )
logits = self.model.compute_logits(last_hidden_states) logits = self.model.compute_logits(last_hidden_states)
...@@ -235,6 +241,7 @@ class EagleSpeculator: ...@@ -235,6 +241,7 @@ class EagleSpeculator:
last_hidden_states, hidden_states = self.run_model( last_hidden_states, hidden_states = self.run_model(
num_tokens, num_tokens,
input_batch.attn_metadata, input_batch.attn_metadata,
input_batch.slot_mappings,
num_tokens_across_dp=None, # FIXME num_tokens_across_dp=None, # FIXME
) )
sample_hidden_states = last_hidden_states[last_token_indices] sample_hidden_states = last_hidden_states[last_token_indices]
...@@ -311,7 +318,12 @@ class EagleSpeculator: ...@@ -311,7 +318,12 @@ class EagleSpeculator:
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=self.kv_cache_config, kv_cache_config=self.kv_cache_config,
) )
self.generate_draft(num_reqs, attn_metadata, num_tokens_across_dp=None) # FIXME slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
self.generate_draft(
num_reqs, attn_metadata, slot_mappings_by_layer, num_tokens_across_dp=None
) # FIXME
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]
......
...@@ -69,7 +69,7 @@ class EagleCudaGraphManager: ...@@ -69,7 +69,7 @@ class EagleCudaGraphManager:
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> None: ) -> None:
num_reqs = min(num_tokens, self.max_num_reqs) num_reqs = min(num_tokens, self.max_num_reqs)
attn_metadata = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
input_buffers, input_buffers,
...@@ -81,13 +81,13 @@ class EagleCudaGraphManager: ...@@ -81,13 +81,13 @@ class EagleCudaGraphManager:
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
# Warm up. # Warm up.
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp) generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
# Capture the graph. # Capture the graph.
assert num_tokens not in self.graphs assert num_tokens not in self.graphs
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, self.pool): with torch.cuda.graph(graph, self.pool):
generate_fn(num_tokens, attn_metadata, num_tokens_across_dp) generate_fn(num_tokens, attn_metadata, slot_mappings, num_tokens_across_dp)
self.graphs[num_tokens] = graph self.graphs[num_tokens] = graph
@torch.inference_mode() @torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment