Unverified Commit e5db3e27 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[CI/Build] Fix broken mm processor test Mistral-3-large (#30597)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 64251f48
...@@ -8,6 +8,7 @@ from typing import Any, TypeAlias ...@@ -8,6 +8,7 @@ from typing import Any, TypeAlias
import numpy as np import numpy as np
import pytest import pytest
import torch
import torch.nn as nn import torch.nn as nn
from PIL import Image from PIL import Image
...@@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config ...@@ -35,6 +36,7 @@ from vllm.tokenizers import cached_tokenizer_from_config
from vllm.utils.collection_utils import is_list_of from vllm.utils.collection_utils import is_list_of
from vllm.utils.torch_utils import set_default_torch_dtype from vllm.utils.torch_utils import set_default_torch_dtype
from ....utils import create_new_process_for_each_test
from ...registry import HF_EXAMPLE_MODELS from ...registry import HF_EXAMPLE_MODELS
from ...utils import dummy_hf_overrides from ...utils import dummy_hf_overrides
from .test_common import get_model_ids_to_test, get_text_token_prompts from .test_common import get_model_ids_to_test, get_text_token_prompts
...@@ -136,6 +138,7 @@ def create_batched_mm_kwargs( ...@@ -136,6 +138,7 @@ def create_batched_mm_kwargs(
) )
# TODO(Isotr0py): Don't initalize model during test
@contextmanager @contextmanager
def initialize_dummy_model( def initialize_dummy_model(
model_cls: type[nn.Module], model_cls: type[nn.Module],
...@@ -150,16 +153,21 @@ def initialize_dummy_model( ...@@ -150,16 +153,21 @@ def initialize_dummy_model(
backend="nccl", backend="nccl",
) )
initialize_model_parallel(tensor_model_parallel_size=1) initialize_model_parallel(tensor_model_parallel_size=1)
current_device = torch.get_default_device()
vllm_config = VllmConfig(model_config=model_config) vllm_config = VllmConfig(model_config=model_config)
with set_current_vllm_config(vllm_config=vllm_config): with set_current_vllm_config(vllm_config=vllm_config):
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
torch.set_default_device(current_platform.device_type)
model = model_cls(vllm_config=vllm_config) model = model_cls(vllm_config=vllm_config)
torch.set_default_device(current_device)
yield model yield model
del model del model
cleanup_dist_env_and_memory() cleanup_dist_env_and_memory()
@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", get_model_ids_to_test()) @pytest.mark.parametrize("model_id", get_model_ids_to_test())
def test_model_tensor_schema(model_id: str): def test_model_tensor_schema(model_id: str):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id) model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment