Unverified Commit e568cf88 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[UX] Infer dtype for local checkpoint (#36218)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 098d8447
...@@ -1116,7 +1116,7 @@ def get_safetensors_params_metadata( ...@@ -1116,7 +1116,7 @@ def get_safetensors_params_metadata(
revision: str | None = None, revision: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Get the safetensors metadata for remote model repository. Get the safetensors parameters metadata for remote/local model repository.
""" """
full_metadata = {} full_metadata = {}
if (model_path := Path(model)).exists(): if (model_path := Path(model)).exists():
......
...@@ -18,7 +18,7 @@ from vllm.config.utils import getattr_iter ...@@ -18,7 +18,7 @@ from vllm.config.utils import getattr_iter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import ( from vllm.transformers_utils.config import (
ConfigFormat, ConfigFormat,
try_get_safetensors_metadata, get_safetensors_params_metadata,
) )
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
...@@ -165,14 +165,14 @@ class ModelArchConfigConvertorBase: ...@@ -165,14 +165,14 @@ class ModelArchConfigConvertorBase:
# Try to read the dtype of the weights if they are in safetensors format # Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None: if config_dtype is None:
with _maybe_patch_hf_hub_constants(config_format): with _maybe_patch_hf_hub_constants(config_format):
repo_mt = try_get_safetensors_metadata(model_id, revision=revision) param_mt = get_safetensors_params_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata): if param_mt:
param_dtypes: set[torch.dtype] = { param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str] _SAFETENSORS_TO_TORCH_DTYPE[dtype]
for file_mt in files_mt.values() for info in param_mt.values()
for dtype_str in file_mt.parameter_count if (dtype := info.get("dtype", None))
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE and dtype in _SAFETENSORS_TO_TORCH_DTYPE
} }
if param_dtypes: if param_dtypes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment