Unverified Commit 6e783bc5 authored by Benjamin Chislett's avatar Benjamin Chislett Committed by GitHub
Browse files

[Bugfix] Fix CUDA graph selection bug in FlashInfer at high concurrency (#26499)


Signed-off-by: default avatarBenjamin Chislett <bchislett@nvidia.com>
parent c9d33c60
...@@ -296,6 +296,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -296,6 +296,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
) )
max_num_reqs = vllm_config.scheduler_config.max_num_seqs max_num_reqs = vllm_config.scheduler_config.max_num_seqs
max_num_pages = max_num_reqs * max_num_pages_per_req max_num_pages = max_num_reqs * max_num_pages_per_req
speculative_config = vllm_config.speculative_config
num_spec_tokens = (
speculative_config.num_speculative_tokens
if speculative_config is not None
else 0
)
self.enable_cuda_graph = ( self.enable_cuda_graph = (
self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL self.compilation_config.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
) )
...@@ -306,7 +312,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -306,7 +312,8 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
int, BatchDecodeWithPagedKVCacheWrapper int, BatchDecodeWithPagedKVCacheWrapper
] = {} ] = {}
self._decode_cudagraph_max_bs = min( self._decode_cudagraph_max_bs = min(
max_num_reqs, self.compilation_config.max_capture_size (1 + num_spec_tokens) * max_num_reqs,
self.compilation_config.max_capture_size,
) )
self.num_qo_heads = self.model_config.get_num_attention_heads( self.num_qo_heads = self.model_config.get_num_attention_heads(
...@@ -679,7 +686,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -679,7 +686,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
use_cudagraph = ( use_cudagraph = (
self.enable_cuda_graph self.enable_cuda_graph
and pure_decode and pure_decode
and num_decodes <= self._decode_cudagraph_max_bs and num_decode_tokens <= self._decode_cudagraph_max_bs
) )
if use_cudagraph: if use_cudagraph:
num_input_tokens = self.vllm_config.pad_for_cudagraph( num_input_tokens = self.vllm_config.pad_for_cudagraph(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment