Unverified Commit 90763256 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[BugFix] Don't scan entire cache dir when loading model (#13302)

parent 97a3d6d9
......@@ -15,8 +15,7 @@ import gguf
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import (HfFileSystem, hf_hub_download, scan_cache_dir,
snapshot_download)
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
......@@ -239,7 +238,8 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
if not huggingface_hub.constants.HF_HUB_OFFLINE:
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
if not local_only:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
......@@ -255,7 +255,6 @@ def download_weights_from_hf(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_size = scan_cache_dir().size_on_disk
start_time = time.perf_counter()
hf_folder = snapshot_download(
model_name_or_path,
......@@ -264,13 +263,12 @@ def download_weights_from_hf(
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
local_files_only=local_only,
)
end_time = time.perf_counter()
end_size = scan_cache_dir().size_on_disk
if end_size != start_size:
logger.info("Time took to download weights for %s: %.6f seconds",
model_name_or_path, end_time - start_time)
time_taken = time.perf_counter() - start_time
if time_taken > 0.5:
logger.info("Time spent downloading weights for %s: %.6f seconds",
model_name_or_path, time_taken)
return hf_folder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment