Unverified Commit c7914d30 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

Reapply [Attention][FA3] Update FA3 to include new swizzle optimization (#34043)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 1b875656
...@@ -38,7 +38,7 @@ else() ...@@ -38,7 +38,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2 GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
...@@ -132,36 +133,39 @@ class TestCudagraphDispatcher: ...@@ -132,36 +133,39 @@ class TestCudagraphDispatcher:
# Test dispatch logic # Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list # 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor( # FULL mode uses exact keys with num_reqs set
num_tokens=8, desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
uniform=False, # PIECEWISE mode uses relaxed keys with num_reqs=None
) desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False)
rt_mode, key = dispatcher.dispatch( rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False num_tokens=8, uniform_decode=False, has_lora=False
) )
if cudagraph_mode_str == "FULL": if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact assert key == desc_full_with_reqs
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact assert key == desc_piecewise
else: else:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list # 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True) desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
rt_mode, key = dispatcher.dispatch( rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False num_tokens=8, uniform_decode=True, has_lora=False
) )
if cudagraph_mode_str == "FULL": if cudagraph_mode_str == "FULL":
# Pure FULL mode uses non-uniform keys for all batches
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs() assert key == desc_non_uniform
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
# These modes have separate uniform decode keys
assert rt_mode == CUDAGraphMode.FULL assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE": elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs() assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False)
else: else:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
...@@ -180,7 +184,7 @@ class TestCudagraphDispatcher: ...@@ -180,7 +184,7 @@ class TestCudagraphDispatcher:
if "PIECEWISE" in cudagraph_mode_str: # string contains check if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs() assert key == replace(desc_full_exact, num_reqs=None, uniform=False)
else: else:
assert rt_mode == CUDAGraphMode.NONE assert rt_mode == CUDAGraphMode.NONE
......
...@@ -5,7 +5,7 @@ import time ...@@ -5,7 +5,7 @@ import time
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, NamedTuple from typing import Any
import torch import torch
...@@ -26,7 +26,8 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL ...@@ -26,7 +26,8 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list) batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple): @dataclass(frozen=True)
class BatchDescriptor:
""" """
Batch descriptor for cudagraph dispatching. We should keep the num of Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded items as minimal as possible to properly and uniquely describe the padded
...@@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple): ...@@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple):
to be properly captured. to be properly captured.
""" """
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)
def _compute_sp_num_tokens( def _compute_sp_num_tokens(
num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -40,7 +40,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionMetadataBuilder, AttentionMetadataBuilder,
...@@ -310,8 +310,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad ...@@ -310,8 +310,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.aot_schedule: if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros( self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1, 1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
......
...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import ( ...@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionCGSupport, AttentionCGSupport,
AttentionLayer, AttentionLayer,
...@@ -129,8 +130,17 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] ...@@ -129,8 +130,17 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule: if self.use_full_cuda_graph and self.fa_aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros( self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1, 1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32, dtype=torch.int32,
device=self.device, device=self.device,
) )
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from itertools import product from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
...@@ -180,12 +181,14 @@ class CudagraphDispatcher: ...@@ -180,12 +181,14 @@ class CudagraphDispatcher:
for bs, num_active_loras in product( for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases self.compilation_config.cudagraph_capture_sizes, lora_cases
): ):
self.add_cudagraph_key( batch_desc = self._create_padded_batch_descriptor(
cudagraph_mode.mixed_mode(), bs, False, num_active_loras > 0, num_active_loras
self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
) )
# Only relax for PIECEWISE mode. FULL mode needs exact num_reqs
# because FA3's scheduler_metadata computation depends on it.
if cudagraph_mode.mixed_mode() == CUDAGraphMode.PIECEWISE:
batch_desc = replace(batch_desc, num_reqs=None, uniform=False)
self.add_cudagraph_key(cudagraph_mode.mixed_mode(), batch_desc)
# if decode cudagraph mode is FULL, and we don't already have mixed # if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here. # mode full cudagraphs then add them here.
...@@ -264,21 +267,23 @@ class CudagraphDispatcher: ...@@ -264,21 +267,23 @@ class CudagraphDispatcher:
batch_desc = self._create_padded_batch_descriptor( batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora, effective_num_active_loras num_tokens, uniform_decode, has_lora, effective_num_active_loras
) )
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
# otherwise, check if the relaxed key exists # check if key exists for full cudagraph
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: # For pure FULL mode, keys are registered with uniform=False.
return CUDAGraphMode.FULL, relaxed_batch_desc batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if (
not disable_full
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
):
return CUDAGraphMode.FULL, batch_desc_to_check
# also check if the relaxed key exists for more "general" # also check if the relaxed key exists for more "general"
# piecewise cudagraph # piecewise cudagraph
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
# finally, just return no cudagraphs and a trivial batch descriptor # finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment