Unverified Commit a28b94e6 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Performance] Split FlashAttn attention and cache update (#25954)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Signed-off-by: default avatarLuka Govedič <luka.govedic@gmail.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <luka.govedic@gmail.com>
Co-authored-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarLuka Govedič <lgovedic@redhat.com>
parent 0118cdcc
...@@ -13,6 +13,7 @@ from tests.v1.attention.utils import ( ...@@ -13,6 +13,7 @@ from tests.v1.attention.utils import (
create_common_attn_metadata, create_common_attn_metadata,
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
try_backend_includes_kv_cache_update,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -295,6 +296,10 @@ def run_attention_backend( ...@@ -295,6 +296,10 @@ def run_attention_backend(
# Run forward pass # Run forward pass
# NOTE: The query, key, and value are already shaped correctly # NOTE: The query, key, and value are already shaped correctly
# in the calling test function. # in the calling test function.
if not try_backend_includes_kv_cache_update(actual_backend):
impl.do_kv_cache_update(
mock_layer, key, value, kv_cache, attn_metadata.slot_mapping
)
output = impl.forward( output = impl.forward(
mock_layer, query, key, value, kv_cache, attn_metadata, output=output mock_layer, query, key, value, kv_cache, attn_metadata, output=output
) )
......
...@@ -130,6 +130,18 @@ def try_get_attention_backend( ...@@ -130,6 +130,18 @@ def try_get_attention_backend(
raise AssertionError("unreachable") from None raise AssertionError("unreachable") from None
def try_backend_includes_kv_cache_update(
backend: AttentionBackendEnum,
) -> bool:
"""Try to get the attention backend class, skipping test if not found."""
try:
backend_class = backend.get_class()
return backend_class.forward_includes_kv_cache_update
except ImportError as e:
pytest.skip(f"{backend.name} not available: {e}")
raise AssertionError("unreachable") from None
def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec: def create_standard_kv_cache_spec(vllm_config: VllmConfig) -> FullAttentionSpec:
"""Create a FullAttentionSpec from ModelParams only.""" """Create a FullAttentionSpec from ModelParams only."""
return FullAttentionSpec( return FullAttentionSpec(
......
...@@ -86,7 +86,7 @@ class DecodeBenchTestRunner: ...@@ -86,7 +86,7 @@ class DecodeBenchTestRunner:
self._block_hasher = get_request_block_hasher(block_size, sha256) self._block_hasher = get_request_block_hasher(block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext( self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0 no_compile_layers={}, attn_metadata={}, virtual_engine=0, slot_mapping={}
) )
def new_request(self, token_ids: list[int]) -> Request: def new_request(self, token_ids: list[int]) -> Request:
......
...@@ -548,6 +548,7 @@ class TestNixlHandshake: ...@@ -548,6 +548,7 @@ class TestNixlHandshake:
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -618,6 +619,7 @@ class TestNixlHandshake: ...@@ -618,6 +619,7 @@ class TestNixlHandshake:
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -844,6 +846,7 @@ class TestNixlHandshake: ...@@ -844,6 +846,7 @@ class TestNixlHandshake:
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
_before_load = time.perf_counter() _before_load = time.perf_counter()
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -1006,6 +1009,7 @@ def test_kv_connector_stats(default_vllm_config, dist_init): ...@@ -1006,6 +1009,7 @@ def test_kv_connector_stats(default_vllm_config, dist_init):
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -1767,6 +1771,7 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_ ...@@ -1767,6 +1771,7 @@ def test_aborted_request_removed_from_worker_in_batch(default_vllm_config, dist_
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -1917,6 +1922,7 @@ def test_transfer_failure_logging( ...@@ -1917,6 +1922,7 @@ def test_transfer_failure_logging(
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
# Capture logs from the nixl_connector logger specifically # Capture logs from the nixl_connector logger specifically
...@@ -2017,6 +2023,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init): ...@@ -2017,6 +2023,7 @@ def test_handshake_failure_returns_finished(default_vllm_config, dist_init):
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
...@@ -2067,6 +2074,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init) ...@@ -2067,6 +2074,7 @@ def test_transfer_setup_failure_returns_finished(default_vllm_config, dist_init)
no_compile_layers={}, no_compile_layers={},
attn_metadata={}, attn_metadata={},
virtual_engine=0, virtual_engine=0,
slot_mapping={},
) )
connector.start_load_kv(dummy_ctx) connector.start_load_kv(dummy_ctx)
......
...@@ -209,7 +209,10 @@ class RequestRunner: ...@@ -209,7 +209,10 @@ class RequestRunner:
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256) self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
self._dummy_ctx: ForwardContext = ForwardContext( self._dummy_ctx: ForwardContext = ForwardContext(
no_compile_layers={}, attn_metadata={}, virtual_engine=0 no_compile_layers={},
attn_metadata={},
virtual_engine=0,
slot_mapping={},
) )
def new_request(self, token_ids: list[int]): def new_request(self, token_ids: list[int]):
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from tests.v1.attention.utils import ( from tests.v1.attention.utils import (
create_standard_kv_cache_spec, create_standard_kv_cache_spec,
create_vllm_config, create_vllm_config,
try_backend_includes_kv_cache_update,
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
...@@ -120,6 +121,14 @@ def forward_attention( ...@@ -120,6 +121,14 @@ def forward_attention(
key = k.view(-1, num_kv_heads, dim_per_head) key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head) value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query) output = torch.empty_like(query)
if not try_backend_includes_kv_cache_update(backend):
instance.do_kv_cache_update(
layer=layer,
key=key,
value=value,
kv_cache=kv_cache,
slot_mapping=attn_metadata.slot_mapping,
)
return instance.forward( return instance.forward(
layer=layer, layer=layer,
query=query, query=query,
......
...@@ -390,29 +390,43 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -390,29 +390,43 @@ class Attention(nn.Module, AttentionLayerBase):
if value is not None: if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v) value = value.view(-1, self.num_kv_heads, self.head_size_v)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() kv_cache_dummy_dep = None
attn_metadata = forward_context.attn_metadata if not self.attn_backend.forward_includes_kv_cache_update:
if isinstance(attn_metadata, dict): kv_cache_dummy_dep = unified_kv_cache_update(
attn_metadata = attn_metadata[self.layer_name] key, value, self.layer_name
self_kv_cache = self.kv_cache[forward_context.virtual_engine] )
self.impl.forward( unified_attention_with_output(
self, query, key, value, self_kv_cache, attn_metadata, output=output query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
else: else:
kv_cache_dummy_dep = None
if not self.attn_backend.forward_includes_kv_cache_update and (
# torch can only dispatch custom op if a tensor is passed
key is not None or value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
)
torch.ops.vllm.unified_attention_with_output( torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
) )
return output.view(-1, hidden_size) return output.view(-1, hidden_size)
else: else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
)
if self.use_direct_call: if self.use_direct_call:
forward_context = get_forward_context() return unified_attention(query, key, value, self.layer_name)
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(
self, query, key, value, self_kv_cache, attn_metadata
)
else: else:
return torch.ops.vllm.unified_attention( return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name query, key, value, self.layer_name
...@@ -802,6 +816,55 @@ direct_register_custom_op( ...@@ -802,6 +816,55 @@ direct_register_custom_op(
) )
def unified_kv_cache_update(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
forward_context = get_forward_context()
attn_layer = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer,
key,
value,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def unified_kv_cache_update_fake(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty(0, device=key.device, dtype=key.dtype)
direct_register_custom_op(
op_name="unified_kv_cache_update",
op_func=unified_kv_cache_update,
fake_impl=unified_kv_cache_update_fake,
mutates_args=[],
)
@maybe_transfer_kv_layer @maybe_transfer_kv_layer
def unified_attention_with_output( def unified_attention_with_output(
query: torch.Tensor, query: torch.Tensor,
...@@ -811,7 +874,12 @@ def unified_attention_with_output( ...@@ -811,7 +874,12 @@ def unified_attention_with_output(
layer_name: str, layer_name: str,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None: ) -> None:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, self, kv_cache = get_attention_context(layer_name) attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward( self.impl.forward(
...@@ -835,6 +903,7 @@ def unified_attention_with_output_fake( ...@@ -835,6 +903,7 @@ def unified_attention_with_output_fake(
layer_name: str, layer_name: str,
output_scale: torch.Tensor | None = None, output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None, output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None: ) -> None:
return return
......
...@@ -189,6 +189,7 @@ class ForwardContext: ...@@ -189,6 +189,7 @@ class ForwardContext:
# copy from vllm_config.compilation_config.static_forward_context # copy from vllm_config.compilation_config.static_forward_context
no_compile_layers: dict[str, Any] no_compile_layers: dict[str, Any]
attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]] attn_metadata: dict[str, AttentionMetadata] | list[dict[str, AttentionMetadata]]
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]]
""" """
Type Dict[str, AttentionMetadata] for v1, map from layer_name of each Type Dict[str, AttentionMetadata] for v1, map from layer_name of each
attention layer to its attention metadata attention layer to its attention metadata
...@@ -266,6 +267,7 @@ def create_forward_context( ...@@ -266,6 +267,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | None = None,
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
): ):
...@@ -282,6 +284,7 @@ def create_forward_context( ...@@ -282,6 +284,7 @@ def create_forward_context(
remaining_moe_layers=remaining_moe_layers, remaining_moe_layers=remaining_moe_layers,
virtual_engine=virtual_engine, virtual_engine=virtual_engine,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
slot_mapping=slot_mapping or {},
dp_metadata=dp_metadata, dp_metadata=dp_metadata,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
...@@ -316,6 +319,7 @@ def set_forward_context( ...@@ -316,6 +319,7 @@ def set_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
...@@ -374,6 +378,7 @@ def set_forward_context( ...@@ -374,6 +378,7 @@ def set_forward_context(
cudagraph_runtime_mode, cudagraph_runtime_mode,
batch_descriptor, batch_descriptor,
ubatch_slices, ubatch_slices,
slot_mapping,
additional_kwargs, additional_kwargs,
skip_compiled, skip_compiled,
) )
......
...@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import ( ...@@ -15,7 +15,7 @@ from vllm.v1.attention.backend import (
AttentionMetadata, AttentionMetadata,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
subclass_attention_backend, subclass_attention_backend_with_overrides,
) )
from vllm.v1.attention.selector import get_attn_backend from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
...@@ -72,6 +72,7 @@ def create_cross_attention_backend( ...@@ -72,6 +72,7 @@ def create_cross_attention_backend(
) -> type[AttentionBackend]: ) -> type[AttentionBackend]:
prefix = "CrossAttention_" prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls() underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore class CrossAttentionBuilder(underlying_builder): # type: ignore
def build( def build(
...@@ -106,18 +107,60 @@ def create_cross_attention_backend( ...@@ -106,18 +107,60 @@ def create_cross_attention_backend(
) )
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here # NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
new_metadata.slot_mapping = _get_cross_slot_mapping( slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens_cpu, new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor, new_metadata.block_table_tensor,
self.kv_cache_spec, self.kv_cache_spec,
self.device, self.device,
) )
return super().build(common_prefix_len, new_metadata, fast_build) attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping
return attn_metadata
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
# (gpu_model_runner)
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
attn_backend = subclass_attention_backend( attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix, name_prefix=prefix,
attention_backend_cls=underlying_attn_backend, attention_backend_cls=underlying_attn_backend,
builder_cls=CrossAttentionBuilder, overrides={
"get_builder_cls": lambda: CrossAttentionBuilder,
"get_impl_cls": lambda: CrossAttentionImpl,
"forward_includes_kv_cache_update": True,
},
) )
return attn_backend return attn_backend
......
...@@ -613,8 +613,9 @@ def weak_ref_tensor(tensor: Any) -> Any: ...@@ -613,8 +613,9 @@ def weak_ref_tensor(tensor: Any) -> Any:
Create a weak reference to a tensor. Create a weak reference to a tensor.
The new tensor will share the same data as the original tensor, The new tensor will share the same data as the original tensor,
but will not keep the original tensor alive. but will not keep the original tensor alive.
This ignores 0-size tensors as those don't allocate any memory.
""" """
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor) and tensor.numel() > 0:
return torch.ops._C.weak_ref_tensor(tensor) return torch.ops._C.weak_ref_tensor(tensor)
else: else:
return tensor return tensor
......
...@@ -53,6 +53,9 @@ class AttentionBackend(ABC): ...@@ -53,6 +53,9 @@ class AttentionBackend(ABC):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"] supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto", "bfloat16"]
# Does attention's forward() include kv cache update?
forward_includes_kv_cache_update: bool = True
@staticmethod @staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(1)] return [MultipleOf(1)]
......
...@@ -79,6 +79,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -79,6 +79,8 @@ class FlashAttentionBackend(AttentionBackend):
return [16, 32, 64] return [16, 32, 64]
return [MultipleOf(16)] return [MultipleOf(16)]
forward_includes_kv_cache_update: bool = False
@staticmethod @staticmethod
def get_name() -> str: def get_name() -> str:
return "FLASH_ATTN" return "FLASH_ATTN"
...@@ -652,32 +654,6 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -652,32 +654,6 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
# queries are quantized in the attention layer # queries are quantized in the attention layer
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
...@@ -774,6 +750,49 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -774,6 +750,49 @@ class FlashAttentionImpl(AttentionImpl):
) )
return output return output
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (
self.kv_sharing_target_layer_name is not None
or key is None
or value is None
):
return
key_cache, value_cache = kv_cache.unbind(0)
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def _forward_with_dcp( def _forward_with_dcp(
self, self,
query: torch.Tensor, query: torch.Tensor,
......
...@@ -159,6 +159,10 @@ class SpecDecodeBaseProposer: ...@@ -159,6 +159,10 @@ class SpecDecodeBaseProposer:
with_numpy=True, with_numpy=True,
) )
self._slot_mapping_buffer = torch.zeros(
self.max_num_tokens, dtype=torch.int64, device=device
)
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple | None = None self.allowed_attn_types: tuple | None = None
if current_platform.is_rocm(): if current_platform.is_rocm():
...@@ -230,6 +234,24 @@ class SpecDecodeBaseProposer: ...@@ -230,6 +234,24 @@ class SpecDecodeBaseProposer:
positions = positions[0] positions = positions[0]
self.positions[:num_tokens] = positions self.positions[:num_tokens] = positions
def _get_slot_mapping(
self,
num_tokens: int,
slot_mapping: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return slot_mapping dict for EAGLE layers.
If slot_mapping is provided, copies it into the buffer first.
"""
if slot_mapping is not None:
num_actual = slot_mapping.shape[0]
self._slot_mapping_buffer[:num_actual].copy_(slot_mapping)
if num_tokens > num_actual:
self._slot_mapping_buffer[num_actual:num_tokens].fill_(PADDING_SLOT_ID)
view = self._slot_mapping_buffer[:num_tokens]
return {name: view for name in self.attn_layer_names + self.indexer_layer_names}
def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None: def initialize_cudagraph_keys(self, cudagraph_mode: CUDAGraphMode) -> None:
"""Initialize cudagraph dispatcher keys for eagle. """Initialize cudagraph dispatcher keys for eagle.
...@@ -262,6 +284,9 @@ class SpecDecodeBaseProposer: ...@@ -262,6 +284,9 @@ class SpecDecodeBaseProposer:
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None, mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None,
num_rejected_tokens_gpu: torch.Tensor | None = None, num_rejected_tokens_gpu: torch.Tensor | None = None,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size = common_attn_metadata.batch_size() batch_size = common_attn_metadata.batch_size()
...@@ -358,6 +383,9 @@ class SpecDecodeBaseProposer: ...@@ -358,6 +383,9 @@ class SpecDecodeBaseProposer:
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple(): if not self.model_returns_tuple():
...@@ -396,6 +424,7 @@ class SpecDecodeBaseProposer: ...@@ -396,6 +424,7 @@ class SpecDecodeBaseProposer:
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
slot_mappings=slot_mappings,
) )
# [batch_size, num_tree_tokens] # [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1) return torch.cat(draft_token_ids_list, dim=1)
...@@ -553,6 +582,9 @@ class SpecDecodeBaseProposer: ...@@ -553,6 +582,9 @@ class SpecDecodeBaseProposer:
num_tokens=input_batch_size, num_tokens=input_batch_size,
num_tokens_across_dp=batch_size_across_dp, num_tokens_across_dp=batch_size_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
input_batch_size, common_attn_metadata.slot_mapping
),
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple(): if not self.model_returns_tuple():
...@@ -760,6 +792,9 @@ class SpecDecodeBaseProposer: ...@@ -760,6 +792,9 @@ class SpecDecodeBaseProposer:
# [num_tokens, hidden_size] # [num_tokens, hidden_size]
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
tree_attn_metadata_builder = self.runner.attn_groups[0][ tree_attn_metadata_builder = self.runner.attn_groups[0][
0 0
...@@ -881,6 +916,9 @@ class SpecDecodeBaseProposer: ...@@ -881,6 +916,9 @@ class SpecDecodeBaseProposer:
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=self._get_slot_mapping(
num_input_tokens, attn_metadata.slot_mapping
),
): ):
last_hidden_states, hidden_states = self.model( last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens], input_ids=self.input_ids[:num_input_tokens],
...@@ -1212,6 +1250,7 @@ class SpecDecodeBaseProposer: ...@@ -1212,6 +1250,7 @@ class SpecDecodeBaseProposer:
num_tokens: int, num_tokens: int,
use_cudagraphs: bool = True, use_cudagraphs: bool = True,
is_graph_capturing: bool = False, is_graph_capturing: bool = False,
slot_mappings: dict[str, torch.Tensor] | None = None,
) -> None: ) -> None:
# FIXME: when using tree-based specdec, adjust number of forward-passes # FIXME: when using tree-based specdec, adjust number of forward-passes
# according to the depth of the tree. # according to the depth of the tree.
...@@ -1233,12 +1272,23 @@ class SpecDecodeBaseProposer: ...@@ -1233,12 +1272,23 @@ class SpecDecodeBaseProposer:
if num_tokens_across_dp is not None: if num_tokens_across_dp is not None:
num_tokens_across_dp[self.dp_rank] = num_input_tokens num_tokens_across_dp[self.dp_rank] = num_input_tokens
# Make sure to use EAGLE's own buffer during cudagraph capture.
if (
self.attn_layer_names
and slot_mappings is not None
and self.attn_layer_names[0] in slot_mappings
):
slot_mapping_dict = self._get_slot_mapping(num_input_tokens)
else:
slot_mapping_dict = slot_mappings or {}
with set_forward_context( with set_forward_context(
None, None,
self.vllm_config, self.vllm_config,
num_tokens=num_input_tokens, num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
slot_mapping=slot_mapping_dict,
): ):
if self.supports_mm_inputs: if self.supports_mm_inputs:
input_ids = None input_ids = None
......
...@@ -38,6 +38,9 @@ class MedusaProposer: ...@@ -38,6 +38,9 @@ class MedusaProposer:
self, self,
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> torch.Tensor: ) -> torch.Tensor:
# Generate blocks and compute logits # Generate blocks and compute logits
blocks = self.model(target_hidden_states) blocks = self.model(target_hidden_states)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import os import os
import numpy as np import numpy as np
import torch
from numba import get_num_threads, jit, njit, prange, set_num_threads from numba import get_num_threads, jit, njit, prange, set_num_threads
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -132,6 +133,9 @@ class NgramProposer: ...@@ -132,6 +133,9 @@ class NgramProposer:
sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
num_tokens_no_spec: np.ndarray, num_tokens_no_spec: np.ndarray,
token_ids_cpu: np.ndarray, token_ids_cpu: np.ndarray,
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> list[list[int]]: ) -> list[list[int]]:
# find which requests need ngram proposals # find which requests need ngram proposals
valid_ngram_requests = [] valid_ngram_requests = []
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_input_batch import InputBatch
...@@ -33,6 +35,9 @@ class SuffixDecodingProposer: ...@@ -33,6 +35,9 @@ class SuffixDecodingProposer:
self, self,
input_batch: InputBatch, input_batch: InputBatch,
sampled_token_ids: list[list[int]], sampled_token_ids: list[list[int]],
slot_mappings: dict[str, torch.Tensor]
| list[dict[str, torch.Tensor]]
| None = None, # unused
) -> list[list[int]]: ) -> list[list[int]]:
""" """
Propose speculative tokens for each request in the input batch. Suffix Decoding Propose speculative tokens for each request in the input batch. Suffix Decoding
......
...@@ -140,6 +140,18 @@ def init_kv_cache( ...@@ -140,6 +140,18 @@ def init_kv_cache(
return kv_caches return kv_caches
def build_slot_mappings_by_layer(
slot_mappings: torch.Tensor,
kv_cache_config: KVCacheConfig,
) -> dict[str, torch.Tensor]:
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
slot_mapping = slot_mappings[i]
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
return slot_mappings_by_layer
def build_attn_metadata( def build_attn_metadata(
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
num_reqs: int, num_reqs: int,
......
...@@ -14,7 +14,10 @@ from vllm.distributed.parallel_state import graph_capture, is_global_first_rank ...@@ -14,7 +14,10 @@ from vllm.distributed.parallel_state import graph_capture, is_global_first_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.v1.attention.backend import AttentionMetadataBuilder from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
build_slot_mappings_by_layer,
)
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.gpu.input_batch import InputBuffers
...@@ -88,7 +91,7 @@ class CudaGraphManager: ...@@ -88,7 +91,7 @@ class CudaGraphManager:
positions = mrope_positions[:, :num_tokens] positions = mrope_positions[:, :num_tokens]
if inputs_embeds is not None: if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:num_tokens] inputs_embeds = inputs_embeds[:num_tokens]
attn_metadata = prepare_inputs_to_capture( attn_metadata, slot_mappings = prepare_inputs_to_capture(
num_reqs, num_reqs,
num_tokens, num_tokens,
input_buffers, input_buffers,
...@@ -98,6 +101,9 @@ class CudaGraphManager: ...@@ -98,6 +101,9 @@ class CudaGraphManager:
kv_cache_config, kv_cache_config,
) )
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, kv_cache_config
)
# Warm up. # Warm up.
with set_forward_context( with set_forward_context(
...@@ -106,6 +112,7 @@ class CudaGraphManager: ...@@ -106,6 +112,7 @@ class CudaGraphManager:
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
): ):
hidden_states = model( hidden_states = model(
input_ids=input_ids, input_ids=input_ids,
...@@ -125,6 +132,7 @@ class CudaGraphManager: ...@@ -125,6 +132,7 @@ class CudaGraphManager:
num_tokens=num_tokens, num_tokens=num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
), ),
torch.cuda.graph(graph, self.pool), torch.cuda.graph(graph, self.pool),
): ):
...@@ -244,7 +252,7 @@ def prepare_inputs_to_capture( ...@@ -244,7 +252,7 @@ def prepare_inputs_to_capture(
attn_metadata_builders: list[AttentionMetadataBuilder], attn_metadata_builders: list[AttentionMetadataBuilder],
max_model_len: int, max_model_len: int,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
) -> dict[str, Any]: ) -> tuple[dict[str, Any], torch.Tensor]:
num_tokens_per_req = num_tokens // num_reqs num_tokens_per_req = num_tokens // num_reqs
query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req
...@@ -274,4 +282,4 @@ def prepare_inputs_to_capture( ...@@ -274,4 +282,4 @@ def prepare_inputs_to_capture(
slot_mappings=slot_mappings, slot_mappings=slot_mappings,
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
) )
return attn_metadata return attn_metadata, slot_mappings
...@@ -24,6 +24,7 @@ from vllm.v1.outputs import ModelRunnerOutput ...@@ -24,6 +24,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import ( from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata, build_attn_metadata,
build_slot_mappings_by_layer,
get_kv_cache_spec, get_kv_cache_spec,
init_attn_backend, init_attn_backend,
init_kv_cache, init_kv_cache,
...@@ -881,6 +882,14 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -881,6 +882,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.uses_mrope: if self.uses_mrope:
assert input_batch.mrope_positions is not None assert input_batch.mrope_positions is not None
positions = input_batch.mrope_positions positions = input_batch.mrope_positions
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions[: input_batch.num_tokens],
)
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
with set_forward_context( with set_forward_context(
input_batch.attn_metadata, input_batch.attn_metadata,
self.vllm_config, self.vllm_config,
...@@ -888,6 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -888,6 +897,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(woosuk): Support piecewise CUDA graph. # TODO(woosuk): Support piecewise CUDA graph.
cudagraph_runtime_mode=CUDAGraphMode.NONE, cudagraph_runtime_mode=CUDAGraphMode.NONE,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
slot_mapping=slot_mappings_by_layer,
): ):
self.kv_connector.pre_forward(scheduler_output) self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model( hidden_states = self.model(
......
...@@ -314,6 +314,7 @@ class ExecuteModelState(NamedTuple): ...@@ -314,6 +314,7 @@ class ExecuteModelState(NamedTuple):
aux_hidden_states: list[torch.Tensor] | None aux_hidden_states: list[torch.Tensor] | None
ec_connector_output: ECConnectorOutput | None ec_connector_output: ECConnectorOutput | None
cudagraph_stats: CUDAGraphStat | None cudagraph_stats: CUDAGraphStat | None
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None
class GPUModelRunner( class GPUModelRunner(
...@@ -1595,6 +1596,7 @@ class GPUModelRunner( ...@@ -1595,6 +1596,7 @@ class GPUModelRunner(
for_cudagraph_capture: bool = False, for_cudagraph_capture: bool = False,
num_scheduled_tokens: dict[str, int] | None = None, num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None,
slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: ) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
""" """
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
...@@ -1628,7 +1630,7 @@ class GPUModelRunner( ...@@ -1628,7 +1630,7 @@ class GPUModelRunner(
kv_cache_groups = self.kv_cache_config.kv_cache_groups kv_cache_groups = self.kv_cache_config.kv_cache_groups
def _get_block_table_and_slot_mapping(kv_cache_gid: int): def _get_block_table(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec kv_cache_spec = kv_cache_groups[kv_cache_gid].kv_cache_spec
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec): if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
...@@ -1637,24 +1639,19 @@ class GPUModelRunner( ...@@ -1637,24 +1639,19 @@ class GPUModelRunner(
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
slot_mapping = torch.zeros(
(num_tokens_padded,),
dtype=torch.int64,
device=self.device,
)
else: else:
blk_table = self.input_batch.block_table[kv_cache_gid] blk_table = self.input_batch.block_table[kv_cache_gid]
blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded) blk_table_tensor = blk_table.get_device_tensor(num_reqs_padded)
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda # Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID # graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens:num_tokens_padded].fill_(-1)
blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1) blk_table_tensor[num_reqs:num_reqs_padded].fill_(-1)
return blk_table_tensor
return blk_table_tensor, slot_mapping assert slot_mappings is not None
block_table_gid_0 = _get_block_table(0)
slot_mapping_gid_0 = slot_mappings[0]
block_table_gid_0, slot_mapping_gid_0 = _get_block_table_and_slot_mapping(0)
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
...@@ -1779,9 +1776,8 @@ class GPUModelRunner( ...@@ -1779,9 +1776,8 @@ class GPUModelRunner(
num_reqs_padded, num_reqs_padded,
) )
if kv_cache_gid > 0: if kv_cache_gid > 0:
cm.block_table_tensor, cm.slot_mapping = ( cm.block_table_tensor = _get_block_table(kv_cache_gid)
_get_block_table_and_slot_mapping(kv_cache_gid) cm.slot_mapping = slot_mappings[kv_cache_gid]
)
if self.speculative_config and spec_decode_common_attn_metadata is None: if self.speculative_config and spec_decode_common_attn_metadata is None:
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
...@@ -3119,6 +3115,80 @@ class GPUModelRunner( ...@@ -3119,6 +3115,80 @@ class GPUModelRunner(
pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) pyt_hooks.register_hooks(self.model, self.model.__class__.__name__)
self.layerwise_nvtx_hooks_registered = True self.layerwise_nvtx_hooks_registered = True
def _get_slot_mappings(
self,
num_tokens_padded: int,
num_reqs_padded: int,
num_tokens_unpadded: int,
ubatch_slices: "UBatchSlices | None" = None,
) -> tuple[
dict[int, torch.Tensor] | None,
dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
]:
"""
Build slot mappings in both formats needed by the system.
Args:
num_tokens_padded: Total number of tokens (padded)
num_reqs_padded: Total number of requests (padded)
num_tokens_unpadded: Actual number of tokens (unpadded)
ubatch_slices: Optional ubatch slicing info for DBO
Returns:
A tuple of:
- slot_mappings_by_gid: dict[int, torch.Tensor] for attention metadata
- slot_mappings_by_layer: dict[str, torch.Tensor] or list for ForwardContext
"""
if not (
hasattr(self, "kv_cache_config")
and self.kv_cache_config is not None
and len(self.kv_cache_config.kv_cache_groups) > 0
):
return None, None
def _get_slot_mapping(kv_cache_gid: int):
assert num_reqs_padded is not None and num_tokens_padded is not None
kv_cache_spec = self.kv_cache_config.kv_cache_groups[
kv_cache_gid
].kv_cache_spec
if isinstance(kv_cache_spec, EncoderOnlyAttentionSpec):
slot_mapping = torch.zeros(
(num_tokens_padded,),
dtype=torch.int64,
device=self.device,
)
else:
blk_table = self.input_batch.block_table[kv_cache_gid]
slot_mapping = blk_table.slot_mapping.gpu[:num_tokens_padded]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping[num_tokens_unpadded:num_tokens_padded].fill_(-1)
return slot_mapping
slot_mappings_by_gid = {
gid: _get_slot_mapping(gid)
for gid, _ in enumerate(self.kv_cache_config.kv_cache_groups)
}
slot_mappings_by_layer: dict[str, torch.Tensor] = {}
for gid, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups):
slot_mapping = slot_mappings_by_gid[gid]
for layer_name in kv_cache_group.layer_names:
slot_mappings_by_layer[layer_name] = slot_mapping
if ubatch_slices is not None:
result: list[dict[str, torch.Tensor]] = []
for ubatch in ubatch_slices:
sliced_mappings: dict[str, torch.Tensor] = {}
for layer_name, slot_mapping in slot_mappings_by_layer.items():
sliced_mappings[layer_name] = slot_mapping[ubatch.token_slice]
result.append(sliced_mappings)
return slot_mappings_by_gid, result
return slot_mappings_by_gid, slot_mappings_by_layer
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
...@@ -3248,6 +3318,17 @@ class GPUModelRunner( ...@@ -3248,6 +3318,17 @@ class GPUModelRunner(
ubatch_slices_padded, ubatch_slices_padded,
) )
# True if any attention backend handles KV cache update separately
# from forward() (i.e., forward_includes_kv_cache_update=False). When true,
# slot_mappings must use padded dimensions to match the key/value tensors.
has_separate_kv_update = not all(
all(
g.backend.forward_includes_kv_cache_update
for g in self.attn_groups[id]
)
for id, spec in enumerate(self.kv_cache_config.kv_cache_groups)
if not isinstance(spec.kv_cache_spec, EncoderOnlyAttentionSpec)
)
pad_attn = cudagraph_mode == CUDAGraphMode.FULL pad_attn = cudagraph_mode == CUDAGraphMode.FULL
if self.cache_config.mamba_cache_mode == "align": if self.cache_config.mamba_cache_mode == "align":
...@@ -3265,6 +3346,17 @@ class GPUModelRunner( ...@@ -3265,6 +3346,17 @@ class GPUModelRunner(
use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0 use_spec_decode = len(scheduler_output.scheduled_spec_decode_tokens) > 0
ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices ubatch_slices_attn = ubatch_slices_padded if pad_attn else ubatch_slices
slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
num_tokens_padded=num_tokens_padded
if pad_attn or has_separate_kv_update
else num_tokens_unpadded,
num_reqs_padded=(
num_reqs_padded if pad_attn or has_separate_kv_update else num_reqs
),
num_tokens_unpadded=num_tokens_unpadded,
ubatch_slices=ubatch_slices_padded,
)
attn_metadata, spec_decode_common_attn_metadata = ( attn_metadata, spec_decode_common_attn_metadata = (
self._build_attention_metadata( self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,
...@@ -3277,6 +3369,7 @@ class GPUModelRunner( ...@@ -3277,6 +3369,7 @@ class GPUModelRunner(
use_spec_decode=use_spec_decode, use_spec_decode=use_spec_decode,
num_scheduled_tokens=scheduler_output.num_scheduled_tokens, num_scheduled_tokens=scheduler_output.num_scheduled_tokens,
cascade_attn_prefix_lens=cascade_attn_prefix_lens, cascade_attn_prefix_lens=cascade_attn_prefix_lens,
slot_mappings=slot_mappings_by_group,
) )
) )
...@@ -3317,6 +3410,7 @@ class GPUModelRunner( ...@@ -3317,6 +3410,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_mode, cudagraph_runtime_mode=cudagraph_mode,
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
skip_compiled=has_encoder_input, skip_compiled=has_encoder_input,
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
...@@ -3399,6 +3493,7 @@ class GPUModelRunner( ...@@ -3399,6 +3493,7 @@ class GPUModelRunner(
aux_hidden_states, aux_hidden_states,
ec_connector_output, ec_connector_output,
cudagraph_stats, cudagraph_stats,
slot_mappings,
) )
self.kv_connector_output = kv_connector_output self.kv_connector_output = kv_connector_output
return None return None
...@@ -3435,6 +3530,7 @@ class GPUModelRunner( ...@@ -3435,6 +3530,7 @@ class GPUModelRunner(
aux_hidden_states, aux_hidden_states,
ec_connector_output, ec_connector_output,
cudagraph_stats, cudagraph_stats,
slot_mappings,
) = self.execute_model_state ) = self.execute_model_state
# Clear ephemeral state. # Clear ephemeral state.
self.execute_model_state = None self.execute_model_state = None
...@@ -3468,6 +3564,7 @@ class GPUModelRunner( ...@@ -3468,6 +3564,7 @@ class GPUModelRunner(
aux_hidden_states, aux_hidden_states,
spec_decode_metadata, spec_decode_metadata,
spec_decode_common_attn_metadata, spec_decode_common_attn_metadata,
slot_mappings,
) )
self._copy_draft_token_ids_to_cpu(scheduler_output) self._copy_draft_token_ids_to_cpu(scheduler_output)
...@@ -3676,6 +3773,7 @@ class GPUModelRunner( ...@@ -3676,6 +3773,7 @@ class GPUModelRunner(
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
spec_decode_metadata: SpecDecodeMetadata | None, spec_decode_metadata: SpecDecodeMetadata | None,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
slot_mappings: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None,
) -> list[list[int]] | torch.Tensor: ) -> list[list[int]] | torch.Tensor:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
spec_config = self.speculative_config spec_config = self.speculative_config
...@@ -3687,11 +3785,14 @@ class GPUModelRunner( ...@@ -3687,11 +3785,14 @@ class GPUModelRunner(
sampled_token_ids, sampled_token_ids,
self.input_batch.num_tokens_no_spec, self.input_batch.num_tokens_no_spec,
self.input_batch.token_ids_cpu, self.input_batch.token_ids_cpu,
slot_mappings=slot_mappings,
) )
elif spec_config.method == "suffix": elif spec_config.method == "suffix":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, SuffixDecodingProposer) assert isinstance(self.drafter, SuffixDecodingProposer)
draft_token_ids = self.drafter.propose(self.input_batch, sampled_token_ids) draft_token_ids = self.drafter.propose(
self.input_batch, sampled_token_ids, slot_mappings=slot_mappings
)
elif spec_config.method == "medusa": elif spec_config.method == "medusa":
assert isinstance(sampled_token_ids, list) assert isinstance(sampled_token_ids, list)
assert isinstance(self.drafter, MedusaProposer) assert isinstance(self.drafter, MedusaProposer)
...@@ -3716,6 +3817,7 @@ class GPUModelRunner( ...@@ -3716,6 +3817,7 @@ class GPUModelRunner(
draft_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_hidden_states=hidden_states, target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
slot_mappings=slot_mappings,
) )
elif spec_config.use_eagle() or spec_config.uses_draft_model(): elif spec_config.use_eagle() or spec_config.uses_draft_model():
assert isinstance(self.drafter, EagleProposer | DraftModelProposer) assert isinstance(self.drafter, EagleProposer | DraftModelProposer)
...@@ -3826,6 +3928,7 @@ class GPUModelRunner( ...@@ -3826,6 +3928,7 @@ class GPUModelRunner(
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embed_inputs=mm_embed_inputs, mm_embed_inputs=mm_embed_inputs,
num_rejected_tokens_gpu=num_rejected_tokens_gpu, num_rejected_tokens_gpu=num_rejected_tokens_gpu,
slot_mappings=slot_mappings,
) )
return draft_token_ids return draft_token_ids
...@@ -4406,6 +4509,13 @@ class GPUModelRunner( ...@@ -4406,6 +4509,13 @@ class GPUModelRunner(
attn_metadata: PerLayerAttnMetadata | None = None attn_metadata: PerLayerAttnMetadata | None = None
slot_mappings_by_group, slot_mappings = self._get_slot_mappings(
num_tokens_padded=num_tokens,
num_reqs_padded=num_reqs_padded,
num_tokens_unpadded=num_tokens_unpadded,
ubatch_slices=ubatch_slices_padded,
)
# If force_attention is True, we always capture attention. Otherwise, # If force_attention is True, we always capture attention. Otherwise,
# it only happens for cudagraph_runtime_mode=FULL. # it only happens for cudagraph_runtime_mode=FULL.
if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL: if force_attention or cudagraph_runtime_mode == CUDAGraphMode.FULL:
...@@ -4431,6 +4541,7 @@ class GPUModelRunner( ...@@ -4431,6 +4541,7 @@ class GPUModelRunner(
max_query_len=max_query_len, max_query_len=max_query_len,
ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices, ubatch_slices=ubatch_slices_padded if pad_attn else ubatch_slices,
for_cudagraph_capture=is_graph_capturing, for_cudagraph_capture=is_graph_capturing,
slot_mappings=slot_mappings_by_group,
) )
with self.maybe_dummy_run_with_lora( with self.maybe_dummy_run_with_lora(
...@@ -4499,6 +4610,7 @@ class GPUModelRunner( ...@@ -4499,6 +4610,7 @@ class GPUModelRunner(
cudagraph_runtime_mode=cudagraph_runtime_mode, cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
), ),
): ):
outputs = self.model( outputs = self.model(
...@@ -4545,6 +4657,7 @@ class GPUModelRunner( ...@@ -4545,6 +4657,7 @@ class GPUModelRunner(
num_tokens, num_tokens,
use_cudagraphs=use_cudagraphs, use_cudagraphs=use_cudagraphs,
is_graph_capturing=is_graph_capturing, is_graph_capturing=is_graph_capturing,
slot_mappings=slot_mappings,
) )
# We register layerwise NVTX hooks here after the first dynamo tracing is # We register layerwise NVTX hooks here after the first dynamo tracing is
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment