Unverified Commit a9bd5df1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a progress bar for the total download of shards (#22062)

* Add a progress bar for the total download of shards

* Check for no cache at all

* Fix check
parent 1a5fc300
......@@ -239,6 +239,7 @@ class DetrConfig(PretrainedConfig):
@classmethod
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
Args:
backbone_config ([`PretrainedConfig`]):
The backbone configuration.
......
......@@ -390,7 +390,7 @@ def cached_file(
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if _commit_hash is not None:
if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
......@@ -913,7 +913,13 @@ def get_checkpoint_shard_files(
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
for shard_filename in shard_filenames:
# Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
# downloaded (if interrupted).
last_shard = try_to_load_from_cache(
pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
)
show_progress_bar = last_shard is None or force_download
for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
try:
# Load from URL
cached_filename = cached_file(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment