Unverified Commit f863ffc9 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Mistral-Small 3.1] Update docs and tests (#14977)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
Co-authored-by: default avatarRoger Wang <ywang@roblox.com>
parent 400d483e
...@@ -879,7 +879,7 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -879,7 +879,7 @@ See [this page](#generative-models) for more information on how to use generativ
- * `PixtralForConditionalGeneration` - * `PixtralForConditionalGeneration`
* Pixtral * Pixtral
* T + I<sup>+</sup> * T + I<sup>+</sup>
* `mistralai/Pixtral-12B-2409`, `mistral-community/pixtral-12b`, etc. * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc.
* *
* ✅︎ * ✅︎
* ✅︎ * ✅︎
......
...@@ -6,14 +6,14 @@ import argparse ...@@ -6,14 +6,14 @@ import argparse
from vllm import LLM from vllm import LLM
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
# This script is an offline demo for running Pixtral. # This script is an offline demo for running Mistral-Small-3
# #
# If you want to run a server/client setup, please follow this code: # If you want to run a server/client setup, please follow this code:
# #
# - Server: # - Server:
# #
# ```bash # ```bash
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384 # vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 --tokenizer-mode mistral --limit-mm-per-prompt 'image=4' --max-model-len 16384
# ``` # ```
# #
# - Client: # - Client:
...@@ -23,7 +23,7 @@ from vllm.sampling_params import SamplingParams ...@@ -23,7 +23,7 @@ from vllm.sampling_params import SamplingParams
# --header 'Content-Type: application/json' \ # --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \ # --header 'Authorization: Bearer token' \
# --data '{ # --data '{
# "model": "mistralai/Pixtral-12B-2409", # "model": "mistralai/Mistral-Small-3.1-24B-Instruct-2503",
# "messages": [ # "messages": [
# { # {
# "role": "user", # "role": "user",
...@@ -44,7 +44,7 @@ from vllm.sampling_params import SamplingParams ...@@ -44,7 +44,7 @@ from vllm.sampling_params import SamplingParams
def run_simple_demo(args: argparse.Namespace): def run_simple_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
sampling_params = SamplingParams(max_tokens=8192) sampling_params = SamplingParams(max_tokens=8192)
# Lower max_model_len and/or max_num_seqs on low-VRAM GPUs. # Lower max_model_len and/or max_num_seqs on low-VRAM GPUs.
...@@ -83,7 +83,7 @@ def run_simple_demo(args: argparse.Namespace): ...@@ -83,7 +83,7 @@ def run_simple_demo(args: argparse.Namespace):
def run_advanced_demo(args: argparse.Namespace): def run_advanced_demo(args: argparse.Namespace):
model_name = "mistralai/Pixtral-12B-2409" model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
max_img_per_msg = 5 max_img_per_msg = 5
max_tokens_per_img = 4096 max_tokens_per_img = 4096
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
Run `pytest tests/models/test_mistral.py`. Run `pytest tests/models/test_mistral.py`.
""" """
import json import json
import uuid
from dataclasses import asdict from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
...@@ -16,8 +15,7 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer ...@@ -16,8 +15,7 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor from transformers import AutoProcessor
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams, from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt
TextPrompt, TokensPrompt)
from vllm.multimodal import MultiModalDataBuiltins from vllm.multimodal import MultiModalDataBuiltins
from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.inputs import PlaceholderRange
from vllm.sequence import Logprob, SampleLogprobs from vllm.sequence import Logprob, SampleLogprobs
...@@ -28,7 +26,11 @@ from ...utils import check_logprobs_close ...@@ -28,7 +26,11 @@ from ...utils import check_logprobs_close
if TYPE_CHECKING: if TYPE_CHECKING:
from _typeshed import StrPath from _typeshed import StrPath
MODELS = ["mistralai/Pixtral-12B-2409"] PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]
IMG_URLS = [ IMG_URLS = [
"https://picsum.photos/id/237/400/300", "https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300", "https://picsum.photos/id/231/200/300",
...@@ -125,8 +127,10 @@ MAX_MODEL_LEN = [8192, 65536] ...@@ -125,8 +127,10 @@ MAX_MODEL_LEN = [8192, 65536]
FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures" FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
assert FIXTURES_PATH.exists() assert FIXTURES_PATH.exists()
FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json" FIXTURE_LOGPROBS_CHAT = {
FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json" PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
}
OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]] OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]]
...@@ -166,12 +170,12 @@ def test_chat( ...@@ -166,12 +170,12 @@ def test_chat(
model: str, model: str,
dtype: str, dtype: str,
) -> None: ) -> None:
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT) EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
FIXTURE_LOGPROBS_CHAT[model])
with vllm_runner( with vllm_runner(
model, model,
dtype=dtype, dtype=dtype,
tokenizer_mode="mistral", tokenizer_mode="mistral",
enable_chunked_prefill=False,
max_model_len=max_model_len, max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT, limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model: ) as vllm_model:
...@@ -183,70 +187,40 @@ def test_chat( ...@@ -183,70 +187,40 @@ def test_chat(
outputs.extend(output) outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs) logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
# Remove last `None` prompt_logprobs to compare with fixture
for i in range(len(logprobs)):
assert logprobs[i][-1] is None
logprobs[i] = logprobs[i][:-1]
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS, check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs, outputs_1_lst=logprobs,
name_0="h100_ref", name_0="h100_ref",
name_1="output") name_1="output")
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
dtype=dtype,
)
engine = LLMEngine.from_engine_args(args)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
outputs = []
count = 0
while True:
out = engine.step()
count += 1
for request_output in out:
if request_output.finished:
outputs.append(request_output)
if count == 2:
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
SAMPLING_PARAMS)
if not engine.has_unfinished_requests():
break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@large_gpu_test(min_gb=48) @large_gpu_test(min_gb=48)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"prompt,expected_ranges", "prompt,expected_ranges",
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{ [(_create_engine_inputs_hf(IMG_URLS[:1]), [{
"offset": 10, "offset": 11,
"length": 494 "length": 494
}]), }]),
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{ (_create_engine_inputs_hf(IMG_URLS[1:4]), [{
"offset": 10, "offset": 11,
"length": 266 "length": 266
}, { }, {
"offset": 276, "offset": 277,
"length": 1056 "length": 1056
}, { }, {
"offset": 1332, "offset": 1333,
"length": 418 "length": 418
}])]) }])])
def test_multi_modal_placeholders( def test_multi_modal_placeholders(vllm_runner, prompt,
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None: expected_ranges: list[PlaceholderRange],
monkeypatch) -> None:
# This placeholder checking test only works with V0 engine
# where `multi_modal_placeholders` is returned with `RequestOutput`
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner( with vllm_runner(
"mistral-community/pixtral-12b", "mistral-community/pixtral-12b",
max_model_len=8192, max_model_len=8192,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment