Unverified Commit 12817a8a authored by Nicolò Lucchesi's avatar Nicolò Lucchesi Committed by GitHub
Browse files

[CI] Fix `tests/v1/e2e/test_kv_sharing_fast_prefill.py` import on test (#22815)


Signed-off-by: default avatarNickLucche <nlucches@redhat.com>
parent c9232d41
...@@ -11,7 +11,8 @@ from vllm import LLM, SamplingParams ...@@ -11,7 +11,8 @@ from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed import cleanup_dist_env_and_memory
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration from vllm.model_executor.models.gemma3n_mm import (
Gemma3nForConditionalGeneration)
from vllm.model_executor.models.registry import ModelRegistry from vllm.model_executor.models.registry import ModelRegistry
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -32,12 +33,13 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): ...@@ -32,12 +33,13 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, intermediate_tensors, hidden_states = super().forward(input_ids, positions,
inputs_embeds, **kwargs) intermediate_tensors, inputs_embeds,
**kwargs)
attn_metadata = get_forward_context().attn_metadata attn_metadata = get_forward_context().attn_metadata
# attn_metadata is None during dummy runs # attn_metadata is None during dummy runs
if (attn_metadata is not None if (attn_metadata is not None
and self.cache_config.kv_sharing_fast_prefill): and self.language_model.cache_config.kv_sharing_fast_prefill):
assert isinstance(attn_metadata, dict) # true in V1 assert isinstance(attn_metadata, dict) # true in V1
# Gemma3n-E2B has 30 layers, with last 20 layers being # Gemma3n-E2B has 30 layers, with last 20 layers being
# cross-decoder layers. Check attention metadata is correct # cross-decoder layers. Check attention metadata is correct
...@@ -52,7 +54,7 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration): ...@@ -52,7 +54,7 @@ class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
# Last layer will be a KV sharing layer # Last layer will be a KV sharing layer
layer_attn_metadata = attn_metadata[ layer_attn_metadata = attn_metadata[
self.model.language_model.layers[-1].self_attn.attn.layer_name] self.language_model.model.layers[-1].self_attn.attn.layer_name]
logits_indices_padded = (layer_attn_metadata.logits_indices_padded) logits_indices_padded = (layer_attn_metadata.logits_indices_padded)
assert logits_indices_padded is not None assert logits_indices_padded is not None
num_logits_indices = layer_attn_metadata.num_logits_indices num_logits_indices = layer_attn_metadata.num_logits_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment