Unverified Commit 54254f7a authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][CI] Fix spec decode logprobs flakiness and parametrize tree attention backends (#34599)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent cf93c1a1
...@@ -52,7 +52,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]: ...@@ -52,7 +52,7 @@ def vllm_model(vllm_runner, request) -> Generator[VllmRunner, None, None]:
# TODO: enable this once we support it for # TODO: enable this once we support it for
# prompt logprobs. # prompt logprobs.
enable_prefix_caching=request.param, enable_prefix_caching=request.param,
gpu_memory_utilization=0.4, # up to 2 alive concurrently gpu_memory_utilization=0.4,
) as vllm_model: ) as vllm_model:
yield vllm_model yield vllm_model
...@@ -366,14 +366,13 @@ def test_max_logprobs(): ...@@ -366,14 +366,13 @@ def test_max_logprobs():
Should also fail for `prompt_logprobs > max_logprobs` Should also fail for `prompt_logprobs > max_logprobs`
APC should not matter as this test checks basic request validation. APC should not matter as this test checks basic request validation.
""" """
runner = VllmRunner( with VllmRunner(
"facebook/opt-125m", "facebook/opt-125m",
max_logprobs=1, max_logprobs=1,
enable_prefix_caching=False, enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=256, max_model_len=256,
) ) as runner:
vllm_sampling_params = SamplingParams(logprobs=1) vllm_sampling_params = SamplingParams(logprobs=1)
# should pass # should pass
runner.generate(["Hello world"], sampling_params=vllm_sampling_params) runner.generate(["Hello world"], sampling_params=vllm_sampling_params)
...@@ -449,15 +448,13 @@ def test_all_logprobs(example_prompts): ...@@ -449,15 +448,13 @@ def test_all_logprobs(example_prompts):
Args: Args:
example_prompts: list of example prompts (test fixture) example_prompts: list of example prompts (test fixture)
""" """
runner = VllmRunner( with VllmRunner(
"facebook/opt-125m", "facebook/opt-125m",
max_logprobs=-1, max_logprobs=-1,
enable_prefix_caching=False, enable_prefix_caching=False,
# 2 other llms alive during whole session
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=256, max_model_len=256,
) ) as runner:
sampling_params_logprobs_all = SamplingParams( sampling_params_logprobs_all = SamplingParams(
max_tokens=5, logprobs=-1, prompt_logprobs=-1 max_tokens=5, logprobs=-1, prompt_logprobs=-1
) )
...@@ -495,6 +492,7 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): ...@@ -495,6 +492,7 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
max_model_len=16, max_model_len=16,
logprobs_mode=logprobs_mode, logprobs_mode=logprobs_mode,
) )
try:
vllm_sampling_params = SamplingParams(logprobs=1) vllm_sampling_params = SamplingParams(logprobs=1)
results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params) results = llm.generate(["Hello world"], sampling_params=vllm_sampling_params)
...@@ -512,7 +510,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode): ...@@ -512,7 +510,10 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode):
assert total_token_with_logprobs >= len(results[0].outputs) assert total_token_with_logprobs >= len(results[0].outputs)
if logprobs_mode in ("raw_logits", "processed_logits"): if logprobs_mode in ("raw_logits", "processed_logits"):
assert positive_values > 0 assert positive_values > 0
finally:
del llm del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
class TestCorrectDecodedToken: class TestCorrectDecodedToken:
...@@ -767,7 +768,7 @@ class TestCorrectDecodedToken: ...@@ -767,7 +768,7 @@ class TestCorrectDecodedToken:
# Simulate cases where individual tokens decode to "�" # Simulate cases where individual tokens decode to "�"
# but combinations decode correctly # but combinations decode correctly
if len(ids) == 1: if len(ids) == 1:
if ids[0] == 3 or ids[0] == 4 or ids[0] == 8 or ids[0] == 9: if ids[0] in (3, 4, 8, 9):
return "�" return "�"
elif len(ids) == 2: elif len(ids) == 2:
if ids == [2, 3]: if ids == [2, 3]:
...@@ -809,14 +810,13 @@ def test_verify_tokens_integration(): ...@@ -809,14 +810,13 @@ def test_verify_tokens_integration():
corrects tokens ending with the replacement character "�". corrects tokens ending with the replacement character "�".
Uses facebook/opt-125m which is known to produce these issues. Uses facebook/opt-125m which is known to produce these issues.
""" """
runner = VllmRunner( with VllmRunner(
"facebook/opt-125m", "facebook/opt-125m",
max_logprobs=0, max_logprobs=0,
enable_prefix_caching=False, enable_prefix_caching=False,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=256, max_model_len=256,
) ) as runner:
# Use a prompt that triggers multi-byte UTF-8 issues # Use a prompt that triggers multi-byte UTF-8 issues
# Based on user's example: "In this example," # Based on user's example: "In this example,"
test_prompts = ["In this example,"] test_prompts = ["In this example,"]
...@@ -853,14 +853,13 @@ def test_utf8_edge_cases_with_real_model(): ...@@ -853,14 +853,13 @@ def test_utf8_edge_cases_with_real_model():
Tests prompts that are likely to trigger byte-fallback tokenization Tests prompts that are likely to trigger byte-fallback tokenization
and multi-byte UTF-8 splitting. and multi-byte UTF-8 splitting.
""" """
runner = VllmRunner( with VllmRunner(
"facebook/opt-125m", "facebook/opt-125m",
max_logprobs=1, max_logprobs=1,
enable_prefix_caching=False, enable_prefix_caching=False,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=256, max_model_len=256,
) ) as runner:
# Prompts with various multi-byte UTF-8 characters # Prompts with various multi-byte UTF-8 characters
test_prompts = [ test_prompts = [
'Smart quotes: "Hello"', # Curly quotes 'Smart quotes: "Hello"', # Curly quotes
...@@ -901,14 +900,13 @@ def test_correct_decoded_token_preserves_valid_tokens(): ...@@ -901,14 +900,13 @@ def test_correct_decoded_token_preserves_valid_tokens():
ending with "�", but this test verifies the broader _verify_tokens ending with "�", but this test verifies the broader _verify_tokens
logic doesn't affect valid tokens. logic doesn't affect valid tokens.
""" """
runner = VllmRunner( with VllmRunner(
"facebook/opt-125m", "facebook/opt-125m",
max_logprobs=2, max_logprobs=2,
enable_prefix_caching=False, enable_prefix_caching=False,
gpu_memory_utilization=0.15, gpu_memory_utilization=0.15,
max_model_len=256, max_model_len=256,
) ) as runner:
# Simple prompt with standard ASCII characters # Simple prompt with standard ASCII characters
test_prompts = ["Hello world, this is a test."] test_prompts = ["Hello world, this is a test."]
...@@ -985,16 +983,33 @@ def test_correct_decoded_token_preserves_valid_tokens(): ...@@ -985,16 +983,33 @@ def test_correct_decoded_token_preserves_valid_tokens():
def test_spec_decode_logprobs( def test_spec_decode_logprobs(
logprobs_mode: LogprobsMode, logprobs_mode: LogprobsMode,
model_setup: tuple[str, str, dict, int], model_setup: tuple[str, str, dict, int],
monkeypatch,
): ):
"""Spec decode logprobs should match those of the base model. """Spec decode logprobs should match those of the base model.
Runs the base model and spec decode model sequentially, ensuring
only one LLM instance is alive at a time to avoid GPU memory
contention. Both use identical chunked prefill settings and eager
mode to control for infrastructure differences.
Args: Args:
logprobs_mode: logprobs mode. logprobs_mode: logprobs mode.
model_setup: Tuple of (method, base model name, model_setup: Tuple of (method, base model name,
speculative_config dict, top_logprobs). speculative_config dict, top_logprobs).
monkeypatch: pytest fixture for setting env vars.
""" """
from vllm import LLM from vllm import LLM
# The ROCm skinny GEMM kernels (gemm_kernels.cu) are
# non-deterministic across LLM instantiations due to persistent
# workgroup scheduling and wave-level shuffle reductions, which
# causes logprob differences that get misattributed to spec decode.
# Disable them so this test isolates spec decode correctness only.
# TODO(akaratza): Remove this workaround once the follow-up to
# https://github.com/vllm-project/vllm/pull/33493#issuecomment-3906083975
# lands with a determinism fix for wvSplitK kernels.
monkeypatch.setenv("VLLM_ROCM_USE_SKINNY_GEMM", "0")
method, model_name, spec_config, top_logprobs = model_setup method, model_name, spec_config, top_logprobs = model_setup
prompt = "Hello world " * 50 prompt = "Hello world " * 50
...@@ -1068,8 +1083,17 @@ def test_spec_decode_logprobs( ...@@ -1068,8 +1083,17 @@ def test_spec_decode_logprobs(
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs): for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose( assert math.isclose(
ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1 ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1
), (
f"Logprob mismatch: ref={ref_logprob.logprob} "
f"spec={spec_logprob.logprob} "
f"diff={abs(ref_logprob.logprob - spec_logprob.logprob)} "
f"(token={ref_logprob.decoded_token!r})"
)
assert ref_logprob.rank == spec_logprob.rank, (
f"Rank mismatch: ref={ref_logprob.rank} "
f"spec={spec_logprob.rank} "
f"(token={ref_logprob.decoded_token!r})"
) )
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token assert ref_logprob.decoded_token == spec_logprob.decoded_token
......
...@@ -13,6 +13,7 @@ from tests.v1.attention.utils import ( ...@@ -13,6 +13,7 @@ from tests.v1.attention.utils import (
try_get_attention_backend, try_get_attention_backend,
) )
from vllm.config import ParallelConfig, SpeculativeConfig from vllm.config import ParallelConfig, SpeculativeConfig
from vllm.platforms import current_platform
from vllm.v1.attention.backend import CommonAttentionMetadata from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
...@@ -23,11 +24,156 @@ if not is_flash_attn_varlen_func_available(): ...@@ -23,11 +24,156 @@ if not is_flash_attn_varlen_func_available():
allow_module_level=True, allow_module_level=True,
) )
# --------------------------------------------------------------------------- #
# KV cache layout adaptation
# --------------------------------------------------------------------------- #
# Two KV cache layouts exist across backends:
#
# Flash layout: (2, num_blocks, block_size, num_kv_heads, head_size)
# - dim 0 separates key (index 0) and value (index 1)
# - Used by: FLASH_ATTN, TREE_ATTN, ROCM_AITER_FA, ROCM_ATTN
#
# Block layout: (num_blocks, 2, block_size, num_kv_heads, head_size)
# - dim 1 separates key (index 0) and value (index 1)
# - Used by: TRITON_ATTN
#
# The test creates KV caches in flash layout (the canonical format used by
# tree attention). When a reference backend needs block layout we transpose
# dims 0 and 1.
#
# Note: ROCM_ATTN uses flash layout for storage but its forward path calls
# PagedAttention.split_kv_cache which reinterprets the raw memory as paged
# layout (num_blocks, num_kv_heads, head_size//x, block_size, x). This is
# a view-level incompatibility, not a transpose - see the TODO in
# _get_available_reference_backends for details.
#
# TODO: Replace this mapping with a `KV_CACHE_LAYOUT` class attribute on each
# AttentionImpl so the layout is self-documented by the backend itself, e.g.:
# class TritonAttentionImpl(AttentionImpl):
# KV_CACHE_LAYOUT = "block"
# --------------------------------------------------------------------------- #
_BLOCK_KV_LAYOUT_BACKENDS = frozenset(
{
AttentionBackendEnum.TRITON_ATTN,
}
)
# Backends whose do_kv_cache_update requires engine-level state (e.g.
# ForwardContext) that is not available in this test harness, but whose
# KV cache is flash layout and can be written with reshape_and_cache_flash.
# When a backend is listed here, forward_attention() bypasses
# do_kv_cache_update and writes directly to the cache.
_NEEDS_DIRECT_CACHE_UPDATE = frozenset(
{
AttentionBackendEnum.ROCM_AITER_FA,
}
)
# Backends with known test-harness incompatibilities - see the TODOs
# inside _get_available_reference_backends for details.
_INCOMPATIBLE_REFERENCE_BACKENDS = frozenset(
{
AttentionBackendEnum.ROCM_AITER_FA,
AttentionBackendEnum.ROCM_ATTN,
}
)
def _adapt_kv_cache_for_backend(
kv_cache: torch.Tensor,
backend: AttentionBackendEnum,
) -> torch.Tensor:
"""Convert kv_cache from flash layout ``(2, num_blocks, ...)`` to block
layout ``(num_blocks, 2, ...)`` if the backend requires it. Returns the
original tensor unchanged when no conversion is needed."""
if backend in _BLOCK_KV_LAYOUT_BACKENDS:
return kv_cache.transpose(0, 1).contiguous()
return kv_cache
def _get_platform_default_backend() -> AttentionBackendEnum:
"""Ask the platform what backend it would auto-select at runtime."""
from vllm.v1.attention.selector import AttentionSelectorConfig
config = AttentionSelectorConfig(
block_size=32,
kv_cache_dtype="auto",
use_mla=False,
use_sparse=False,
head_size=128,
dtype=torch.bfloat16,
)
backend_path = current_platform.get_attn_backend_cls(
selected_backend=None,
attn_selector_config=config,
)
for backend in AttentionBackendEnum:
try:
if backend.get_path() == backend_path:
return backend
except ValueError:
continue
raise RuntimeError(
f"Platform returned backend path '{backend_path}' "
f"that doesn't match any AttentionBackendEnum member."
)
def _get_available_reference_backends() -> list[AttentionBackendEnum]:
"""Collect all reference backends the current platform can run.
On CUDA this is just FLASH_ATTN. On ROCm this includes the platform
default plus every backend the hardware supports, so the test validates
tree attention against all of them.
"""
if current_platform.is_rocm():
backends: list[AttentionBackendEnum] = []
# 1. Whatever the platform would auto-select at runtime.
default_backend = _get_platform_default_backend()
if default_backend not in _INCOMPATIBLE_REFERENCE_BACKENDS:
backends.append(default_backend)
# 2. TRITON_ATTN - always available on ROCm.
if AttentionBackendEnum.TRITON_ATTN not in backends:
backends.append(AttentionBackendEnum.TRITON_ATTN)
# TODO: Enable ROCM_ATTN. Its forward path uses
# PagedAttention.split_kv_cache which reinterprets the raw
# cache memory as paged layout:
# key: (num_blocks, num_kv_heads, head_size//x, block_size, x)
# value: (num_blocks, num_kv_heads, head_size, block_size)
# Tree attention writes prefix data in NHD flash layout, so the
# same bytes produce completely different values when read in
# paged format. Supporting ROCM_ATTN would require writing
# prefix data via PagedAttention.write_to_paged_cache into a
# separate paged-format KV cache.
# TODO: Enable ROCM_AITER_FA. Its metadata builder reads head
# counts from the model config at construction time and
# allocates extend_workspace with those dimensions. The test
# uses independent head count parameters (num_heads=2/4,
# num_kv_heads=2) that don't match the model config
# (Llama-3-8B: 32 q heads, 8 kv heads), causing a head count
# mismatch in flash_attn_varlen_func during extend_forward.
# Fixing this requires either matching test head counts to the
# model config or decoupling the builder from model config
# head geometry. The direct cache update path
# (_NEEDS_DIRECT_CACHE_UPDATE) is already in place for when
# this is resolved.
return backends
# CUDA: flash attention.
return [AttentionBackendEnum.FLASH_ATTN]
class MockAttentionLayer(torch.nn.Module): class MockAttentionLayer(torch.nn.Module):
_q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _q_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _k_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
_v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda") _v_scale = torch.tensor(1.0, dtype=torch.float32, device="cuda")
layer_name = "mock_layer"
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -48,6 +194,13 @@ def forward_attention( ...@@ -48,6 +194,13 @@ def forward_attention(
spec_token_tree: str | None = None, spec_token_tree: str | None = None,
num_spec_tokens: int = 0, num_spec_tokens: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
"""Run a single attention forward pass through the given backend.
``kv_cache`` is expected in **flash layout**
``(2, num_blocks, block_size, num_kv_heads, head_size)``.
It is automatically converted when the target backend needs a
different layout.
"""
batch_size, q_len, num_heads, dim_per_head = q.shape batch_size, q_len, num_heads, dim_per_head = q.shape
num_kv_heads = k.shape[-2] num_kv_heads = k.shape[-2]
# Initialize the query and KV sequence lengths. # Initialize the query and KV sequence lengths.
...@@ -116,17 +269,37 @@ def forward_attention( ...@@ -116,17 +269,37 @@ def forward_attention(
kv_cache_dtype="auto", kv_cache_dtype="auto",
) )
# Adapt KV cache layout for this backend.
adapted_kv_cache = _adapt_kv_cache_for_backend(kv_cache, backend)
# Run forward pass and return output. # Run forward pass and return output.
query = q.view(-1, num_heads, dim_per_head) query = q.view(-1, num_heads, dim_per_head)
key = k.view(-1, num_kv_heads, dim_per_head) key = k.view(-1, num_kv_heads, dim_per_head)
value = v.view(-1, num_kv_heads, dim_per_head) value = v.view(-1, num_kv_heads, dim_per_head)
output = torch.empty_like(query) output = torch.empty_like(query)
if not try_backend_includes_kv_cache_update(backend): if not try_backend_includes_kv_cache_update(backend):
if backend in _NEEDS_DIRECT_CACHE_UPDATE:
# This backend's do_kv_cache_update requires engine-level
# ForwardContext that isn't available in this test harness.
# Write directly using reshape_and_cache_flash since the
# KV cache layout is identical (flash layout, unbind on dim 0).
key_cache, value_cache = adapted_kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
"auto",
layer._k_scale,
layer._v_scale,
)
else:
instance.do_kv_cache_update( instance.do_kv_cache_update(
layer=layer, layer=layer,
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache, kv_cache=adapted_kv_cache,
slot_mapping=attn_metadata.slot_mapping, slot_mapping=attn_metadata.slot_mapping,
) )
return instance.forward( return instance.forward(
...@@ -134,13 +307,20 @@ def forward_attention( ...@@ -134,13 +307,20 @@ def forward_attention(
query=query, query=query,
key=key, key=key,
value=value, value=value,
kv_cache=kv_cache.clone(), kv_cache=adapted_kv_cache.clone(),
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
output=output, output=output,
) )
def test_tree_attn_correctness() -> None: @pytest.mark.parametrize(
"reference_backend",
_get_available_reference_backends(),
ids=lambda b: b.name,
)
def test_tree_attn_correctness(
reference_backend: AttentionBackendEnum,
) -> None:
torch.manual_seed(42) torch.manual_seed(42)
torch.cuda.manual_seed_all(42) torch.cuda.manual_seed_all(42)
...@@ -205,7 +385,9 @@ def test_tree_attn_correctness() -> None: ...@@ -205,7 +385,9 @@ def test_tree_attn_correctness() -> None:
dtype=torch.bfloat16, dtype=torch.bfloat16,
) )
# Set up the block table and KV cache for paged KV. # KV cache in flash layout - the canonical format for
# tree attention. forward_attention() handles conversion
# when needed.
assert max_sequence_length % block_size == 0 assert max_sequence_length % block_size == 0
max_blocks_per_batch = max_sequence_length // block_size max_blocks_per_batch = max_sequence_length // block_size
kv_cache = torch.randn( kv_cache = torch.randn(
...@@ -263,9 +445,7 @@ def test_tree_attn_correctness() -> None: ...@@ -263,9 +445,7 @@ def test_tree_attn_correctness() -> None:
num_spec_tokens=tree_size_q - 1, num_spec_tokens=tree_size_q - 1,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
# Verify that the chain attention output for each # Verify each branch against the reference backend.
# branch of the tree (computed using FA3) matches
# the tree attention output.
for q_index in range(tree_size_q): for q_index in range(tree_size_q):
# Get the q, k, and v for the branch. # Get the q, k, and v for the branch.
branch_mask = tree_attn_mask[q_index, :] branch_mask = tree_attn_mask[q_index, :]
...@@ -286,8 +466,8 @@ def test_tree_attn_correctness() -> None: ...@@ -286,8 +466,8 @@ def test_tree_attn_correctness() -> None:
branch_positions, block_table, block_size branch_positions, block_table, block_size
) )
# Compute flash attention for the branch. # Reference attention for this branch.
flash_attn_output = forward_attention( ref_output = forward_attention(
q=q_branch, q=q_branch,
k=k_branch, k=k_branch,
v=v_branch, v=v_branch,
...@@ -295,16 +475,17 @@ def test_tree_attn_correctness() -> None: ...@@ -295,16 +475,17 @@ def test_tree_attn_correctness() -> None:
block_table=block_table, block_table=block_table,
slot_mapping=branch_slot_mapping, slot_mapping=branch_slot_mapping,
seqlen_k=sequence_position + q_len, seqlen_k=sequence_position + q_len,
backend=AttentionBackendEnum.FLASH_ATTN, backend=reference_backend,
).view(batch_size, -1, num_heads, dim_per_head) ).view(batch_size, -1, num_heads, dim_per_head)
# Compare the outputs. # Compare the outputs.
assert torch.allclose( assert torch.allclose(
tree_attn_output[:, branch_indices], tree_attn_output[:, branch_indices],
flash_attn_output, ref_output,
atol=7.81e-3, atol=7.81e-3,
), ( ), (
f"outputs are not close for " f"outputs are not close for "
f"reference_backend: {reference_backend.name}, "
f"batch_size: {batch_size}, " f"batch_size: {batch_size}, "
f"num_heads: {num_heads}, " f"num_heads: {num_heads}, "
f"sequence_position: {sequence_position}, " f"sequence_position: {sequence_position}, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment