"vscode:/vscode.git/clone" did not exist on "9352cdb56d70bd52d4e6ea88d991bf5f4cc93393"
Unverified Commit d43ad5a7 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix DCP Assert (AssertionError: DCP not support reorder_batch_threshold > 1 now.) (#28100)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 0ff05e37
...@@ -545,6 +545,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -545,6 +545,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
vllm_config: VllmConfig, vllm_config: VllmConfig,
device: torch.device, device: torch.device,
metadata_cls: type[M] | None = None, metadata_cls: type[M] | None = None,
supports_dcp_with_varlen: bool = False,
): ):
self.metadata_cls = ( self.metadata_cls = (
metadata_cls if metadata_cls is not None else MLACommonMetadata metadata_cls if metadata_cls is not None else MLACommonMetadata
...@@ -638,7 +639,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -638,7 +639,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
self._init_reorder_batch_threshold( self._init_reorder_batch_threshold(
self.reorder_batch_threshold, supports_spec_decode self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
) )
# Validate consistency between query_len_support and reorder_batch_threshold # Validate consistency between query_len_support and reorder_batch_threshold
......
...@@ -81,7 +81,12 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -81,7 +81,12 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
device: torch.device, device: torch.device,
): ):
super().__init__( super().__init__(
kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata kv_cache_spec,
layer_names,
vllm_config,
device,
FlashAttnMLAMetadata,
supports_dcp_with_varlen=True,
) )
self.max_num_splits = 0 # No upper bound on the number of splits. self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = get_flash_attn_version() == 3 self.fa_aot_schedule = get_flash_attn_version() == 3
......
...@@ -264,7 +264,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -264,7 +264,10 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
self.device = device self.device = device
def _init_reorder_batch_threshold( def _init_reorder_batch_threshold(
self, reorder_batch_threshold: int = 1, supports_spec_as_decode: bool = False self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool = False,
supports_dcp_with_varlen: bool = False,
) -> None: ) -> None:
self.reorder_batch_threshold = reorder_batch_threshold self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None and supports_spec_as_decode: if self.reorder_batch_threshold is not None and supports_spec_as_decode:
...@@ -281,6 +284,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -281,6 +284,12 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
1 + speculative_config.num_speculative_tokens, 1 + speculative_config.num_speculative_tokens,
) )
if (
self.vllm_config.parallel_config.decode_context_parallel_size > 1
and not supports_dcp_with_varlen
):
self.reorder_batch_threshold = 1
@abstractmethod @abstractmethod
def build( def build(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment