Unverified Commit c7914d30 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

Reapply [Attention][FA3] Update FA3 to include new swizzle optimization (#34043)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 1b875656
......@@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from unittest.mock import MagicMock, patch
import pytest
......@@ -132,36 +133,39 @@ class TestCudagraphDispatcher:
# Test dispatch logic
# 1. non-uniform batch, size in cudagraph size list
desc_full_exact = BatchDescriptor(
num_tokens=8,
uniform=False,
)
# FULL mode uses exact keys with num_reqs set
desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
# PIECEWISE mode uses relaxed keys with num_reqs=None
desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False
)
if cudagraph_mode_str == "FULL":
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_full_exact
assert key == desc_full_with_reqs
elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact
assert key == desc_piecewise
else:
assert rt_mode == CUDAGraphMode.NONE
# 2. uniform decode batch, size in cudagraph size list
desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=True, has_lora=False
)
if cudagraph_mode_str == "FULL":
# Pure FULL mode uses non-uniform keys for all batches
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
assert key == desc_non_uniform
elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
# These modes have separate uniform decode keys
assert rt_mode == CUDAGraphMode.FULL
assert key == desc_uniform_exact
elif cudagraph_mode_str == "PIECEWISE":
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False)
else:
assert rt_mode == CUDAGraphMode.NONE
......@@ -180,7 +184,7 @@ class TestCudagraphDispatcher:
if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
assert key == replace(desc_full_exact, num_reqs=None, uniform=False)
else:
assert rt_mode == CUDAGraphMode.NONE
......
......@@ -5,7 +5,7 @@ import time
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Any, NamedTuple
from typing import Any
import torch
......@@ -26,7 +26,8 @@ batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL
batchsize_forward_time: defaultdict = defaultdict(list)
class BatchDescriptor(NamedTuple):
@dataclass(frozen=True)
class BatchDescriptor:
"""
Batch descriptor for cudagraph dispatching. We should keep the num of
items as minimal as possible to properly and uniquely describe the padded
......@@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple):
to be properly captured.
"""
def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)
def _compute_sp_num_tokens(
num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int
......
......@@ -40,7 +40,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv
from vllm.utils.math_utils import cdiv, round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionMetadataBuilder,
......@@ -310,8 +310,17 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)
......
......@@ -21,6 +21,7 @@ from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import round_up
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
......@@ -129,8 +130,17 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
if self.use_full_cuda_graph and self.fa_aot_schedule:
# FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4
# The +1 is for the tile_count_semaphore (synchronization).
# The 4 slots per batch element (num_prepare_batch_vectors) are:
# prepare_varlen + dynamic_split + sort_batches + head_swizzle
# See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501
max_batch_size = max(
vllm_config.scheduler_config.max_num_seqs,
self.max_cudagraph_size or 0,
)
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
1 + round_up(max_batch_size, 4) * 4,
dtype=torch.int32,
device=self.device,
)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import replace
from itertools import product
from vllm.config import CUDAGraphMode, VllmConfig
......@@ -180,12 +181,14 @@ class CudagraphDispatcher:
for bs, num_active_loras in product(
self.compilation_config.cudagraph_capture_sizes, lora_cases
):
self.add_cudagraph_key(
cudagraph_mode.mixed_mode(),
self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
).relax_for_mixed_batch_cudagraphs(),
batch_desc = self._create_padded_batch_descriptor(
bs, False, num_active_loras > 0, num_active_loras
)
# Only relax for PIECEWISE mode. FULL mode needs exact num_reqs
# because FA3's scheduler_metadata computation depends on it.
if cudagraph_mode.mixed_mode() == CUDAGraphMode.PIECEWISE:
batch_desc = replace(batch_desc, num_reqs=None, uniform=False)
self.add_cudagraph_key(cudagraph_mode.mixed_mode(), batch_desc)
# if decode cudagraph mode is FULL, and we don't already have mixed
# mode full cudagraphs then add them here.
......@@ -264,21 +267,23 @@ class CudagraphDispatcher:
batch_desc = self._create_padded_batch_descriptor(
num_tokens, uniform_decode, has_lora, effective_num_active_loras
)
relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs()
if not disable_full:
# check if key exists for full cudagraph
if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, batch_desc
# otherwise, check if the relaxed key exists
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]:
return CUDAGraphMode.FULL, relaxed_batch_desc
# check if key exists for full cudagraph
# For pure FULL mode, keys are registered with uniform=False.
batch_desc_to_check = batch_desc
if self.cudagraph_mode == CUDAGraphMode.FULL:
batch_desc_to_check = replace(batch_desc, uniform=False)
if (
not disable_full
and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL]
):
return CUDAGraphMode.FULL, batch_desc_to_check
# also check if the relaxed key exists for more "general"
# piecewise cudagraph
if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, relaxed_batch_desc
batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False)
if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
return CUDAGraphMode.PIECEWISE, batch_desc_to_check
# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment