Unverified Commit 18961c5e authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Hybrid] Pass kernel block size to builders (#27753)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 470ad118
...@@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -62,7 +62,11 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod @staticmethod
def get_supported_kernel_block_size() -> list[int | MultipleOf]: def get_supported_kernel_block_size() -> list[int | MultipleOf]:
return [MultipleOf(16)] # NOTE(tdoublep): while in principle, FA supports
# MultipleOf(16), these are the block sizes that do not
# suffer from the NaN propagation problem described here:
# https://github.com/Dao-AILab/flash-attention/issues/1974
return [16, 32, 64]
@classmethod @classmethod
def validate_head_size(cls, head_size: int) -> None: def validate_head_size(cls, head_size: int) -> None:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy import copy
from dataclasses import dataclass, fields from dataclasses import dataclass, fields, replace
from math import prod from math import prod
import torch import torch
...@@ -44,6 +44,12 @@ class KVCacheSpec: ...@@ -44,6 +44,12 @@ class KVCacheSpec:
""" """
raise NotImplementedError raise NotImplementedError
def copy_with_new_block_size(self, block_size: int) -> Self:
"""
Create a new KVCacheSpec from self but replacing the block size.
"""
return replace(self, block_size=block_size)
@classmethod @classmethod
def merge(cls, specs: list[Self]) -> Self: def merge(cls, specs: list[Self]) -> Self:
""" """
......
...@@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4039,16 +4039,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) -> list[AttentionGroup]: ) -> list[AttentionGroup]:
attn_groups: list[AttentionGroup] = [] attn_groups: list[AttentionGroup] = []
for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items(): for (attn_backend, kv_cache_spec), layer_names in attn_backends_map.items():
attn_group = AttentionGroup.create_with_metadata_builders( attn_group = AttentionGroup(
attn_backend, attn_backend,
layer_names, layer_names,
kv_cache_spec, kv_cache_spec,
self.vllm_config,
self.device,
kv_cache_group_id, kv_cache_group_id,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
) )
attn_groups.append(attn_group) attn_groups.append(attn_group)
...@@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4067,7 +4062,27 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for i, attn_backend_map in enumerate(attention_backend_maps): for i, attn_backend_map in enumerate(attention_backend_maps):
self.attn_groups.append(create_attn_groups(attn_backend_map, i)) self.attn_groups.append(create_attn_groups(attn_backend_map, i))
def initialize_metadata_builders(
self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int]
) -> None:
"""
Create the metadata builders for all KV cache groups and attn groups.
"""
for kv_cache_group_id in range(len(kv_cache_config.kv_cache_groups)):
for attn_group in self.attn_groups[kv_cache_group_id]:
attn_group.create_metadata_builders(
self.vllm_config,
self.device,
kernel_block_sizes[kv_cache_group_id]
if kv_cache_group_id < len(kernel_block_sizes)
else None,
num_metadata_builders=1
if not self.parallel_config.enable_dbo
else 2,
)
# Calculate reorder batch threshold (if needed) # Calculate reorder batch threshold (if needed)
# Note (tdoublep): do this *after* constructing builders,
# because some of them change the threshold at init time.
self.calculate_reorder_batch_threshold() self.calculate_reorder_batch_threshold()
def _check_and_update_cudagraph_mode( def _check_and_update_cudagraph_mode(
...@@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -4633,6 +4648,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# kernel_block_size 64 and split the 256-token-block to 4 blocks with 64 # kernel_block_size 64 and split the 256-token-block to 4 blocks with 64
# tokens each. # tokens each.
kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config) kernel_block_sizes = self._prepare_kernel_block_sizes(kv_cache_config)
# create metadata builders
self.initialize_metadata_builders(kv_cache_config, kernel_block_sizes)
# Reinitialize need to after initialize_attn_backend # Reinitialize need to after initialize_attn_backend
self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes) self.may_reinitialize_input_batch(kv_cache_config, kernel_block_sizes)
kv_caches = self.initialize_kv_cache_tensors( kv_caches = self.initialize_kv_cache_tensors(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
...@@ -134,31 +134,37 @@ class MultiModalBudget: ...@@ -134,31 +134,37 @@ class MultiModalBudget:
@dataclass @dataclass
class AttentionGroup: class AttentionGroup:
backend: type[AttentionBackend] backend: type[AttentionBackend]
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder]
layer_names: list[str] layer_names: list[str]
kv_cache_spec: KVCacheSpec kv_cache_spec: KVCacheSpec
kv_cache_group_id: int kv_cache_group_id: int
# When ubatching is enabled we will have a metadata builder for each ubatch
# so that if they use internal persistant buffers for cudagraphs, and they
# won't have to worry about conflicting with the other ubatches.
metadata_builders: list[AttentionMetadataBuilder] = field(
default_factory=lambda: []
)
@staticmethod def create_metadata_builders(
def create_with_metadata_builders( self,
backend: type[AttentionBackend], vllm_config,
layer_names: list[str], device,
kv_cache_spec: KVCacheSpec, kernel_block_size: int | None,
vllm_config: VllmConfig,
device: torch.device,
kv_cache_group_id: int,
num_metadata_builders: int = 1, num_metadata_builders: int = 1,
) -> "AttentionGroup": ):
metadata_builders = [ kv_cache_spec_builder = (
backend.get_builder_cls()(kv_cache_spec, layer_names, vllm_config, device) self.kv_cache_spec.copy_with_new_block_size(kernel_block_size)
if kernel_block_size is not None
else self.kv_cache_spec
)
self.metadata_builders = [
self.backend.get_builder_cls()(
kv_cache_spec_builder,
self.layer_names,
vllm_config,
device,
)
for _ in range(num_metadata_builders) for _ in range(num_metadata_builders)
] ]
return AttentionGroup(
backend, metadata_builders, layer_names, kv_cache_spec, kv_cache_group_id
)
def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder: def get_metadata_builder(self, ubatch_id: int = 0) -> AttentionMetadataBuilder:
assert len(self.metadata_builders) > ubatch_id assert len(self.metadata_builders) > ubatch_id
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment