Unverified Commit c30b405b authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Spec Decode] Enable FlashInfer Spec Decoding (#25196)


Signed-off-by: default avatarBenjamin Chislett <benjamin.chislett@centml.ai>
Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
Co-authored-by: default avatarlhsjohn <huashuoli@tencent.com>
parent 77d90699
...@@ -9,7 +9,8 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata ...@@ -9,7 +9,8 @@ from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm.v1.attention.backends.utils import (UBatchSlice, from vllm.v1.attention.backends.utils import (UBatchSlice,
_make_metadata_with_slice, _make_metadata_with_slice,
slice_query_start_locs, slice_query_start_locs,
split_attn_metadata) split_attn_metadata,
split_decodes_and_prefills)
from vllm.v1.worker.ubatch_utils import create_ubatch_slices from vllm.v1.worker.ubatch_utils import create_ubatch_slices
...@@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata): ...@@ -158,6 +159,112 @@ def test_split_attn_metadata_decode_batch(large_decode_metadata):
assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point)) assert torch.equal(results[1].seq_lens, torch.tensor([2048] * mid_point))
def apply_split_decodes_and_prefills(query_lens: list[int],
decode_threshold: int,
require_uniform: bool):
"""Helper function to apply split_decodes_and_prefills and return
the results."""
device = torch.device("cpu")
seq_lens = [10 * (i + 1) for i in range(len(query_lens))]
common_metadata = create_common_attn_metadata(BatchSpec(
seq_lens=seq_lens, query_lens=query_lens),
block_size=16,
device=device)
return split_decodes_and_prefills(common_metadata,
decode_threshold=decode_threshold,
require_uniform=require_uniform)
def test_split_decodes_and_prefills_nonuniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, False))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_nonuniform_all_short_decodes():
query_lens = [1, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 7
assert num_prefills == 0
assert num_decode_tokens == sum(query_lens)
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_nonuniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, False))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)
def test_split_decodes_and_prefills_nonuniform_mixed_batch():
query_lens = [2, 1, 3, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, False))
assert num_decodes == 4 # 2, 1, 3, 4 are all <= 4
assert num_prefills == 4 # 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 10 # 2 + 1 + 3 + 4
assert num_prefill_tokens == 26 # 5 + 6 + 7 + 8
def test_split_decodes_and_prefills_uniform_all_ones():
query_lens = [1, 1, 1]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 1, True))
assert num_decodes == 3
assert num_prefills == 0
assert num_decode_tokens == 3
assert num_prefill_tokens == 0
def test_split_decodes_and_prefills_uniform_all_short_decodes():
query_lens = [2, 2, 1, 3, 2, 1, 2]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 2
assert num_prefills == 5
assert num_decode_tokens == 4
assert num_prefill_tokens == (1 + 3 + 2 + 1 + 2)
def test_split_decodes_and_prefills_uniform_all_prefills():
query_lens = [4, 5, 6, 7]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 3, True))
assert num_decodes == 0
assert num_prefills == 4
assert num_decode_tokens == 0
assert num_prefill_tokens == sum(query_lens)
def test_split_decodes_and_prefills_uniform_mixed_batch_all_uniform_decodes():
query_lens = [2, 2, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 3 # 2, 2, 2 are all <= 4 and uniform
assert num_prefills == 5 # 4, 5, 6, 7, 8 are all > 4
assert num_decode_tokens == 6 # 2 + 2 + 2
assert num_prefill_tokens == 30 # 4 + 5 + 6 + 7 + 8
def test_split_decodes_and_prefills_uniform_mixed_batch_non_uniform_decodes():
query_lens = [2, 1, 2, 4, 5, 6, 7, 8]
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
apply_split_decodes_and_prefills(query_lens, 4, True))
assert num_decodes == 1 # only the first 2 is taken as decode
assert num_prefills == 7 # 1, 2, 4, 5, 6, 7, 8 are all > 4 or non-uniform
assert num_decode_tokens == 2 # only the first 2
assert num_prefill_tokens == (sum(query_lens) - 2) # rest of the tokens
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs", "seq_lens,query_lens,split_point,expected_first_reqs,expected_second_reqs",
[ [
......
...@@ -181,6 +181,12 @@ def force_use_trtllm_attention() -> Optional[bool]: ...@@ -181,6 +181,12 @@ def force_use_trtllm_attention() -> Optional[bool]:
return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION) return _force_use_trtllm_attention(envs.VLLM_USE_TRTLLM_ATTENTION)
def can_use_trtllm_attention(num_qo_heads: int, num_kv_heads: int) -> bool:
"""Check if the current configuration supports TRTLLM attention."""
has_trtllm = supports_trtllm_attention()
return has_trtllm and (num_qo_heads % num_kv_heads == 0)
def use_trtllm_attention( def use_trtllm_attention(
num_qo_heads: int, num_qo_heads: int,
num_kv_heads: int, num_kv_heads: int,
...@@ -188,7 +194,9 @@ def use_trtllm_attention( ...@@ -188,7 +194,9 @@ def use_trtllm_attention(
max_seq_len: int, max_seq_len: int,
kv_cache_dtype: str, kv_cache_dtype: str,
q_dtype: torch.dtype, q_dtype: torch.dtype,
is_prefill: bool,
has_sinks: bool = False, has_sinks: bool = False,
has_spec: bool = False,
) -> bool: ) -> bool:
"""Return ``True`` if TRTLLM attention is used.""" """Return ``True`` if TRTLLM attention is used."""
force_use_trtllm = force_use_trtllm_attention() force_use_trtllm = force_use_trtllm_attention()
...@@ -214,6 +222,12 @@ def use_trtllm_attention( ...@@ -214,6 +222,12 @@ def use_trtllm_attention(
) )
return False return False
if has_spec and not is_prefill:
# Speculative decoding requires TRTLLM attention for decodes
logger.info_once(
"Using TRTLLM attention (enabled for speculative decoding).")
return True
# Must use TRTLLM attention if query is FP8 quantized # Must use TRTLLM attention if query is FP8 quantized
if q_dtype == current_platform.fp8_dtype(): if q_dtype == current_platform.fp8_dtype():
if has_sinks: if has_sinks:
...@@ -391,6 +405,7 @@ __all__ = [ ...@@ -391,6 +405,7 @@ __all__ = [
"has_flashinfer_cutlass_fused_moe", "has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory", "has_nvidia_artifactory",
"supports_trtllm_attention", "supports_trtllm_attention",
"can_use_trtllm_attention",
"use_trtllm_attention", "use_trtllm_attention",
"flashinfer_disable_q_quantization", "flashinfer_disable_q_quantization",
"flashinfer_scaled_fp4_mm", "flashinfer_scaled_fp4_mm",
......
...@@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -25,7 +25,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, is_pin_memory_available from vllm.utils import cdiv, is_pin_memory_available
from vllm.utils.flashinfer import (flashinfer_disable_q_quantization, from vllm.utils.flashinfer import (can_use_trtllm_attention,
flashinfer_disable_q_quantization,
supports_trtllm_attention, supports_trtllm_attention,
use_trtllm_attention) use_trtllm_attention)
from vllm.v1.attention.backends.flash_attn import use_cascade_attention from vllm.v1.attention.backends.flash_attn import use_cascade_attention
...@@ -223,6 +224,7 @@ class FlashInferMetadata: ...@@ -223,6 +224,7 @@ class FlashInferMetadata:
# For flashinfer trtllm batch decode # For flashinfer trtllm batch decode
max_q_len: int max_q_len: int
max_q_len_prefill: int
max_seq_len: int max_seq_len: int
seq_lens: torch.Tensor seq_lens: torch.Tensor
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
...@@ -250,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -250,7 +252,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -302,6 +304,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -302,6 +304,10 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
else: else:
self.q_data_type = self.model_config.dtype self.q_data_type = self.model_config.dtype
supports_spec_as_decode = \
can_use_trtllm_attention(self.num_qo_heads, self.num_kv_heads)
self._init_reorder_batch_threshold(1, supports_spec_as_decode)
self._cascade_wrapper = None # Wrapper for cascade attention self._cascade_wrapper = None # Wrapper for cascade attention
# Global hyperparameters shared by all attention layers # Global hyperparameters shared by all attention layers
...@@ -416,7 +422,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -416,7 +422,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\
split_decodes_and_prefills(common_attn_metadata, split_decodes_and_prefills(common_attn_metadata,
decode_threshold=self.reorder_batch_threshold) decode_threshold=self.reorder_batch_threshold,
require_uniform=True)
page_size = self.page_size page_size = self.page_size
max_q_len = common_attn_metadata.max_query_len max_q_len = common_attn_metadata.max_query_len
...@@ -491,20 +498,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -491,20 +498,25 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
paged_kv_last_page_len_np, paged_kv_last_page_len_np,
) )
uses_spec_reorder = self.reorder_batch_threshold > 1
prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads, prefill_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads, self.num_kv_heads,
num_prefill_tokens, num_prefill_tokens,
max_seq_len, max_seq_len,
self.cache_dtype, self.cache_dtype,
self.q_data_type, self.q_data_type,
has_sinks=self.has_sinks) is_prefill=True,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder)
decode_use_trtllm = use_trtllm_attention(self.num_qo_heads, decode_use_trtllm = use_trtllm_attention(self.num_qo_heads,
self.num_kv_heads, self.num_kv_heads,
num_decode_tokens, num_decode_tokens,
max_seq_len, max_seq_len,
self.cache_dtype, self.cache_dtype,
self.q_data_type, self.q_data_type,
has_sinks=self.has_sinks) is_prefill=False,
has_sinks=self.has_sinks,
has_spec=uses_spec_reorder)
if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm): if self.has_sinks and not (prefill_use_trtllm and decode_use_trtllm):
raise NotImplementedError( raise NotImplementedError(
"FlashInfer backend currently does not support attention " "FlashInfer backend currently does not support attention "
...@@ -521,6 +533,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -521,6 +533,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
max_q_len=max_q_len, max_q_len=max_q_len,
max_q_len_prefill=max_q_len,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
seq_lens=seq_lens, seq_lens=seq_lens,
block_table_tensor=block_table_tensor, block_table_tensor=block_table_tensor,
...@@ -577,6 +590,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -577,6 +590,15 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[ qo_indptr_cpu = qo_indptr_cpu[prefill_start:] - qo_indptr_cpu[
prefill_start] prefill_start]
paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:] paged_kv_indptr_cpu = paged_kv_indptr_cpu[prefill_start:]
# Recompute max_q_len for the slice of requests we are using
# for prefills. This can be different from max_q_len when
# we have a non-uniform batch with some short decodes offloaded
# to the prefill pathway
query_lens_prefill = qo_indptr_cpu[1:] - qo_indptr_cpu[:-1]
attn_metadata.max_q_len_prefill = \
int(query_lens_prefill.max().item())
if not attn_metadata.prefill_use_trtllm: if not attn_metadata.prefill_use_trtllm:
attn_metadata.prefill_wrapper.plan( attn_metadata.prefill_wrapper.plan(
qo_indptr_cpu, qo_indptr_cpu,
...@@ -607,7 +629,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -607,7 +629,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_decodes <= self._decode_cudagraph_max_bs) num_decodes <= self._decode_cudagraph_max_bs)
if use_cudagraph: if use_cudagraph:
num_input_tokens = ( num_input_tokens = (
self.vllm_config.pad_for_cudagraph(num_decodes)) self.vllm_config.pad_for_cudagraph(num_decode_tokens))
# Carefully fulfill the padding region with reasonable value # Carefully fulfill the padding region with reasonable value
# on cpu. # on cpu.
# Make sure paged_kv_indptr_cpu is not decreasing # Make sure paged_kv_indptr_cpu is not decreasing
...@@ -621,7 +643,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -621,7 +643,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_decodes:num_input_tokens].fill_(1) num_decodes:num_input_tokens].fill_(1)
else: else:
num_input_tokens = num_decodes num_input_tokens = num_decode_tokens
attn_metadata.decode_wrapper = self._get_decode_wrapper( attn_metadata.decode_wrapper = self._get_decode_wrapper(
num_input_tokens, use_cudagraph) num_input_tokens, use_cudagraph)
...@@ -842,6 +864,9 @@ class FlashInferImpl(AttentionImpl): ...@@ -842,6 +864,9 @@ class FlashInferImpl(AttentionImpl):
output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache)) output.copy_(attn_metadata.cascade_wrapper.run(query, kv_cache))
return output return output
# When using spec decoding, num_decodes can be < num_decode_tokens
# because some decode requests may have more than one query token.
num_decodes = attn_metadata.num_decodes
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
...@@ -874,8 +899,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -874,8 +899,8 @@ class FlashInferImpl(AttentionImpl):
prefill_query = prefill_query.contiguous() prefill_query = prefill_query.contiguous()
workspace_buffer = _get_trtllm_gen_workspace_buffer() workspace_buffer = _get_trtllm_gen_workspace_buffer()
block_tables_prefill = attn_metadata.block_table_tensor[ block_tables_prefill = attn_metadata.block_table_tensor[
num_decode_tokens:] num_decodes:]
seq_lens_prefill = attn_metadata.seq_lens[num_decode_tokens:] seq_lens_prefill = attn_metadata.seq_lens[num_decodes:]
# This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND # This path needs to be enabled with VLLM_KV_CACHE_LAYOUT = HND
assert get_kv_cache_layout() == "HND" assert get_kv_cache_layout() == "HND"
...@@ -919,7 +944,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -919,7 +944,7 @@ class FlashInferImpl(AttentionImpl):
workspace_buffer=workspace_buffer, workspace_buffer=workspace_buffer,
block_tables=mock_block_table, block_tables=mock_block_table,
seq_lens=seq_lens_prefill, seq_lens=seq_lens_prefill,
max_q_len=attn_metadata.max_q_len, max_q_len=attn_metadata.max_q_len_prefill,
max_kv_len=attn_metadata.max_seq_len, max_kv_len=attn_metadata.max_seq_len,
bmm1_scale=self.bmm1_scale, bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale, bmm2_scale=self.bmm2_scale,
...@@ -976,6 +1001,14 @@ class FlashInferImpl(AttentionImpl): ...@@ -976,6 +1001,14 @@ class FlashInferImpl(AttentionImpl):
assert self.o_sf_scale is None assert self.o_sf_scale is None
out = output[:num_decode_tokens] out = output[:num_decode_tokens]
if num_decode_tokens % attn_metadata.num_decodes != 0:
# This gets triggered when the dummy_run forces
# attention to be initialized with q_len = 0
q_len_per_req = 1
else:
q_len_per_req = \
num_decode_tokens // attn_metadata.num_decodes
trtllm_batch_decode_with_kv_cache( trtllm_batch_decode_with_kv_cache(
query=decode_query, query=decode_query,
kv_cache=kv_cache_permute, kv_cache=kv_cache_permute,
...@@ -989,7 +1022,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -989,7 +1022,7 @@ class FlashInferImpl(AttentionImpl):
sinks=self.sinks, sinks=self.sinks,
o_sf_scale=self.o_sf_scale, o_sf_scale=self.o_sf_scale,
out=out, out=out,
) q_len_per_req=q_len_per_req)
return output_padded return output_padded
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Backend for GatedDeltaNet attention.""" """Backend for GatedDeltaNet attention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import Optional
import torch import torch
...@@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder( ...@@ -62,7 +62,7 @@ class GDNAttentionMetadataBuilder(
cudagraph_support = AttentionCGSupport.UNIFORM_BATCH cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -76,7 +76,7 @@ class GDNAttentionMetadataBuilder( ...@@ -76,7 +76,7 @@ class GDNAttentionMetadataBuilder(
else: else:
self.num_spec = 0 self.num_spec = 0
self.use_spec_decode = self.num_spec > 0 self.use_spec_decode = self.num_spec > 0
self.reorder_batch_threshold = self.num_spec + 1 # type: ignore[misc] self._init_reorder_batch_threshold(1, self.use_spec_decode)
self.use_full_cuda_graph = \ self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.compilation_config.cudagraph_mode.has_full_cudagraphs()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar
import torch import torch
...@@ -35,7 +34,7 @@ class LinearAttentionMetadata: ...@@ -35,7 +34,7 @@ class LinearAttentionMetadata:
class LinearAttentionMetadataBuilder( class LinearAttentionMetadataBuilder(
AttentionMetadataBuilder[LinearAttentionMetadata]): AttentionMetadataBuilder[LinearAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
......
...@@ -16,7 +16,7 @@ M = TypeVar("M") ...@@ -16,7 +16,7 @@ M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
......
...@@ -190,7 +190,7 @@ return curr_o @ W_O ...@@ -190,7 +190,7 @@ return curr_o @ W_O
import functools import functools
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import ClassVar, Generic, Optional, TypeVar, Union from typing import Generic, Optional, TypeVar, Union
import torch import torch
from tqdm import tqdm from tqdm import tqdm
...@@ -434,7 +434,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ...@@ -434,7 +434,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
@staticmethod @staticmethod
def determine_chunked_prefill_workspace_size( def determine_chunked_prefill_workspace_size(
......
...@@ -64,7 +64,7 @@ class FlashAttnMLAMetadataBuilder( ...@@ -64,7 +64,7 @@ class FlashAttnMLAMetadataBuilder(
cudagraph_support: ClassVar[AttentionCGSupport] = \ cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 512 reorder_batch_threshold: int = 512
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
...@@ -99,7 +99,7 @@ class FlashAttnMLAMetadataBuilder( ...@@ -99,7 +99,7 @@ class FlashAttnMLAMetadataBuilder(
# TODO(lucas): Until we add support for the DCP custom masking we need # TODO(lucas): Until we add support for the DCP custom masking we need
# to restrict decodes to q_len == 1 when DCP is enabled. # to restrict decodes to q_len == 1 when DCP is enabled.
self.__class__.reorder_batch_threshold = 1 \ self.reorder_batch_threshold = 1 \
if get_dcp_group().world_size > 1 else self.reorder_batch_threshold if get_dcp_group().world_size > 1 else self.reorder_batch_threshold
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar, Optional from typing import Optional
import torch import torch
...@@ -41,7 +41,7 @@ class ShortConvAttentionMetadata: ...@@ -41,7 +41,7 @@ class ShortConvAttentionMetadata:
class ShortConvAttentionMetadataBuilder( class ShortConvAttentionMetadataBuilder(
AttentionMetadataBuilder[ShortConvAttentionMetadata]): AttentionMetadataBuilder[ShortConvAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device): vllm_config: VllmConfig, device: torch.device):
......
...@@ -236,7 +236,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -236,7 +236,7 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
# Does this backend/builder reorder the batch? # Does this backend/builder reorder the batch?
# If not, set this to None. Otherwise set it to the query # If not, set this to None. Otherwise set it to the query
# length that will be pulled into the front of the batch. # length that will be pulled into the front of the batch.
reorder_batch_threshold: ClassVar[Optional[int]] = None reorder_batch_threshold: Optional[int] = None
@abstractmethod @abstractmethod
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
...@@ -246,6 +246,22 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]): ...@@ -246,6 +246,22 @@ class AttentionMetadataBuilder(abc.ABC, Generic[M]):
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.device = device self.device = device
def _init_reorder_batch_threshold(
self,
reorder_batch_threshold: int = 1,
supports_spec_as_decode: bool = False) -> None:
self.reorder_batch_threshold = reorder_batch_threshold
if self.reorder_batch_threshold is not None \
and supports_spec_as_decode:
# If the backend supports spec-as-decode kernels, then we can set
# the reorder_batch_threshold based on the number of speculative
# tokens from the config.
speculative_config = self.vllm_config.speculative_config
if (speculative_config is not None
and speculative_config.num_speculative_tokens is not None):
self.reorder_batch_threshold = \
1 + speculative_config.num_speculative_tokens
@abstractmethod @abstractmethod
def build(self, def build(self,
common_prefix_len: int, common_prefix_len: int,
...@@ -705,7 +721,7 @@ def subclass_attention_backend( ...@@ -705,7 +721,7 @@ def subclass_attention_backend(
def split_decodes_and_prefills( def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
decode_threshold: int = 1, decode_threshold: int = 1,
) -> tuple[int, int, int, int]: require_uniform: bool = False) -> tuple[int, int, int, int]:
""" """
Assuming a reordered batch, finds the boundary between prefill and decode Assuming a reordered batch, finds the boundary between prefill and decode
requests. requests.
...@@ -714,6 +730,9 @@ def split_decodes_and_prefills( ...@@ -714,6 +730,9 @@ def split_decodes_and_prefills(
common_attn_metadata: CommonAttentionMetadata object containing the common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata. batch metadata.
decode_threshold: The maximum query length to be considered a decode. decode_threshold: The maximum query length to be considered a decode.
require_uniform: If True, requires that all decode requests have the
same query length. When set, some queries may be considered prefills
even if they are <= decode_threshold, in order to ensure uniformity.
Returns: Returns:
num_decodes: The number of decode requests. num_decodes: The number of decode requests.
...@@ -726,11 +745,20 @@ def split_decodes_and_prefills( ...@@ -726,11 +745,20 @@ def split_decodes_and_prefills(
num_tokens = common_attn_metadata.num_actual_tokens num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc_cpu query_start_loc = common_attn_metadata.query_start_loc_cpu
if max_query_len <= decode_threshold: if max_query_len <= decode_threshold and \
(not require_uniform or decode_threshold <= 1):
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
query_lens = query_start_loc[1:] - query_start_loc[:-1] query_lens = query_start_loc[1:] - query_start_loc[:-1]
if query_lens[0].item() > decode_threshold:
# first request is not decode, so no decode requests
return 0, num_reqs, 0, num_tokens
if require_uniform:
is_prefill = query_lens != query_lens[0]
else:
is_prefill = query_lens > decode_threshold is_prefill = query_lens > decode_threshold
if not torch.any(is_prefill): if not torch.any(is_prefill):
return num_reqs, 0, num_tokens, 0 return num_reqs, 0, num_tokens, 0
...@@ -806,6 +834,38 @@ def reorder_batch_to_split_decodes_and_prefills( ...@@ -806,6 +834,38 @@ def reorder_batch_to_split_decodes_and_prefills(
return modified_batch return modified_batch
def reshape_query_for_spec_decode(query: torch.Tensor,
batch_size: int) -> torch.Tensor:
"""
Reshapes the query tensor for the specified batch size, so that
it has shape (batch_size, seq_len, num_heads, head_dim).
"""
assert query.dim() == 3, f"query must be 3D, got {query.dim()}D"
total_tokens = query.shape[0]
num_heads = query.shape[1]
head_dim = query.shape[2]
assert total_tokens % batch_size == 0, (
f"{total_tokens=} is not divisible by {batch_size=}")
seq_len = total_tokens // batch_size
return query.view(batch_size, seq_len, num_heads, head_dim)
def reshape_attn_output_for_spec_decode(
attn_output: torch.Tensor) -> torch.Tensor:
"""
Reshapes the attention output tensor, so that
the batch_size and seq_len dimensions are combined.
"""
if attn_output.dim() == 3:
# Already in the correct shape
return attn_output
assert attn_output.dim() == 4, \
f"attn_output must be 4D, got {attn_output.dim()}D"
total_tokens = attn_output.shape[0] * attn_output.shape[1]
return attn_output.view(total_tokens, attn_output.shape[2],
attn_output.shape[3])
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [ KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
('logits_indices_padded', Optional[torch.Tensor], None), ('logits_indices_padded', Optional[torch.Tensor], None),
('num_logits_indices', int, 0), ('num_logits_indices', int, 0),
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
"""Attention layer with XFormersAttention.""" """Attention layer with XFormersAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -197,7 +197,7 @@ class XFormersAttentionMetadata: ...@@ -197,7 +197,7 @@ class XFormersAttentionMetadata:
class XFormersAttentionMetadataBuilder( class XFormersAttentionMetadataBuilder(
AttentionMetadataBuilder[XFormersAttentionMetadata]): AttentionMetadataBuilder[XFormersAttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: int = 1
def __init__( def __init__(
self, self,
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import ast import ast
from dataclasses import replace from dataclasses import replace
from importlib.util import find_spec from importlib.util import find_spec
from typing import Optional, Protocol from typing import Optional
import numpy as np import numpy as np
import torch import torch
...@@ -37,17 +37,6 @@ logger = init_logger(__name__) ...@@ -37,17 +37,6 @@ logger = init_logger(__name__)
PADDING_SLOT_ID = -1 PADDING_SLOT_ID = -1
class EagleAttentionMetadata(Protocol):
# Required attributes
num_actual_tokens: int
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
class EagleProposer: class EagleProposer:
def __init__( def __init__(
...@@ -120,7 +109,7 @@ class EagleProposer: ...@@ -120,7 +109,7 @@ class EagleProposer:
with_numpy=True) with_numpy=True)
# Determine allowed attention backends once during initialization. # Determine allowed attention backends once during initialization.
self.allowed_attn_types: tuple[type, ...] self.allowed_attn_types: Optional[tuple] = None
if current_platform.is_rocm(): if current_platform.is_rocm():
rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata] rocm_types = [TritonAttentionMetadata, FlashAttentionMetadata]
# vllm.v1.attention.backends.rocm_aiter_fa is an optional backend # vllm.v1.attention.backends.rocm_aiter_fa is an optional backend
...@@ -129,9 +118,6 @@ class EagleProposer: ...@@ -129,9 +118,6 @@ class EagleProposer:
AiterFlashAttentionMetadata) AiterFlashAttentionMetadata)
rocm_types.append(AiterFlashAttentionMetadata) rocm_types.append(AiterFlashAttentionMetadata)
self.allowed_attn_types = tuple(rocm_types) self.allowed_attn_types = tuple(rocm_types)
else:
self.allowed_attn_types = (FlashAttentionMetadata,
TreeAttentionMetadata)
# Parse the speculative token tree. # Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree spec_token_tree = self.speculative_config.speculative_token_tree
...@@ -266,7 +252,8 @@ class EagleProposer: ...@@ -266,7 +252,8 @@ class EagleProposer:
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = logits.argmax(dim=-1)
if not isinstance(attn_metadata, self.allowed_attn_types): if self.allowed_attn_types is not None and \
not isinstance(attn_metadata, self.allowed_attn_types):
raise ValueError( raise ValueError(
f"Unsupported attention metadata type for speculative " f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: " "decoding with num_speculative_tokens > 1: "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment