Unverified Commit 7e4be741 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Batch invariant: Fix flash attn MLA `RuntimeError: scheduler_metadata...

[Bug] Batch invariant: Fix flash attn MLA `RuntimeError: scheduler_metadata must have shape (metadata_size)` (#27884)
parent 380ba681
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib import contextlib
import functools
import os import os
from collections import namedtuple from collections import namedtuple
from collections.abc import Callable from collections.abc import Callable
...@@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize: ...@@ -846,6 +847,7 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16) return AttentionBlockSize(block_m=16, block_n=16)
@functools.cache
def vllm_is_batch_invariant(): def vllm_is_batch_invariant():
env_key = "VLLM_BATCH_INVARIANT" env_key = "VLLM_BATCH_INVARIANT"
is_overridden = False is_overridden = False
......
...@@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -163,6 +163,9 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# we only set num_splits when using cuda graphs. # we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits max_num_splits = self.max_num_splits
if vllm_is_batch_invariant():
max_num_splits = 1
scheduler_metadata = self._schedule_decode( scheduler_metadata = self._schedule_decode(
num_reqs=seq_lens_cpu.numel(), num_reqs=seq_lens_cpu.numel(),
cu_query_lens=query_start_loc_device, cu_query_lens=query_start_loc_device,
...@@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -188,9 +191,6 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.scheduler_metadata[n:] = 0 self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n] scheduler_metadata = self.scheduler_metadata[:n]
if vllm_is_batch_invariant():
max_num_splits = 1
metadata = FlashAttnMLADecodeMetadata( metadata = FlashAttnMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens_device, seq_lens=seq_lens_device,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment