Unverified Commit 1d93f116 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[Attention][CUDAGraph] Remove CG padding from attention backends (#29352)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 2d613de9
...@@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -252,7 +252,6 @@ class MambaMixer(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
...@@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp): ...@@ -281,7 +280,7 @@ class MambaMixer(MambaBase, CustomOp):
state_indices_tensor, state_indices_tensor,
num_prefill_tokens, num_prefill_tokens,
num_prefills, num_prefills,
num_padded_decodes, num_decode_tokens,
) )
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
...@@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode( ...@@ -470,24 +469,24 @@ def split_batch_to_prefill_and_decode(
state_indices_tensor: torch.Tensor, state_indices_tensor: torch.Tensor,
num_prefill_tokens: int, num_prefill_tokens: int,
num_prefills: int, num_prefills: int,
num_padded_decodes: int, num_decode_tokens: int,
) -> PrefillDecodeSplit: ) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes num_actual_tokens = num_prefill_tokens + num_decode_tokens
# In v1, decode tokens come first, then prefill tokens. # In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split( hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens], hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens], [num_decode_tokens, num_prefill_tokens],
dim=-1, dim=-1,
) )
gate_d, gate_p = torch.split( gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1 gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
) )
# num_padded_decodes accounts for CUDA graph padding when applicable # num_decode_tokens accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split( state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_padded_decodes + num_prefills], state_indices_tensor[: num_decode_tokens + num_prefills],
[num_padded_decodes, num_prefills], [num_decode_tokens, num_prefills],
dim=0, dim=0,
) )
......
...@@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -254,17 +254,11 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
) )
else: else:
has_initial_state = None has_initial_state = None
num_actual_tokens = (
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
)
# prepare tensors for cudagraph # Prepare tensors for cudagraph
# # Note: m.num_actual_tokens is already padded by the model runner for CUDAGraph
# With speculative decoding, the xgrammar backend may rollback tokens batch_size = m.num_actual_tokens
# and causing some sequences has less draft tokens than self.num_spec.
#
# In above cases, the max possible batch size for n tokens, can be
# min(n, cudagraph_max_bs).
if ( if (
self.use_full_cuda_graph self.use_full_cuda_graph
and num_prefills == 0 and num_prefills == 0
...@@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -272,9 +266,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes <= self.decode_cudagraph_max_bs and num_spec_decodes <= self.decode_cudagraph_max_bs
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
): ):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
self.spec_state_indices_tensor[:num_spec_decodes].copy_( self.spec_state_indices_tensor[:num_spec_decodes].copy_(
spec_state_indices_tensor, non_blocking=True spec_state_indices_tensor, non_blocking=True
) )
...@@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -319,9 +310,6 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
and num_spec_decodes == 0 and num_spec_decodes == 0
and num_decodes <= self.decode_cudagraph_max_bs and num_decodes <= self.decode_cudagraph_max_bs
): ):
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
batch_size = num_actual_tokens
self.non_spec_state_indices_tensor[:num_decodes].copy_( self.non_spec_state_indices_tensor[:num_decodes].copy_(
non_spec_state_indices_tensor, non_blocking=True non_spec_state_indices_tensor, non_blocking=True
) )
...@@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] ...@@ -344,7 +332,7 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
num_spec_decodes=num_spec_decodes, num_spec_decodes=num_spec_decodes,
num_spec_decode_tokens=num_spec_decode_tokens, num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=num_actual_tokens, num_actual_tokens=m.num_actual_tokens,
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
spec_query_start_loc=spec_query_start_loc, spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc,
......
...@@ -31,7 +31,6 @@ class Mamba1AttentionMetadata: ...@@ -31,7 +31,6 @@ class Mamba1AttentionMetadata:
num_prefill_tokens: int num_prefill_tokens: int
num_decodes: int num_decodes: int
num_decode_tokens: int num_decode_tokens: int
num_padded_decodes: int
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,] block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,] block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
...@@ -68,7 +67,6 @@ class Mamba1AttentionMetadataBuilder( ...@@ -68,7 +67,6 @@ class Mamba1AttentionMetadataBuilder(
has_initial_states_p = None has_initial_states_p = None
query_start_loc_p = None query_start_loc_p = None
padded_decodes = num_decodes
num_computed_tokens, num_computed_tokens_p = None, None num_computed_tokens, num_computed_tokens_p = None, None
block_idx_first_scheduled_token = None block_idx_first_scheduled_token = None
block_idx_first_scheduled_token_p = None block_idx_first_scheduled_token_p = None
...@@ -125,11 +123,10 @@ class Mamba1AttentionMetadataBuilder( ...@@ -125,11 +123,10 @@ class Mamba1AttentionMetadataBuilder(
and num_decodes <= self.decode_cudagraph_max_bs and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
): ):
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_( self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True state_indices_tensor, non_blocking=True
) )
state_indices_tensor = self.state_indices_tensor[:padded_decodes] state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching: if self.vllm_config.cache_config.enable_prefix_caching:
...@@ -137,17 +134,15 @@ class Mamba1AttentionMetadataBuilder( ...@@ -137,17 +134,15 @@ class Mamba1AttentionMetadataBuilder(
block_idx_last_scheduled_token, non_blocking=True block_idx_last_scheduled_token, non_blocking=True
) )
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:padded_decodes :num_decode_tokens
] ]
block_idx_last_scheduled_token[num_decodes:] = 0
self.block_idx_last_computed_token[:num_decodes].copy_( self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True block_idx_last_computed_token, non_blocking=True
) )
block_idx_last_computed_token = self.block_idx_last_computed_token[ block_idx_last_computed_token = self.block_idx_last_computed_token[
:padded_decodes :num_decode_tokens
] ]
block_idx_last_computed_token[num_decodes:] = 0
return Mamba1AttentionMetadata( return Mamba1AttentionMetadata(
query_start_loc_p=query_start_loc_p, query_start_loc_p=query_start_loc_p,
...@@ -157,7 +152,6 @@ class Mamba1AttentionMetadataBuilder( ...@@ -157,7 +152,6 @@ class Mamba1AttentionMetadataBuilder(
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes, num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
num_padded_decodes=padded_decodes,
block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_last_scheduled_token=block_idx_last_scheduled_token,
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
block_idx_last_computed_token=block_idx_last_computed_token, block_idx_last_computed_token=block_idx_last_computed_token,
......
...@@ -10,7 +10,6 @@ from vllm.config import VllmConfig ...@@ -10,7 +10,6 @@ from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
CommonAttentionMetadata, CommonAttentionMetadata,
compute_causal_conv1d_metadata, compute_causal_conv1d_metadata,
split_decodes_and_prefills, split_decodes_and_prefills,
...@@ -304,30 +303,25 @@ class Mamba2AttentionMetadataBuilder( ...@@ -304,30 +303,25 @@ class Mamba2AttentionMetadataBuilder(
num_decodes <= self.decode_cudagraph_max_bs num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
): ):
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_( self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True state_indices_tensor, non_blocking=True
) )
state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
if self.vllm_config.cache_config.enable_prefix_caching: if self.vllm_config.cache_config.enable_prefix_caching:
self.block_idx_last_scheduled_token[:num_decodes].copy_( self.block_idx_last_scheduled_token[:num_decodes].copy_(
block_idx_last_scheduled_token, non_blocking=True block_idx_last_scheduled_token, non_blocking=True
) )
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
:num_input_tokens :num_decode_tokens
] ]
block_idx_last_scheduled_token[num_decodes:] = 0
self.block_idx_last_computed_token[:num_decodes].copy_( self.block_idx_last_computed_token[:num_decodes].copy_(
block_idx_last_computed_token, non_blocking=True block_idx_last_computed_token, non_blocking=True
) )
block_idx_last_computed_token = self.block_idx_last_computed_token[ block_idx_last_computed_token = self.block_idx_last_computed_token[
:num_input_tokens :num_decode_tokens
] ]
block_idx_last_computed_token[num_decodes:] = 0
attn_metadata = Mamba2AttentionMetadata( attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills, num_prefills=num_prefills,
......
...@@ -83,11 +83,10 @@ class ShortConvAttentionMetadataBuilder( ...@@ -83,11 +83,10 @@ class ShortConvAttentionMetadataBuilder(
and num_decodes <= self.decode_cudagraph_max_bs and num_decodes <= self.decode_cudagraph_max_bs
and self.compilation_config.cudagraph_mode.has_full_cudagraphs() and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
): ):
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_( self.state_indices_tensor[:num_decodes].copy_(
state_indices_tensor, non_blocking=True state_indices_tensor, non_blocking=True
) )
state_indices_tensor = self.state_indices_tensor[:num_input_tokens] state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = ShortConvAttentionMetadata( attn_metadata = ShortConvAttentionMetadata(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment