Unverified Commit 4ac22722 authored by sungsoo ha's avatar sungsoo ha Committed by GitHub
Browse files

[Bugfix][DCP] Fix CUDA graph capture for Decode Context Parallelism (#36070)


Signed-off-by: default avatarSungsoo Ha <sungsooh@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent bb51d5b4
...@@ -22,9 +22,11 @@ from vllm.v1.attention.backends.fa_utils import ( ...@@ -22,9 +22,11 @@ from vllm.v1.attention.backends.fa_utils import (
get_flash_attn_version, get_flash_attn_version,
is_flash_attn_varlen_func_available, is_flash_attn_varlen_func_available,
) )
from vllm.v1.attention.backends.utils import get_dcp_local_seq_lens
from vllm.v1.attention.ops.common import cp_lse_ag_out_rs from vllm.v1.attention.ops.common import cp_lse_ag_out_rs
from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce from vllm.v1.attention.ops.dcp_alltoall import dcp_a2a_lse_reduce
from vllm.v1.attention.ops.merge_attn_states import merge_attn_states from vllm.v1.attention.ops.merge_attn_states import merge_attn_states
from vllm.v1.worker.workspace import current_workspace_manager
if is_flash_attn_varlen_func_available(): if is_flash_attn_varlen_func_available():
from vllm.v1.attention.backends.fa_utils import ( from vllm.v1.attention.backends.fa_utils import (
...@@ -52,7 +54,6 @@ from vllm.v1.attention.backend import ( ...@@ -52,7 +54,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata, CommonAttentionMetadata,
) )
from vllm.v1.attention.backends.utils import ( from vllm.v1.attention.backends.utils import (
get_dcp_local_seq_lens,
get_kv_cache_layout, get_kv_cache_layout,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
...@@ -356,6 +357,14 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -356,6 +357,14 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.attention_config.flash_attn_max_num_splits_for_cuda_graph self.attention_config.flash_attn_max_num_splits_for_cuda_graph
) )
if self.dcp_world_size > 1:
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
self._dcp_context_kv_lens = torch.zeros(
max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Sliding window size to be used with the AOT scheduler will be # Sliding window size to be used with the AOT scheduler will be
# populated on first build() call. # populated on first build() call.
self.aot_sliding_window: tuple[int, int] | None = None self.aot_sliding_window: tuple[int, int] | None = None
...@@ -452,15 +461,18 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -452,15 +461,18 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
prefix_scheduler_metadata = None prefix_scheduler_metadata = None
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
query_kv_lens = query_start_loc[1:] - query_start_loc[:-1] query_lens = query_start_loc[1:] - query_start_loc[:-1]
dcp_context_kv_lens = seq_lens - query_kv_lens context_kv_lens = seq_lens - query_lens
local_context_kv_lens = get_dcp_local_seq_lens(
dcp_context_kv_lens = get_dcp_local_seq_lens( context_kv_lens,
dcp_context_kv_lens,
self.dcp_world_size, self.dcp_world_size,
self.dcp_rank, self.dcp_rank,
self.cp_kv_cache_interleave_size, self.cp_kv_cache_interleave_size,
) )
self._dcp_context_kv_lens[:num_reqs] = local_context_kv_lens
self._dcp_context_kv_lens[num_reqs:] = 0
dcp_context_kv_lens = self._dcp_context_kv_lens[:num_reqs]
# After DCP distribution, the maximum number of tokens for any rank is # After DCP distribution, the maximum number of tokens for any rank is
# ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size, # ceil(L / (N * I)) * I, where L is max_seq_len, N is dcp_world_size,
# and I is cp_kv_cache_interleave_size. # and I is cp_kv_cache_interleave_size.
...@@ -637,6 +649,10 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -637,6 +649,10 @@ class FlashAttentionImpl(AttentionImpl):
) )
self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs self.dcp_combine = dcp_a2a_lse_reduce if dcp_a2a else cp_lse_ag_out_rs
self._dcp_dtype: torch.dtype | None = None
if vllm_config is not None and self.dcp_world_size > 1:
self._dcp_dtype = vllm_config.model_config.dtype
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -862,11 +878,18 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -862,11 +878,18 @@ class FlashAttentionImpl(AttentionImpl):
sliding_window_size = ( sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None list(self.sliding_window) if self.sliding_window is not None else None
) )
n = query_across_dcp.shape[0]
(dcp_context_out,) = current_workspace_manager().get_simultaneous(
(
(n, self.num_heads * self.dcp_world_size, self.head_size),
self._dcp_dtype,
),
)
context_attn_out, context_lse = flash_attn_varlen_func( context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp, q=query_across_dcp,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
out=None, out=dcp_context_out,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
seqused_k=attn_metadata.dcp_context_kv_lens, seqused_k=attn_metadata.dcp_context_kv_lens,
...@@ -894,11 +917,14 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -894,11 +917,14 @@ class FlashAttentionImpl(AttentionImpl):
) )
context_lse_cor = context_lse_cor.transpose(0, 1).contiguous() context_lse_cor = context_lse_cor.transpose(0, 1).contiguous()
(dcp_query_out,) = current_workspace_manager().get_simultaneous(
((query.shape[0], self.num_heads, self.head_size), self._dcp_dtype),
)
query_attn_out, query_lse = flash_attn_varlen_func( query_attn_out, query_lse = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
out=None, out=dcp_query_out,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
cu_seqlens_k=cu_seqlens_q, cu_seqlens_k=cu_seqlens_q,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment