Unverified Commit 7ba5ad57 authored by DarkSharpness's avatar DarkSharpness Committed by GitHub
Browse files

[Fix] Fix flashinfer cpu <-> gpu synchronization (#8340)

parent 19bc77f0
...@@ -66,6 +66,10 @@ class PrefillMetadata: ...@@ -66,6 +66,10 @@ class PrefillMetadata:
# Reuse this workspace buffer across all flashinfer wrappers # Reuse this workspace buffer across all flashinfer wrappers
global_workspace_buffer = None global_workspace_buffer = None
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global_override_indptr_cpu = None
class FlashInferAttnBackend(AttentionBackend): class FlashInferAttnBackend(AttentionBackend):
"""Flashinfer attention kernels.""" """Flashinfer attention kernels."""
...@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_decode.update( self.indices_updater_decode.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers, decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
...@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
prefix_lens=None, prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_paged, prefill_wrappers=self.prefill_wrappers_paged,
...@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
prefix_lens=None, prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_verify, prefill_wrappers=self.prefill_wrappers_verify,
...@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
forward_batch.req_pool_indices, forward_batch.req_pool_indices,
forward_batch.seq_lens, forward_batch.seq_lens,
forward_batch.seq_lens_cpu,
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
prefix_lens, prefix_lens,
prefill_wrappers=self.prefill_wrappers_paged, prefill_wrappers=self.prefill_wrappers_paged,
...@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
seq_lens.cpu(), # may add a little overhead in capture stage
seq_lens_sum, seq_lens_sum,
decode_wrappers=decode_wrappers, decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
...@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
seq_lens.cpu(), # may add a little overhead in capture stage
seq_lens_sum, seq_lens_sum,
prefix_lens=None, prefix_lens=None,
prefill_wrappers=prefill_wrappers, prefill_wrappers=prefill_wrappers,
...@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
seq_lens.cpu(), # may add a little overhead in capture stage
seq_lens_sum, seq_lens_sum,
prefix_lens=None, prefix_lens=None,
prefill_wrappers=prefill_wrappers, prefill_wrappers=prefill_wrappers,
...@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_decode.update( self.indices_updater_decode.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
seq_lens_sum, seq_lens_sum,
decode_wrappers=self.decode_cuda_graph_metadata[bs], decode_wrappers=self.decode_cuda_graph_metadata[bs],
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
...@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
seq_lens_sum, seq_lens_sum,
prefix_lens=None, prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs], prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
...@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
self.indices_updater_prefill.update( self.indices_updater_prefill.update(
req_pool_indices[:bs], req_pool_indices[:bs],
seq_lens[:bs], seq_lens[:bs],
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
seq_lens_sum, seq_lens_sum,
prefix_lens=None, prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs], prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
...@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
class FlashInferIndicesUpdaterDecode: class FlashInferIndicesUpdaterDecode:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
# Parse Constants # Parse Constants
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
...@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
...@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
...@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode: ...@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr[0], self.kv_indptr[0],
None, None,
spec_info, spec_info,
seq_lens_cpu,
) )
def update_sliding_window( def update_sliding_window(
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
): ):
assert self.sliding_window_size is not None
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
# Sliding window attention # Sliding window attention
paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp paged_kernel_lens_tmp = torch.clamp(
seq_lens, seq_lens, max=self.sliding_window_size + 1
torch.tensor(self.sliding_window_size + 1),
) )
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item() if seq_lens_cpu is not None:
seq_lens_cpu_tmp = torch.clamp(
seq_lens_cpu, max=self.sliding_window_size + 1
)
paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
else:
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
else: else:
# Full attention # Full attention
paged_kernel_lens_tmp = seq_lens paged_kernel_lens_tmp = seq_lens
paged_kernel_lens_sum_tmp = seq_lens_sum paged_kernel_lens_sum_tmp = seq_lens_sum
seq_lens_cpu_tmp = seq_lens_cpu
kv_start_idx_tmp = None kv_start_idx_tmp = None
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance( use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
...@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
kv_start_idx_tmp, kv_start_idx_tmp,
spec_info, spec_info,
seq_lens_cpu=seq_lens_cpu_tmp,
use_sliding_window_kv_pool=use_sliding_window_kv_pool, use_sliding_window_kv_pool=use_sliding_window_kv_pool,
) )
...@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor], encoder_lens: Optional[torch.Tensor],
...@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
kv_start_idx, kv_start_idx,
spec_info, spec_info,
seq_lens_cpu=seq_lens_cpu,
) )
def call_begin_forward( def call_begin_forward(
...@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False, use_sliding_window_kv_pool: bool = False,
): ):
if spec_info is None: if spec_info is None:
...@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode: ...@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
) )
) )
global global_override_indptr_cpu
locally_override = False
if seq_lens_cpu is not None and global_override_indptr_cpu is None:
locally_override = True
global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
global_override_indptr_cpu[0] = 0
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
wrapper.begin_forward( wrapper.begin_forward(
kv_indptr, kv_indptr,
kv_indices, kv_indices,
...@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode: ...@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
non_blocking=True, non_blocking=True,
) )
if locally_override:
global_override_indptr_cpu = None
class FlashInferIndicesUpdaterPrefill: class FlashInferIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
# Parse Constants # Parse Constants
self.num_qo_heads = ( self.num_qo_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
...@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
...@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
...@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
...@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
self, self,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_cpu: Optional[torch.Tensor],
seq_lens_sum: int, seq_lens_sum: int,
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
...@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
) )
# Use as a fast path to override the indptr in flashinfer's plan function
# This is used to remove some host-to-device copy overhead.
global global_override_indptr_cpu
class FlashInferMultiStepDraftBackend: class FlashInferMultiStepDraftBackend:
""" """
Wrap multiple flashinfer attention backends as one for multiple consecutive Wrap multiple flashinfer attention backends as one for multiple consecutive
...@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
self.kv_last_page_len = torch.ones( self.kv_last_page_len = torch.ones(
(max_bs,), dtype=torch.int32, device=model_runner.device (max_bs,), dtype=torch.int32, device=model_runner.device
) )
self.attn_backends = [] self.attn_backends: List[FlashInferAttnBackend] = []
for i in range(self.speculative_num_steps): for i in range(self.speculative_num_steps):
self.attn_backends.append( self.attn_backends.append(
FlashInferAttnBackend( FlashInferAttnBackend(
...@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend: ...@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
encoder_lens=None, encoder_lens=None,
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info, spec_info=forward_batch.spec_info,
seq_lens_cpu=None, seq_lens_cpu=forward_batch.seq_lens_cpu,
) )
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
......
...@@ -1714,16 +1714,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1714,16 +1714,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
attention_backend_str = global_server_args_dict["prefill_attention_backend"] attention_backend_str = global_server_args_dict["prefill_attention_backend"]
# Create seq_lens_cpu when needed # Create seq_lens_cpu when needed
if ( if (
attention_backend_str == "fa3" attention_backend_str
or ( in [
global_server_args_dict["use_mla_backend"] "fa3",
and attention_backend_str == "flashinfer" "flashinfer",
) "flashmla",
or attention_backend_str == "flashmla" "cutlass_mla",
or attention_backend_str == "cutlass_mla" "ascend",
or attention_backend_str == "ascend" "trtllm_mha",
or attention_backend_str == "trtllm_mha" "aiter",
or attention_backend_str == "aiter" ]
or global_server_args_dict["enable_two_batch_overlap"] or global_server_args_dict["enable_two_batch_overlap"]
): ):
seq_lens_cpu = ( seq_lens_cpu = (
......
...@@ -729,10 +729,12 @@ class CudaGraphRunner: ...@@ -729,10 +729,12 @@ class CudaGraphRunner:
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions) self.positions[:raw_num_token].copy_(forward_batch.positions)
seq_lens_cpu = None
if forward_batch.seq_lens_cpu is not None: if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs: if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
seq_lens_cpu = self.seq_lens_cpu[:bs]
if pp_proxy_tensors: if pp_proxy_tensors:
for key in self.pp_proxy_tensors.keys(): for key in self.pp_proxy_tensors.keys():
...@@ -766,7 +768,7 @@ class CudaGraphRunner: ...@@ -766,7 +768,7 @@ class CudaGraphRunner:
self.encoder_lens[:bs] if self.is_encoder_decoder else None, self.encoder_lens[:bs] if self.is_encoder_decoder else None,
self.capture_forward_mode, self.capture_forward_mode,
forward_batch.spec_info, forward_batch.spec_info,
seq_lens_cpu=self.seq_lens_cpu[:bs], seq_lens_cpu=seq_lens_cpu,
) )
# Store fields # Store fields
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment