"docs/vscode:/vscode.git/clone" did not exist on "c5362c739fb31c171fd345ed4a83fb0127804aa3"
Unverified Commit 2267cb1c authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention][FA3] Update FA3 to include new swizzle optimization (#23465)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 0d6ccf68
...@@ -38,7 +38,7 @@ else() ...@@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2 GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
...@@ -308,10 +308,15 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -308,10 +308,15 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.compilation_config.cudagraph_mode.has_full_cudagraphs()
) )
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
if self.use_full_cuda_graph and self.aot_schedule: if self.use_full_cuda_graph and self.aot_schedule:
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
self.scheduler_metadata = torch.zeros( self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1, max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
......
...@@ -127,10 +127,15 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -127,10 +127,15 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.compilation_config.cudagraph_mode.has_full_cudagraphs() self.compilation_config.cudagraph_mode.has_full_cudagraphs()
) )
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
if self.use_full_cuda_graph and self.fa_aot_schedule: if self.use_full_cuda_graph and self.fa_aot_schedule:
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
self.scheduler_metadata = torch.zeros( self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1, max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment