Unverified Commit d559979c authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[bugfix] fix cpu tests (#10585)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent d345f409
......@@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union
import torch
from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
......@@ -64,6 +65,7 @@ class CPUEmbeddingModelRunner(
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)
# Only perform pooling in the driver worker.
......
......@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.forward_context import set_forward_context
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
......@@ -303,6 +304,7 @@ class CPUEncoderDecoderModelRunner(
intermediate_tensors,
}
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)
# Compute the logits.
......
......@@ -10,6 +10,7 @@ from torch import nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
......@@ -487,6 +488,7 @@ class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment