Unverified Commit ae5b7aff authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Improve Mistral format checks. (#33253)


Signed-off-by: default avatarJulien Denize <julien.denize@mistral.ai>
Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
parent a11bc12d
...@@ -83,7 +83,10 @@ def _assert_model_arch_config( ...@@ -83,7 +83,10 @@ def _assert_model_arch_config(
assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"] assert model_arch_config.is_deepseek_mla == expected["is_deepseek_mla"]
torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype( torch_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
model_config.hf_config, model_config.model, revision=model_config.revision model_config.hf_config,
model_config.model,
revision=model_config.revision,
config_format="hf",
) )
assert str(torch_dtype) == expected["dtype"] assert str(torch_dtype) == expected["dtype"]
......
...@@ -365,6 +365,7 @@ class HfRunner: ...@@ -365,6 +365,7 @@ class HfRunner:
self.config, self.config,
dtype=dtype, dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder, is_pooling_model=is_sentence_transformer or is_cross_encoder,
config_format="hf",
) )
model_kwargs = model_kwargs if model_kwargs is not None else {} model_kwargs = model_kwargs if model_kwargs is not None else {}
......
...@@ -8,7 +8,11 @@ from unittest.mock import MagicMock, call, patch ...@@ -8,7 +8,11 @@ from unittest.mock import MagicMock, call, patch
import pytest import pytest
from vllm.transformers_utils.repo_utils import list_filtered_repo_files from vllm.transformers_utils.repo_utils import (
any_pattern_in_repo_files,
is_mistral_model_repo,
list_filtered_repo_files,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -60,3 +64,95 @@ def test_list_filtered_repo_files( ...@@ -60,3 +64,95 @@ def test_list_filtered_repo_files(
repo_type="model", repo_type="model",
token="token", token="token",
) )
@pytest.mark.parametrize(
("allow_patterns", "expected_bool"),
[
(["*.json", "correct*.txt"], True),
(
["*.jpeg"],
True,
),
(
["not_found.jpeg"],
False,
),
],
)
def test_one_filtered_repo_files(allow_patterns: list[str], expected_bool: bool):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
subfolder = path_tmp_dir / "subfolder"
subfolder.mkdir()
(path_tmp_dir / "uncorrect.jpeg").touch()
(subfolder / "correct.txt").touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
assert (
any_pattern_in_repo_files(
tmp_dir, allow_patterns, "revision", "model", "token"
)
) is expected_bool
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
@pytest.mark.parametrize(
("files", "expected_bool"),
[
(["consolidated.safetensors", "incorrect.txt"], True),
(["consolidated-1.safetensors", "incorrect.txt"], True),
(
["consolidated-1.json"],
False,
),
],
)
def test_is_mistral_model_repo(files: list[str], expected_bool: bool):
with tempfile.TemporaryDirectory() as tmp_dir:
# Prep folder and files
path_tmp_dir = Path(tmp_dir)
for file in files:
(path_tmp_dir / file).touch()
def _glob_path() -> list[str]:
return [
str(file.relative_to(path_tmp_dir))
for file in path_tmp_dir.glob("**/*")
if file.is_file()
]
# Patch list_repo_files called by fn
with patch(
"vllm.transformers_utils.repo_utils.list_repo_files",
MagicMock(return_value=_glob_path()),
) as mock_list_repo_files:
assert (
is_mistral_model_repo(tmp_dir, "revision", "model", "token")
is expected_bool
)
assert mock_list_repo_files.call_count == 1
assert mock_list_repo_files.call_args_list[0] == call(
repo_id=tmp_dir,
revision="revision",
repo_type="model",
token="token",
)
...@@ -565,6 +565,7 @@ class ModelConfig: ...@@ -565,6 +565,7 @@ class ModelConfig:
self.dtype, self.dtype,
is_pooling_model=self.runner_type == "pooling", is_pooling_model=self.runner_type == "pooling",
revision=self.revision, revision=self.revision,
config_format=self.config_format,
) )
self.original_max_model_len = self.max_model_len self.original_max_model_len = self.max_model_len
...@@ -1844,9 +1845,10 @@ def _get_and_verify_dtype( ...@@ -1844,9 +1845,10 @@ def _get_and_verify_dtype(
*, *,
is_pooling_model: bool, is_pooling_model: bool,
revision: str | None = None, revision: str | None = None,
config_format: ConfigFormat = "hf",
) -> torch.dtype: ) -> torch.dtype:
config_dtype = ModelArchConfigConvertorBase.get_torch_dtype( config_dtype = ModelArchConfigConvertorBase.get_torch_dtype(
config, model_id, revision=revision config, model_id, revision=revision, config_format=config_format
) )
model_type = config.model_type model_type = config.model_type
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib.util
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
...@@ -18,7 +17,10 @@ from vllm.transformers_utils.gguf_utils import ( ...@@ -18,7 +17,10 @@ from vllm.transformers_utils.gguf_utils import (
is_remote_gguf, is_remote_gguf,
split_remote_gguf, split_remote_gguf,
) )
from vllm.transformers_utils.repo_utils import list_filtered_repo_files from vllm.transformers_utils.repo_utils import (
any_pattern_in_repo_files,
is_mistral_model_repo,
)
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from .protocol import TokenizerLike from .protocol import TokenizerLike
...@@ -142,26 +144,26 @@ def resolve_tokenizer_args( ...@@ -142,26 +144,26 @@ def resolve_tokenizer_args(
kwargs["use_fast"] = False kwargs["use_fast"] = False
# Try to use official Mistral tokenizer if possible # Try to use official Mistral tokenizer if possible
if tokenizer_mode == "auto" and importlib.util.find_spec("mistral_common"): if (
allow_patterns = ["tekken.json", "tokenizer.model.v*"] tokenizer_mode == "auto"
files_list = list_filtered_repo_files( and is_mistral_model_repo(
model_name_or_path=str(tokenizer_name), revision=revision
)
and any_pattern_in_repo_files(
model_name_or_path=str(tokenizer_name), model_name_or_path=str(tokenizer_name),
allow_patterns=allow_patterns, allow_patterns=["tekken.json", "tokenizer.model.v*"],
revision=revision, revision=revision,
) )
if len(files_list) > 0: ):
tokenizer_mode = "mistral" tokenizer_mode = "mistral"
# Try to use Grok2 tiktoken tokenizer if possible # Try to use Grok2 tiktoken tokenizer if possible
if tokenizer_mode == "auto": if tokenizer_mode == "auto" and any_pattern_in_repo_files(
allow_patterns = ["tokenizer.tok.json"] model_name_or_path=str(tokenizer_name),
files_list = list_filtered_repo_files( allow_patterns=["tokenizer.tok.json"],
model_name_or_path=str(tokenizer_name), revision=revision,
allow_patterns=allow_patterns, ):
revision=revision, tokenizer_mode = "grok2"
)
if len(files_list) > 0:
tokenizer_mode = "grok2"
# Fallback to HF tokenizer # Fallback to HF tokenizer
if tokenizer_mode == "auto": if tokenizer_mode == "auto":
......
...@@ -23,6 +23,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME ...@@ -23,6 +23,7 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs from vllm import envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.repo_utils import is_mistral_model_repo
from vllm.transformers_utils.utils import parse_safetensors_file_metadata from vllm.transformers_utils.utils import parse_safetensors_file_metadata
from .config_parser_base import ConfigParserBase from .config_parser_base import ConfigParserBase
...@@ -49,7 +50,6 @@ except ImportError: ...@@ -49,7 +50,6 @@ except ImportError:
ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES, ALLOWED_LAYER_TYPES as ALLOWED_ATTENTION_LAYER_TYPES,
) )
if envs.VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig from modelscope import AutoConfig
else: else:
...@@ -581,7 +581,11 @@ def get_config( ...@@ -581,7 +581,11 @@ def get_config(
try: try:
# First check for Mistral to avoid defaulting to # First check for Mistral to avoid defaulting to
# Transformers implementation. # Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): if is_mistral_model_repo(
model_name_or_path=str(model), revision=revision
) and file_or_path_exists(
model=model, config_name=MISTRAL_CONFIG_NAME, revision=revision
):
config_format = "mistral" config_format = "mistral"
elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists( elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision model, HF_CONFIG_NAME, revision=revision
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from contextlib import contextmanager
from typing import final from typing import final
import torch import torch
from huggingface_hub import constants
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -14,6 +17,7 @@ from vllm.config.model_arch import ( ...@@ -14,6 +17,7 @@ from vllm.config.model_arch import (
from vllm.config.utils import getattr_iter from vllm.config.utils import getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat,
try_get_safetensors_metadata, try_get_safetensors_metadata,
) )
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
...@@ -21,6 +25,22 @@ from vllm.utils.torch_utils import common_broadcastable_dtype ...@@ -21,6 +25,22 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
@contextmanager
def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]:
if config_format == "mistral":
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
else:
yield
class ModelArchConfigConvertorBase: class ModelArchConfigConvertorBase:
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig): def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
self.hf_config = hf_config self.hf_config = hf_config
...@@ -123,7 +143,11 @@ class ModelArchConfigConvertorBase: ...@@ -123,7 +143,11 @@ class ModelArchConfigConvertorBase:
@final @final
@classmethod @classmethod
def get_torch_dtype( def get_torch_dtype(
cls, hf_config: PretrainedConfig, model_id: str, revision: str | None cls,
hf_config: PretrainedConfig,
model_id: str,
revision: str | None,
config_format: ConfigFormat,
): ):
# NOTE: getattr(config, "dtype", torch.float32) is not correct # NOTE: getattr(config, "dtype", torch.float32) is not correct
# because config.dtype can be None. # because config.dtype can be None.
...@@ -140,7 +164,8 @@ class ModelArchConfigConvertorBase: ...@@ -140,7 +164,8 @@ class ModelArchConfigConvertorBase:
# Try to read the dtype of the weights if they are in safetensors format # Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None: if config_dtype is None:
repo_mt = try_get_safetensors_metadata(model_id, revision=revision) with _maybe_patch_hf_hub_constants(config_format):
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata): if repo_mt and (files_mt := repo_mt.files_metadata):
param_dtypes: set[torch.dtype] = { param_dtypes: set[torch.dtype] = {
......
...@@ -127,6 +127,42 @@ def list_filtered_repo_files( ...@@ -127,6 +127,42 @@ def list_filtered_repo_files(
return file_list return file_list
def any_pattern_in_repo_files(
model_name_or_path: str,
allow_patterns: list[str],
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
):
return (
len(
list_filtered_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=allow_patterns,
revision=revision,
repo_type=repo_type,
token=token,
)
)
> 0
)
def is_mistral_model_repo(
model_name_or_path: str,
revision: str | None = None,
repo_type: str | None = None,
token: str | bool | None = None,
) -> bool:
return any_pattern_in_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=["consolidated*.safetensors"],
revision=revision,
repo_type=repo_type,
token=token,
)
def file_exists( def file_exists(
repo_id: str, repo_id: str,
file_name: str, file_name: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment