Unverified Commit 794029f0 authored by Injae Ryou's avatar Injae Ryou Committed by GitHub
Browse files

[Feature]: Improve GGUF loading from HuggingFace user experience like repo_id:quant_type (#29137)


Signed-off-by: default avatarInjae Ryou <injaeryou@gmail.com>
Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 0231ce83
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.weight_utils import download_gguf
class TestGGUFDownload:
"""Test GGUF model downloading functionality."""
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_single_file(self, mock_download):
"""Test downloading a single GGUF file."""
# Setup mock
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
# Mock glob to return a single file
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[f"{mock_folder}/model-IQ1_S.gguf"] if "IQ1_S" in pattern else []
)
result = download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
# Verify download_weights_from_hf was called with correct patterns
mock_download.assert_called_once_with(
model_name_or_path="unsloth/Qwen3-0.6B-GGUF",
cache_dir=None,
allow_patterns=[
"*-IQ1_S.gguf",
"*-IQ1_S-*.gguf",
"*/*-IQ1_S.gguf",
"*/*-IQ1_S-*.gguf",
],
revision=None,
ignore_patterns=None,
)
# Verify result is the file path, not folder
assert result == f"{mock_folder}/model-IQ1_S.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_sharded_files(self, mock_download):
"""Test downloading sharded GGUF files."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
# Mock glob to return sharded files
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[
f"{mock_folder}/model-Q2_K-00001-of-00002.gguf",
f"{mock_folder}/model-Q2_K-00002-of-00002.gguf",
]
if "Q2_K" in pattern
else []
)
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
# Should return the first file after sorting
assert result == f"{mock_folder}/model-Q2_K-00001-of-00002.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
def test_download_gguf_subdir(self, mock_download):
"""Test downloading GGUF files from subdirectory."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
with patch("glob.glob") as mock_glob:
mock_glob.side_effect = lambda pattern, **kwargs: (
[f"{mock_folder}/Q2_K/model-Q2_K.gguf"]
if "Q2_K" in pattern or "**/*.gguf" in pattern
else []
)
result = download_gguf("unsloth/gpt-oss-120b-GGUF", "Q2_K")
assert result == f"{mock_folder}/Q2_K/model-Q2_K.gguf"
@patch("vllm.model_executor.model_loader.weight_utils.download_weights_from_hf")
@patch("glob.glob", return_value=[])
def test_download_gguf_no_files_found(self, mock_glob, mock_download):
"""Test error when no GGUF files are found."""
mock_folder = "/tmp/mock_cache"
mock_download.return_value = mock_folder
with pytest.raises(ValueError, match="Downloaded GGUF files not found"):
download_gguf("unsloth/Qwen3-0.6B-GGUF", "IQ1_S")
class TestGGUFModelLoader:
"""Test GGUFModelLoader class methods."""
@patch("os.path.isfile", return_value=True)
def test_prepare_weights_local_file(self, mock_isfile):
"""Test _prepare_weights with local file."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "/path/to/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/path/to/model.gguf"
mock_isfile.assert_called_once_with("/path/to/model.gguf")
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_https_url(self, mock_isfile, mock_hf_download):
"""Test _prepare_weights with HTTPS URL."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_hf_download.return_value = "/downloaded/model.gguf"
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "https://huggingface.co/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/downloaded/model.gguf"
mock_hf_download.assert_called_once_with(
url="https://huggingface.co/model.gguf"
)
@patch("vllm.model_executor.model_loader.gguf_loader.hf_hub_download")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_repo_filename(self, mock_isfile, mock_hf_download):
"""Test _prepare_weights with repo_id/filename.gguf format."""
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_hf_download.return_value = "/downloaded/model.gguf"
# Create a simple mock ModelConfig with only the model attribute
model_config = MagicMock()
model_config.model = "unsloth/Qwen3-0.6B-GGUF/model.gguf"
result = loader._prepare_weights(model_config)
assert result == "/downloaded/model.gguf"
mock_hf_download.assert_called_once_with(
repo_id="unsloth/Qwen3-0.6B-GGUF", filename="model.gguf"
)
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
@patch("vllm.transformers_utils.config.file_or_path_exists", return_value=True)
@patch("vllm.config.model.get_config")
@patch("vllm.config.model.is_gguf", return_value=True)
@patch("vllm.model_executor.model_loader.gguf_loader.download_gguf")
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_repo_quant_type(
self,
mock_isfile,
mock_download_gguf,
mock_is_gguf,
mock_get_config,
mock_file_exists,
mock_get_image_config,
):
"""Test _prepare_weights with repo_id:quant_type format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
mock_text_config = MockTextConfig()
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
mock_download_gguf.return_value = "/downloaded/model-IQ1_S.gguf"
model_config = ModelConfig(
model="unsloth/Qwen3-0.6B-GGUF:IQ1_S", tokenizer="Qwen/Qwen3-0.6B"
)
result = loader._prepare_weights(model_config)
# The actual result will be the downloaded file path from mock
assert result == "/downloaded/model-IQ1_S.gguf"
mock_download_gguf.assert_called_once_with(
"unsloth/Qwen3-0.6B-GGUF",
"IQ1_S",
cache_dir=None,
revision=None,
ignore_patterns=["original/**/*"],
)
@patch("vllm.config.model.get_hf_image_processor_config", return_value=None)
@patch("vllm.config.model.get_config")
@patch("vllm.config.model.is_gguf", return_value=False)
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
@patch("os.path.isfile", return_value=False)
def test_prepare_weights_invalid_format(
self,
mock_isfile,
mock_check_gguf,
mock_is_gguf,
mock_get_config,
mock_get_image_config,
):
"""Test _prepare_weights with invalid format."""
mock_hf_config = MagicMock()
mock_hf_config.architectures = ["Qwen3ForCausalLM"]
class MockTextConfig:
max_position_embeddings = 4096
sliding_window = None
model_type = "qwen3"
num_attention_heads = 32
mock_text_config = MockTextConfig()
mock_hf_config.get_text_config.return_value = mock_text_config
mock_hf_config.dtype = "bfloat16"
mock_get_config.return_value = mock_hf_config
load_config = LoadConfig(load_format="gguf")
loader = GGUFModelLoader(load_config)
# Create ModelConfig with a valid repo_id to avoid validation errors
# Then test _prepare_weights with invalid format
model_config = ModelConfig(model="unsloth/Qwen3-0.6B")
# Manually set model to invalid format after creation
model_config.model = "invalid-format"
with pytest.raises(ValueError, match="Unrecognised GGUF reference"):
loader._prepare_weights(model_config)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from pathlib import Path
from unittest.mock import patch
import pytest
from vllm.transformers_utils.utils import ( from vllm.transformers_utils.utils import (
is_cloud_storage, is_cloud_storage,
is_gcs, is_gcs,
is_gguf,
is_remote_gguf,
is_s3, is_s3,
split_remote_gguf,
) )
...@@ -28,3 +34,143 @@ def test_is_cloud_storage(): ...@@ -28,3 +34,143 @@ def test_is_cloud_storage():
assert is_cloud_storage("s3://model-path/path-to-model") assert is_cloud_storage("s3://model-path/path-to-model")
assert not is_cloud_storage("/unix/local/path") assert not is_cloud_storage("/unix/local/path")
assert not is_cloud_storage("nfs://nfs-fqdn.local") assert not is_cloud_storage("nfs://nfs-fqdn.local")
class TestIsRemoteGGUF:
"""Test is_remote_gguf utility function."""
def test_is_remote_gguf_with_colon_and_slash(self):
"""Test is_remote_gguf with repo_id:quant_type format."""
# Valid quant types
assert is_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert is_remote_gguf("user/repo:Q2_K")
assert is_remote_gguf("repo/model:Q4_K")
assert is_remote_gguf("repo/model:Q8_0")
# Invalid quant types should return False
assert not is_remote_gguf("repo/model:quant")
assert not is_remote_gguf("repo/model:INVALID")
assert not is_remote_gguf("repo/model:invalid_type")
def test_is_remote_gguf_without_colon(self):
"""Test is_remote_gguf without colon."""
assert not is_remote_gguf("repo/model")
assert not is_remote_gguf("unsloth/Qwen3-0.6B-GGUF")
def test_is_remote_gguf_without_slash(self):
"""Test is_remote_gguf without slash."""
assert not is_remote_gguf("model.gguf")
# Even with valid quant_type, no slash means not remote GGUF
assert not is_remote_gguf("model:IQ1_S")
assert not is_remote_gguf("model:quant")
def test_is_remote_gguf_local_path(self):
"""Test is_remote_gguf with local file path."""
assert not is_remote_gguf("/path/to/model.gguf")
assert not is_remote_gguf("./model.gguf")
def test_is_remote_gguf_with_path_object(self):
"""Test is_remote_gguf with Path object."""
assert is_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
assert not is_remote_gguf(Path("repo/model"))
def test_is_remote_gguf_with_http_https(self):
"""Test is_remote_gguf with HTTP/HTTPS URLs."""
# HTTP/HTTPS URLs should return False even with valid quant_type
assert not is_remote_gguf("http://example.com/repo/model:IQ1_S")
assert not is_remote_gguf("https://huggingface.co/repo/model:Q2_K")
assert not is_remote_gguf("http://repo/model:Q4_K")
assert not is_remote_gguf("https://repo/model:Q8_0")
def test_is_remote_gguf_with_cloud_storage(self):
"""Test is_remote_gguf with cloud storage paths."""
# Cloud storage paths should return False even with valid quant_type
assert not is_remote_gguf("s3://bucket/repo/model:IQ1_S")
assert not is_remote_gguf("gs://bucket/repo/model:Q2_K")
assert not is_remote_gguf("s3://repo/model:Q4_K")
assert not is_remote_gguf("gs://repo/model:Q8_0")
class TestSplitRemoteGGUF:
"""Test split_remote_gguf utility function."""
def test_split_remote_gguf_valid(self):
"""Test split_remote_gguf with valid repo_id:quant_type format."""
repo_id, quant_type = split_remote_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
assert quant_type == "IQ1_S"
repo_id, quant_type = split_remote_gguf("repo/model:Q2_K")
assert repo_id == "repo/model"
assert quant_type == "Q2_K"
def test_split_remote_gguf_with_path_object(self):
"""Test split_remote_gguf with Path object."""
repo_id, quant_type = split_remote_gguf(Path("unsloth/Qwen3-0.6B-GGUF:IQ1_S"))
assert repo_id == "unsloth/Qwen3-0.6B-GGUF"
assert quant_type == "IQ1_S"
def test_split_remote_gguf_invalid(self):
"""Test split_remote_gguf with invalid format."""
# Invalid format (no colon) - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("repo/model")
# Invalid quant type - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("repo/model:INVALID_TYPE")
# HTTP URL - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("http://repo/model:IQ1_S")
# Cloud storage - is_remote_gguf returns False
with pytest.raises(ValueError, match="Wrong GGUF model"):
split_remote_gguf("s3://bucket/repo/model:Q2_K")
class TestIsGGUF:
"""Test is_gguf utility function."""
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=True)
def test_is_gguf_with_local_file(self, mock_check_gguf):
"""Test is_gguf with local GGUF file."""
assert is_gguf("/path/to/model.gguf")
assert is_gguf("./model.gguf")
def test_is_gguf_with_remote_gguf(self):
"""Test is_gguf with remote GGUF format."""
# Valid remote GGUF format (repo_id:quant_type with valid quant_type)
assert is_gguf("unsloth/Qwen3-0.6B-GGUF:IQ1_S")
assert is_gguf("repo/model:Q2_K")
assert is_gguf("repo/model:Q4_K")
# Invalid quant_type should return False
assert not is_gguf("repo/model:quant")
assert not is_gguf("repo/model:INVALID")
@patch("vllm.transformers_utils.utils.check_gguf_file", return_value=False)
def test_is_gguf_false(self, mock_check_gguf):
"""Test is_gguf returns False for non-GGUF models."""
assert not is_gguf("unsloth/Qwen3-0.6B")
assert not is_gguf("repo/model")
assert not is_gguf("model")
def test_is_gguf_edge_cases(self):
"""Test is_gguf with edge cases."""
# Empty string
assert not is_gguf("")
# Only colon, no slash (even with valid quant_type)
assert not is_gguf("model:IQ1_S")
# Only slash, no colon
assert not is_gguf("repo/model")
# HTTP/HTTPS URLs
assert not is_gguf("http://repo/model:IQ1_S")
assert not is_gguf("https://repo/model:Q2_K")
# Cloud storage
assert not is_gguf("s3://bucket/repo/model:IQ1_S")
assert not is_gguf("gs://bucket/repo/model:Q2_K")
...@@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import ( ...@@ -39,7 +39,12 @@ from vllm.transformers_utils.gguf_utils import (
maybe_patch_hf_config_from_gguf, maybe_patch_hf_config_from_gguf,
) )
from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri
from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect from vllm.transformers_utils.utils import (
is_gguf,
is_remote_gguf,
maybe_model_redirect,
split_remote_gguf,
)
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
...@@ -440,7 +445,8 @@ class ModelConfig: ...@@ -440,7 +445,8 @@ class ModelConfig:
self.model = maybe_model_redirect(self.model) self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default. # The tokenizer is consistent with the model by default.
if self.tokenizer is None: if self.tokenizer is None:
if check_gguf_file(self.model): # Check if this is a GGUF model (either local file or remote GGUF)
if is_gguf(self.model):
raise ValueError( raise ValueError(
"Using a tokenizer is mandatory when loading a GGUF model. " "Using a tokenizer is mandatory when loading a GGUF model. "
"Please specify the tokenizer path or name using the " "Please specify the tokenizer path or name using the "
...@@ -832,7 +838,10 @@ class ModelConfig: ...@@ -832,7 +838,10 @@ class ModelConfig:
self.tokenizer = object_storage_tokenizer.dir self.tokenizer = object_storage_tokenizer.dir
def _get_encoder_config(self): def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(self.model, self.revision) model = self.model
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
return get_sentence_transformer_tokenizer_config(model, self.revision)
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower()) tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
......
...@@ -86,7 +86,7 @@ from vllm.transformers_utils.config import ( ...@@ -86,7 +86,7 @@ from vllm.transformers_utils.config import (
is_interleaved, is_interleaved,
maybe_override_with_speculators, maybe_override_with_speculators,
) )
from vllm.transformers_utils.utils import check_gguf_file, is_cloud_storage from vllm.transformers_utils.utils import is_cloud_storage, is_gguf
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.network_utils import get_ip from vllm.utils.network_utils import get_ip
...@@ -1148,8 +1148,8 @@ class EngineArgs: ...@@ -1148,8 +1148,8 @@ class EngineArgs:
return engine_args return engine_args
def create_model_config(self) -> ModelConfig: def create_model_config(self) -> ModelConfig:
# gguf file needs a specific model loader and doesn't use hf_repo # gguf file needs a specific model loader
if check_gguf_file(self.model): if is_gguf(self.model):
self.quantization = self.load_format = "gguf" self.quantization = self.load_format = "gguf"
# NOTE(woosuk): In V1, we use separate processes for workers (unless # NOTE(woosuk): In V1, we use separate processes for workers (unless
......
...@@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import ( ...@@ -18,6 +18,7 @@ from vllm.model_executor.model_loader.utils import (
process_weights_after_loading, process_weights_after_loading,
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_gguf,
get_gguf_extra_tensor_names, get_gguf_extra_tensor_names,
get_gguf_weight_type_map, get_gguf_weight_type_map,
gguf_quant_weights_iterator, gguf_quant_weights_iterator,
...@@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -43,7 +44,8 @@ class GGUFModelLoader(BaseModelLoader):
f"load format {load_config.load_format}" f"load format {load_config.load_format}"
) )
def _prepare_weights(self, model_name_or_path: str): def _prepare_weights(self, model_config: ModelConfig):
model_name_or_path = model_config.model
if os.path.isfile(model_name_or_path): if os.path.isfile(model_name_or_path):
return model_name_or_path return model_name_or_path
# for raw HTTPS link # for raw HTTPS link
...@@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -55,12 +57,23 @@ class GGUFModelLoader(BaseModelLoader):
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"): if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1) repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename) return hf_hub_download(repo_id=repo_id, filename=filename)
else: # repo_id:quant_type
raise ValueError( elif "/" in model_name_or_path and ":" in model_name_or_path:
f"Unrecognised GGUF reference: {model_name_or_path} " repo_id, quant_type = model_name_or_path.rsplit(":", 1)
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)" return download_gguf(
repo_id,
quant_type,
cache_dir=self.load_config.download_dir,
revision=model_config.revision,
ignore_patterns=self.load_config.ignore_patterns,
) )
raise ValueError(
f"Unrecognised GGUF reference: {model_name_or_path} "
"(expected local file, raw URL, <repo_id>/<filename>.gguf, "
"or <repo_id>:<quant_type>)"
)
def _get_gguf_weights_map(self, model_config: ModelConfig): def _get_gguf_weights_map(self, model_config: ModelConfig):
""" """
GGUF uses this naming convention for their tensors from HF checkpoint: GGUF uses this naming convention for their tensors from HF checkpoint:
...@@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -244,7 +257,7 @@ class GGUFModelLoader(BaseModelLoader):
gguf_to_hf_name_map: dict[str, str], gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]: ) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map( weight_type_map = get_gguf_weight_type_map(
model_config.model, gguf_to_hf_name_map model_name_or_path, gguf_to_hf_name_map
) )
is_multimodal = hasattr(model_config.hf_config, "vision_config") is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal: if is_multimodal:
...@@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -290,10 +303,10 @@ class GGUFModelLoader(BaseModelLoader):
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None: def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model) self._prepare_weights(model_config)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model) local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights( model.load_weights(
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map) self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
...@@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader): ...@@ -303,7 +316,7 @@ class GGUFModelLoader(BaseModelLoader):
self, vllm_config: VllmConfig, model_config: ModelConfig self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module: ) -> nn.Module:
device_config = vllm_config.device_config device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model) local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights # we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names( if "lm_head.weight" in get_gguf_extra_tensor_names(
......
...@@ -369,6 +369,52 @@ def get_sparse_attention_config( ...@@ -369,6 +369,52 @@ def get_sparse_attention_config(
return config return config
def download_gguf(
repo_id: str,
quant_type: str,
cache_dir: str | None = None,
revision: str | None = None,
ignore_patterns: str | list[str] | None = None,
) -> str:
# Use patterns that snapshot_download can handle directly
# Patterns to match:
# - *-{quant_type}.gguf (root)
# - *-{quant_type}-*.gguf (root sharded)
# - */*-{quant_type}.gguf (subdir)
# - */*-{quant_type}-*.gguf (subdir sharded)
allow_patterns = [
f"*-{quant_type}.gguf",
f"*-{quant_type}-*.gguf",
f"*/*-{quant_type}.gguf",
f"*/*-{quant_type}-*.gguf",
]
# Use download_weights_from_hf which handles caching and downloading
folder = download_weights_from_hf(
model_name_or_path=repo_id,
cache_dir=cache_dir,
allow_patterns=allow_patterns,
revision=revision,
ignore_patterns=ignore_patterns,
)
# Find the downloaded file(s) in the folder
local_files = []
for pattern in allow_patterns:
# Convert pattern to glob pattern for local filesystem
glob_pattern = os.path.join(folder, pattern)
local_files.extend(glob.glob(glob_pattern))
if not local_files:
raise ValueError(
f"Downloaded GGUF files not found in {folder} for quant_type {quant_type}"
)
# Sort to ensure consistent ordering (prefer non-sharded files)
local_files.sort(key=lambda x: (x.count("-"), x))
return local_files[0]
def download_weights_from_hf( def download_weights_from_hf(
model_name_or_path: str, model_name_or_path: str,
cache_dir: str | None, cache_dir: str | None,
......
...@@ -42,7 +42,10 @@ from vllm.logger import init_logger ...@@ -42,7 +42,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import ( from vllm.transformers_utils.utils import (
check_gguf_file, check_gguf_file,
is_gguf,
is_remote_gguf,
parse_safetensors_file_metadata, parse_safetensors_file_metadata,
split_remote_gguf,
) )
if envs.VLLM_USE_MODELSCOPE: if envs.VLLM_USE_MODELSCOPE:
...@@ -629,10 +632,12 @@ def maybe_override_with_speculators( ...@@ -629,10 +632,12 @@ def maybe_override_with_speculators(
Returns: Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config) Tuple of (resolved_model, resolved_tokenizer, speculative_config)
""" """
is_gguf = check_gguf_file(model) if check_gguf_file(model):
if is_gguf:
kwargs["gguf_file"] = Path(model).name kwargs["gguf_file"] = Path(model).name
gguf_model_repo = Path(model).parent gguf_model_repo = Path(model).parent
elif is_remote_gguf(model):
repo_id, _ = split_remote_gguf(model)
gguf_model_repo = Path(repo_id)
else: else:
gguf_model_repo = None gguf_model_repo = None
kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE
...@@ -678,10 +683,18 @@ def get_config( ...@@ -678,10 +683,18 @@ def get_config(
) -> PretrainedConfig: ) -> PretrainedConfig:
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(model) _is_gguf = is_gguf(model)
if is_gguf: _is_remote_gguf = is_remote_gguf(model)
kwargs["gguf_file"] = Path(model).name if _is_gguf:
model = Path(model).parent if check_gguf_file(model):
# Local GGUF file
kwargs["gguf_file"] = Path(model).name
model = Path(model).parent
elif _is_remote_gguf:
# Remote GGUF - extract repo_id from repo_id:quant_type format
# The actual GGUF file will be downloaded later by GGUFModelLoader
# Keep model as repo_id:quant_type for download, but use repo_id for config
model, _ = split_remote_gguf(model)
if config_format == "auto": if config_format == "auto":
try: try:
...@@ -689,10 +702,25 @@ def get_config( ...@@ -689,10 +702,25 @@ def get_config(
# Transformers implementation. # Transformers implementation.
if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): if file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision):
config_format = "mistral" config_format = "mistral"
elif is_gguf or file_or_path_exists( elif (_is_gguf and not _is_remote_gguf) or file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision model, HF_CONFIG_NAME, revision=revision
): ):
config_format = "hf" config_format = "hf"
# Remote GGUF models must have config.json in repo,
# otherwise the config can't be parsed correctly.
# FIXME(Isotr0py): Support remote GGUF repos without config.json
elif _is_remote_gguf and not file_or_path_exists(
model, HF_CONFIG_NAME, revision=revision
):
err_msg = (
"Could not find config.json for remote GGUF model repo. "
"To load remote GGUF model through `<repo_id>:<quant_type>`, "
"ensure your model has config.json (HF format) file. "
"Otherwise please specify --hf-config-path <original_repo> "
"in engine args to fetch config from unquantized hf model."
)
logger.error(err_msg)
raise ValueError(err_msg)
else: else:
raise ValueError( raise ValueError(
"Could not detect config format for no config file found. " "Could not detect config format for no config file found. "
...@@ -713,9 +741,6 @@ def get_config( ...@@ -713,9 +741,6 @@ def get_config(
"'config.json'.\n" "'config.json'.\n"
" - For Mistral models: ensure the presence of a " " - For Mistral models: ensure the presence of a "
"'params.json'.\n" "'params.json'.\n"
"3. For GGUF: pass the local path of the GGUF checkpoint.\n"
" Loading GGUF from a remote repo directly is not yet "
"supported.\n"
).format(model=model) ).format(model=model)
raise ValueError(error_message) from e raise ValueError(error_message) from e
...@@ -729,7 +754,7 @@ def get_config( ...@@ -729,7 +754,7 @@ def get_config(
**kwargs, **kwargs,
) )
# Special architecture mapping check for GGUF models # Special architecture mapping check for GGUF models
if is_gguf: if _is_gguf:
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
raise RuntimeError(f"Can't get gguf config for {config.model_type}.") raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
...@@ -889,6 +914,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None ...@@ -889,6 +914,8 @@ def get_pooling_config(model: str, revision: str | None = "main") -> dict | None
A dictionary containing the pooling type and whether A dictionary containing the pooling type and whether
normalization is used, or None if no pooling configuration is found. normalization is used, or None if no pooling configuration is found.
""" """
if is_remote_gguf(model):
model, _ = split_remote_gguf(model)
modules_file_name = "modules.json" modules_file_name = "modules.json"
...@@ -1108,6 +1135,8 @@ def get_hf_image_processor_config( ...@@ -1108,6 +1135,8 @@ def get_hf_image_processor_config(
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
if check_gguf_file(model): if check_gguf_file(model):
model = Path(model).parent model = Path(model).parent
elif is_remote_gguf(model):
model, _ = split_remote_gguf(model)
return get_image_processor_config( return get_image_processor_config(
model, token=hf_token, revision=revision, **kwargs model, token=hf_token, revision=revision, **kwargs
) )
......
...@@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin ...@@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor from transformers.video_processing_utils import BaseVideoProcessor
from typing_extensions import TypeVar from typing_extensions import TypeVar
from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path from vllm.transformers_utils.utils import convert_model_repo_to_path, is_gguf
from vllm.utils.func_utils import get_allowed_kwarg_only_overrides from vllm.utils.func_utils import get_allowed_kwarg_only_overrides
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -236,8 +236,8 @@ def cached_processor_from_config( ...@@ -236,8 +236,8 @@ def cached_processor_from_config(
processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin,
**kwargs: Any, **kwargs: Any,
) -> _P: ) -> _P:
if check_gguf_file(model_config.model): if is_gguf(model_config.model):
assert not check_gguf_file(model_config.tokenizer), ( assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer " "For multimodal GGUF models, the original tokenizer "
"should be used to correctly load processor." "should be used to correctly load processor."
) )
...@@ -350,8 +350,8 @@ def cached_image_processor_from_config( ...@@ -350,8 +350,8 @@ def cached_image_processor_from_config(
model_config: "ModelConfig", model_config: "ModelConfig",
**kwargs: Any, **kwargs: Any,
): ):
if check_gguf_file(model_config.model): if is_gguf(model_config.model):
assert not check_gguf_file(model_config.tokenizer), ( assert not is_gguf(model_config.tokenizer), (
"For multimodal GGUF models, the original tokenizer " "For multimodal GGUF models, the original tokenizer "
"should be used to correctly load image processor." "should be used to correctly load image processor."
) )
......
...@@ -20,7 +20,12 @@ from vllm.transformers_utils.config import ( ...@@ -20,7 +20,12 @@ from vllm.transformers_utils.config import (
list_filtered_repo_files, list_filtered_repo_files,
) )
from vllm.transformers_utils.tokenizers import MistralTokenizer from vllm.transformers_utils.tokenizers import MistralTokenizer
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import (
check_gguf_file,
is_gguf,
is_remote_gguf,
split_remote_gguf,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -180,10 +185,12 @@ def get_tokenizer( ...@@ -180,10 +185,12 @@ def get_tokenizer(
kwargs["truncation_side"] = "left" kwargs["truncation_side"] = "left"
# Separate model folder from file path for GGUF models # Separate model folder from file path for GGUF models
is_gguf = check_gguf_file(tokenizer_name) if is_gguf(tokenizer_name):
if is_gguf: if check_gguf_file(tokenizer_name):
kwargs["gguf_file"] = Path(tokenizer_name).name kwargs["gguf_file"] = Path(tokenizer_name).name
tokenizer_name = Path(tokenizer_name).parent tokenizer_name = Path(tokenizer_name).parent
elif is_remote_gguf(tokenizer_name):
tokenizer_name, _ = split_remote_gguf(tokenizer_name)
# if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format # if `tokenizer_mode` == "auto", check if tokenizer can be loaded via Mistral format
# first to use official Mistral tokenizer if possible. # first to use official Mistral tokenizer if possible.
......
...@@ -9,6 +9,8 @@ from os import PathLike ...@@ -9,6 +9,8 @@ from os import PathLike
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from gguf import GGMLQuantizationType
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool: ...@@ -46,6 +48,57 @@ def check_gguf_file(model: str | PathLike) -> bool:
return False return False
@cache
def is_remote_gguf(model: str | Path) -> bool:
"""Check if the model is a remote GGUF model."""
model = str(model)
return (
(not is_cloud_storage(model))
and (not model.startswith(("http://", "https://")))
and ("/" in model and ":" in model)
and is_valid_gguf_quant_type(model.rsplit(":", 1)[1])
)
def is_valid_gguf_quant_type(gguf_quant_type: str) -> bool:
"""Check if the quant type is a valid GGUF quant type."""
return getattr(GGMLQuantizationType, gguf_quant_type, None) is not None
def split_remote_gguf(model: str | Path) -> tuple[str, str]:
"""Split the model into repo_id and quant type."""
model = str(model)
if is_remote_gguf(model):
parts = model.rsplit(":", 1)
return (parts[0], parts[1])
raise ValueError(
"Wrong GGUF model or invalid GGUF quant type: %s.\n"
"- It should be in repo_id:quant_type format.\n"
"- Valid GGMLQuantizationType values: %s",
model,
GGMLQuantizationType._member_names_,
)
def is_gguf(model: str | Path) -> bool:
"""Check if the model is a GGUF model.
Args:
model: Model name, path, or Path object to check.
Returns:
True if the model is a GGUF model, False otherwise.
"""
model = str(model)
# Check if it's a local GGUF file
if check_gguf_file(model):
return True
# Check if it's a remote GGUF model (repo_id:quant_type format)
return is_remote_gguf(model)
def modelscope_list_repo_files( def modelscope_list_repo_files(
repo_id: str, repo_id: str,
revision: str | None = None, revision: str | None = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment