"vllm/executor/ray_xpu_executor.py" did not exist on "479d69fad0538f04cb22bf13e76ff91cfeb8a4e5"
Unverified Commit 9324e102 authored by Yong Hoon Shin's avatar Yong Hoon Shin Committed by GitHub
Browse files

Fix KV sharing fast prefill with cudagraph enabled (#28537)


Signed-off-by: default avatarYong Hoon Shin <yhshin@meta.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 4516d44b
...@@ -4,13 +4,11 @@ ...@@ -4,13 +4,11 @@
import random import random
import pytest import pytest
import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode from vllm.config import CompilationConfig, CompilationMode
from vllm.distributed import cleanup_dist_env_and_memory
from ...utils import fork_new_process_for_each_test from ...utils import check_answers, fork_new_process_for_each_test, prep_prompts
# global seed # global seed
SEED = 42 SEED = 42
...@@ -45,28 +43,12 @@ def test_prompts(): ...@@ -45,28 +43,12 @@ def test_prompts():
return prompts return prompts
def cleanup(llm: LLM, compilation_config: CompilationConfig):
# hacky: below lines are required to free up memory for the next test
# when setting VLLM_ENABLE_V1_MULTIPROCESSING=0, del llm is not sufficient
# TODO(sarckk): when enforce_eager=False, memory is not freed:
# find out why and re-enable test for enforce_eager=False case
llm_engine = llm.llm_engine.engine_core.engine_core
model_runner = llm_engine.model_executor.driver_worker.worker.model_runner
del model_runner.model
del model_runner.kv_caches
del compilation_config.static_forward_context
compilation_config.static_forward_context = {}
del llm
torch.cuda.empty_cache()
cleanup_dist_env_and_memory()
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("kv_sharing_fast_prefill", [False, True])
@pytest.mark.skip(reason="Disable until Gemma3n supports fast prefill") @pytest.mark.parametrize("enforce_eager", [True, False])
def test_kv_sharing_fast_prefill( def test_kv_sharing_fast_prefill(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
kv_sharing_fast_prefill: bool,
enforce_eager: bool, enforce_eager: bool,
test_prompts: list[str], test_prompts: list[str],
): ):
...@@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill( ...@@ -79,36 +61,25 @@ def test_kv_sharing_fast_prefill(
if not enforce_eager if not enforce_eager
else CompilationMode.NONE, else CompilationMode.NONE,
) )
batch_size = 10
with monkeypatch.context() as m: with monkeypatch.context() as m:
# Make scheduling deterministic for reproducibility # Make scheduling deterministic for reproducibility
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
llm = LLM( prompts, answer, indices = prep_prompts(batch_size)
model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager,
compilation_config=compilation_config,
seed=SEED,
)
ref_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
llm = LLM( llm = LLM(
model="google/gemma-3n-E2B-it", model="google/gemma-3n-E2B-it",
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
compilation_config=compilation_config, compilation_config=compilation_config,
seed=SEED, seed=SEED,
kv_sharing_fast_prefill=True, kv_sharing_fast_prefill=kv_sharing_fast_prefill,
)
responses = llm.generate(prompts, sampling_params)
check_answers(
indices,
answer,
[response.outputs[0].text for response in responses],
accept_rate=1.0,
) )
optimized_responses = llm.generate(test_prompts, sampling_params)
cleanup(llm, compilation_config)
misses = 0
for ref_response, optimized_response in zip(ref_responses, optimized_responses):
if ref_response.outputs[0].text != optimized_response.outputs[0].text:
misses += 1
assert misses == 0
...@@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens ...@@ -965,12 +965,6 @@ def reshape_attn_output_for_spec_decode(attn_output: torch.Tensor) -> torch.Tens
return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3]) return attn_output.view(total_tokens, attn_output.shape[2], attn_output.shape[3])
KV_SHARING_FAST_PREFILL_METADATA_FIELDS = [
("logits_indices_padded", torch.Tensor | None, None),
("num_logits_indices", int, 0),
]
def subclass_attention_metadata( def subclass_attention_metadata(
name_prefix: str, name_prefix: str,
metadata_cls: Any, metadata_cls: Any,
...@@ -986,8 +980,8 @@ def subclass_attention_metadata( ...@@ -986,8 +980,8 @@ def subclass_attention_metadata(
@runtime_checkable @runtime_checkable
class KVSharingFastPrefillMetadata(Protocol): class KVSharingFastPrefillMetadata(Protocol):
logits_indices_padded: torch.Tensor logits_indices_padded: torch.Tensor | None = None
num_logits_indices: int num_logits_indices: int | None = None
def create_fast_prefill_custom_backend( def create_fast_prefill_custom_backend(
...@@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend( ...@@ -1019,11 +1013,6 @@ def create_fast_prefill_custom_backend(
for _field in fields(metadata.__class__): for _field in fields(metadata.__class__):
setattr(self, _field.name, getattr(metadata, _field.name)) setattr(self, _field.name, getattr(metadata, _field.name))
# Set additional fields that will be used in model code
assert (
common_attn_metadata.logits_indices_padded is not None
and common_attn_metadata.num_logits_indices is not None
)
self.logits_indices_padded = ( self.logits_indices_padded = (
common_attn_metadata.logits_indices_padded common_attn_metadata.logits_indices_padded
) )
......
...@@ -1314,7 +1314,7 @@ class GPUModelRunner( ...@@ -1314,7 +1314,7 @@ class GPUModelRunner(
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
""" """
logits_indices_padded = None logits_indices_padded = None
num_logits_indices = 0 num_logits_indices = None
if logits_indices is not None: if logits_indices is not None:
num_logits_indices = logits_indices.size(0) num_logits_indices = logits_indices.size(0)
if self.cache_config.kv_sharing_fast_prefill: if self.cache_config.kv_sharing_fast_prefill:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment