"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "ad8d696a99ca1eee19f1404e16e8e82df592ff85"
Unverified Commit 8b141ed8 authored by shunting314's avatar shunting314 Committed by GitHub
Browse files

full cudagraph for flex-attn (#36298)


Signed-off-by: default avatarshunting314 <shunting@meta.com>
parent 2ad7c033
...@@ -170,14 +170,3 @@ class TestFullCUDAGraph: ...@@ -170,14 +170,3 @@ class TestFullCUDAGraph:
piecewise_res.outputs[0].text.lower() piecewise_res.outputs[0].text.lower()
== full_res.outputs[0].text.lower() == full_res.outputs[0].text.lower()
) )
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_full_cudagraph_with_invalid_backend():
# Flex_Attention is not supported with full cuda graph
with pytest.raises(RuntimeError):
LLM(
model="Qwen/Qwen2-1.5B-Instruct",
compilation_config=CompilationConfig(cudagraph_mode="FULL"),
attention_config={"backend": "FLEX_ATTENTION"},
)
...@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0") ...@@ -26,6 +26,59 @@ MINIMUM_TORCH_VERSION = version.parse("2.7.0")
DIRECT_BUILD_VERSION = version.parse("2.9.dev0") DIRECT_BUILD_VERSION = version.parse("2.9.dev0")
@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_full_cudagraphs(vllm_runner):
"""Test the numerics for flex attention full cudagraphs support."""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 24
num_logprobs = 5
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]
# Run with flex attention eager
set_random_seed(seed)
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_flex:
output_eager = llm_flex.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
# Run with flex attention compiled
set_random_seed(seed)
with vllm_runner(
model_name,
runner="generate",
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=False,
gpu_memory_utilization=0.85,
attention_config={"backend": "FLEX_ATTENTION"},
) as llm_default:
output_compile = llm_default.generate_greedy_logprobs(
prompts, max_tokens, num_logprobs
)
check_logprobs_close(
outputs_0_lst=output_eager,
outputs_1_lst=output_compile,
name_0="eager",
name_1="compile",
)
@pytest.mark.skipif( @pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION, not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7", reason="CUDA not available or PyTorch version < 2.7",
......
...@@ -30,6 +30,7 @@ from vllm.utils.math_utils import cdiv ...@@ -30,6 +30,7 @@ from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer from vllm.utils.torch_utils import is_quantized_kv_cache, is_torch_equal_or_newer
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport,
AttentionImpl, AttentionImpl,
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
...@@ -315,6 +316,18 @@ class BlockSparsityHint(NamedTuple): ...@@ -315,6 +316,18 @@ class BlockSparsityHint(NamedTuple):
hint_fn: _block_sparsity_hint_signature hint_fn: _block_sparsity_hint_signature
def copy_to_persistent(dst, src):
try:
dst = dst.as_strided(src.shape, src.stride())
except RuntimeError as e:
raise RuntimeError(
f"Fail to re-stride a persistent tensor of shape {dst.shape} "
f"for a tensor of shape {src.shape}"
) from e
dst.copy_(src)
return dst
@dataclass @dataclass
class FlexAttentionMetadata: class FlexAttentionMetadata:
causal: bool causal: bool
...@@ -340,6 +353,9 @@ class FlexAttentionMetadata: ...@@ -340,6 +353,9 @@ class FlexAttentionMetadata:
physical_to_logical: torch.Tensor physical_to_logical: torch.Tensor
decode_offset: torch.Tensor decode_offset: torch.Tensor
num_blocks_per_seq: torch.Tensor num_blocks_per_seq: torch.Tensor
persistent_kv_indices: torch.Tensor
persistent_kv_num_blocks: torch.Tensor
persistent_doc_ids: torch.Tensor
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
...@@ -656,8 +672,11 @@ class FlexAttentionMetadata: ...@@ -656,8 +672,11 @@ class FlexAttentionMetadata:
kv_indices = unique_static_unsorted( kv_indices = unique_static_unsorted(
(used_pages_padded.long()), M=self.num_blocks (used_pages_padded.long()), M=self.num_blocks
).to(torch.int32) ).to(torch.int32)
kv_indices = copy_to_persistent(self.persistent_kv_indices, kv_indices)
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32) kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
kv_num_blocks = copy_to_persistent(self.persistent_kv_num_blocks, kv_num_blocks)
block_mask_kwargs = { block_mask_kwargs = {
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens), "seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
"kv_num_blocks": kv_num_blocks[None, None], "kv_num_blocks": kv_num_blocks[None, None],
...@@ -694,6 +713,7 @@ class FlexAttentionMetadata: ...@@ -694,6 +713,7 @@ class FlexAttentionMetadata:
assert self.suffix_kv_lens is None, "Not implemented yet." assert self.suffix_kv_lens is None, "Not implemented yet."
# Create a lookup mapping from query indices -> request number # Create a lookup mapping from query indices -> request number
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc) self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
self.doc_ids = copy_to_persistent(self.persistent_doc_ids, self.doc_ids)
self.num_blocks = self.total_cache_tokens // self.block_size self.num_blocks = self.total_cache_tokens // self.block_size
self.mask_mod = self.get_mask_mod() self.mask_mod = self.get_mask_mod()
...@@ -701,6 +721,8 @@ class FlexAttentionMetadata: ...@@ -701,6 +721,8 @@ class FlexAttentionMetadata:
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]): class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
def __init__( def __init__(
self, self,
kv_cache_spec: AttentionSpec, kv_cache_spec: AttentionSpec,
...@@ -726,6 +748,38 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -726,6 +748,38 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
self.q_block_size: int = 16 if supports_small_blocks else 128 self.q_block_size: int = 16 if supports_small_blocks else 128
self.kv_block_size: int = self.block_size if supports_small_blocks else 128 self.kv_block_size: int = self.block_size if supports_small_blocks else 128
self.max_model_len = self.model_config.max_model_len
max_num_seqs = vllm_config.scheduler_config.max_num_seqs
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.max_num_q_block = (
self.max_model_len + self.q_block_size - 1
) // self.q_block_size
self.persistent_kv_num_blocks = torch.empty(
self.max_num_q_block, dtype=torch.int32, device=device
)
self.persistent_offset_tensor = torch.empty(
max_num_seqs, dtype=torch.int32, device=device
)
self.persistent_doc_ids = torch.empty(
max_num_batched_tokens, dtype=torch.int32, device=device
)
# initialize later when we can access block_table
self.persistent_physical_to_logical = None
self.persistent_kv_indices = None
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> FlexAttentionMetadata:
# Use actual max_seq_len instead of max_model_len to avoid
# torch.compile recompilation during CUDA graph capture.
common_attn_metadata.max_seq_len = (
common_attn_metadata.seq_lens_cpu.max().item()
)
return self.build(
common_prefix_len=0, common_attn_metadata=common_attn_metadata
)
def build( def build(
self, self,
common_prefix_len: int, common_prefix_len: int,
...@@ -765,8 +819,32 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -765,8 +819,32 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
inverse_block_table = physical_to_logical_mapping( inverse_block_table = physical_to_logical_mapping(
block_table_tensor, seq_lens, block_size, num_gpu_blocks block_table_tensor, seq_lens, block_size, num_gpu_blocks
) )
if self.persistent_physical_to_logical is None:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
self.persistent_physical_to_logical = torch.empty(
max_num_seqs,
num_gpu_blocks,
dtype=torch.long,
device=self.device,
)
if self.persistent_kv_indices is None:
max_num_kv_block = (
self.max_model_len + self.kv_block_size - 1
) // self.kv_block_size
self.persistent_kv_indices = torch.empty(
self.max_model_len,
max_num_kv_block,
dtype=torch.int32,
device=self.device,
)
inverse_block_table = copy_to_persistent(
self.persistent_physical_to_logical, inverse_block_table
)
offset_tensor = common_attn_metadata.compute_num_computed_tokens() offset_tensor = common_attn_metadata.compute_num_computed_tokens()
offset_tensor = copy_to_persistent(self.persistent_offset_tensor, offset_tensor)
out = FlexAttentionMetadata( out = FlexAttentionMetadata(
causal=common_attn_metadata.causal, causal=common_attn_metadata.causal,
...@@ -795,7 +873,20 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -795,7 +873,20 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
direct_build=(self.direct_build and common_attn_metadata.causal), direct_build=(self.direct_build and common_attn_metadata.causal),
q_block_size=self.q_block_size, q_block_size=self.q_block_size,
kv_block_size=self.kv_block_size, kv_block_size=self.kv_block_size,
persistent_kv_indices=self.persistent_kv_indices,
persistent_kv_num_blocks=self.persistent_kv_num_blocks,
persistent_doc_ids=self.persistent_doc_ids,
) )
# Pre-build block_mask so it is ready before CUDA graph capture.
# Without this, the lazy build in forward() would run non-graph-safe
# ops (e.g. torch.nonzero) inside capture.
if out.block_mask is None:
if out.direct_build:
out.block_mask = out._build_block_mask_direct()
else:
out.block_mask = out.build_block_mask()
return out return out
def use_cascade_attention(self, *args, **kwargs) -> bool: def use_cascade_attention(self, *args, **kwargs) -> bool:
......
...@@ -6077,6 +6077,7 @@ class GPUModelRunner( ...@@ -6077,6 +6077,7 @@ class GPUModelRunner(
skip_eplb=True, skip_eplb=True,
remove_lora=False, remove_lora=False,
num_active_loras=desc.num_active_loras, num_active_loras=desc.num_active_loras,
profile_seq_lens=profile_seq_lens,
) )
self._dummy_run( self._dummy_run(
desc.num_tokens, desc.num_tokens,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment