Unverified Commit e568cf88 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[UX] Infer dtype for local checkpoint (#36218)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 098d8447
......@@ -1116,7 +1116,7 @@ def get_safetensors_params_metadata(
revision: str | None = None,
) -> dict[str, Any]:
"""
Get the safetensors metadata for remote model repository.
Get the safetensors parameters metadata for remote/local model repository.
"""
full_metadata = {}
if (model_path := Path(model)).exists():
......
......@@ -18,7 +18,7 @@ from vllm.config.utils import getattr_iter
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
ConfigFormat,
try_get_safetensors_metadata,
get_safetensors_params_metadata,
)
from vllm.utils.torch_utils import common_broadcastable_dtype
......@@ -165,14 +165,14 @@ class ModelArchConfigConvertorBase:
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
with _maybe_patch_hf_hub_constants(config_format):
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
if param_mt:
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
for file_mt in files_mt.values()
for dtype_str in file_mt.parameter_count
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
_SAFETENSORS_TO_TORCH_DTYPE[dtype]
for info in param_mt.values()
if (dtype := info.get("dtype", None))
and dtype in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment