Unverified Commit 84275504 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[CI/Build] Update pixtral tests to use JSON (#8436)

parent 3f79bc3d
...@@ -76,7 +76,7 @@ exclude = [ ...@@ -76,7 +76,7 @@ exclude = [
[tool.codespell] [tool.codespell]
ignore-words-list = "dout, te, indicies, subtile" ignore-words-list = "dout, te, indicies, subtile"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort] [tool.isort]
use_parentheses = true use_parentheses = true
......
This diff is collapsed.
This diff is collapsed.
...@@ -2,9 +2,10 @@ ...@@ -2,9 +2,10 @@
Run `pytest tests/models/test_mistral.py`. Run `pytest tests/models/test_mistral.py`.
""" """
import pickle import json
import uuid import uuid
from typing import Any, Dict, List from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple
import pytest import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk from mistral_common.protocol.instruct.messages import ImageURLChunk
...@@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk ...@@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from .utils import check_logprobs_close from .utils import check_logprobs_close
...@@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5) ...@@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4) LIMIT_MM_PER_PROMPT = dict(image=4)
MAX_MODEL_LEN = [8192, 65536] MAX_MODEL_LEN = [8192, 65536]
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle" FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle" FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
def load_logprobs(filename: str) -> Any:
with open(filename, 'rb') as f: # For the test author to store golden output in JSON
return pickle.load(f) def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
json_data = [(tokens, text,
[{k: asdict(v)
for k, v in token_logprobs.items()}
for token_logprobs in (logprobs or [])])
for tokens, text, logprobs in outputs]
with open(filename, "w") as f:
json.dump(json_data, f)
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
with open(filename, "rb") as f:
json_data = json.load(f)
return [(tokens, text,
[{int(k): Logprob(**v)
for k, v in token_logprobs.items()}
for token_logprobs in logprobs])
for tokens, text, logprobs in json_data]
@pytest.mark.skip( @pytest.mark.skip(
...@@ -103,7 +125,7 @@ def test_chat( ...@@ -103,7 +125,7 @@ def test_chat(
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT) EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
...@@ -120,10 +142,10 @@ def test_chat( ...@@ -120,10 +142,10 @@ def test_chat(
outputs.extend(output) outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs, check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=EXPECTED_CHAT_LOGPROBS, outputs_1_lst=logprobs,
name_0="output", name_0="h100_ref",
name_1="h100_ref") name_1="output")
@pytest.mark.skip( @pytest.mark.skip(
...@@ -133,7 +155,7 @@ def test_chat( ...@@ -133,7 +155,7 @@ def test_chat(
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None: def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE) EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs( args = EngineArgs(
model=model, model=model,
tokenizer_mode="mistral", tokenizer_mode="mistral",
...@@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None: ...@@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
break break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs, check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS, outputs_1_lst=logprobs,
name_0="output", name_0="h100_ref",
name_1="h100_ref") name_1="output")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment