Unverified Commit a9bd5df1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add a progress bar for the total download of shards (#22062)

* Add a progress bar for the total download of shards

* Check for no cache at all

* Fix check
parent 1a5fc300
...@@ -239,6 +239,7 @@ class DetrConfig(PretrainedConfig): ...@@ -239,6 +239,7 @@ class DetrConfig(PretrainedConfig):
@classmethod @classmethod
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs): def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration. """Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
Args: Args:
backbone_config ([`PretrainedConfig`]): backbone_config ([`PretrainedConfig`]):
The backbone configuration. The backbone configuration.
......
...@@ -390,7 +390,7 @@ def cached_file( ...@@ -390,7 +390,7 @@ def cached_file(
if isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if _commit_hash is not None: if _commit_hash is not None and not force_download:
# If the file is cached under that commit hash, we return it directly. # If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache( resolved_file = try_to_load_from_cache(
path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash path_or_repo_id, full_filename, cache_dir=cache_dir, revision=_commit_hash
...@@ -913,7 +913,13 @@ def get_checkpoint_shard_files( ...@@ -913,7 +913,13 @@ def get_checkpoint_shard_files(
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = [] cached_filenames = []
for shard_filename in shard_filenames: # Check if the model is already cached or not. We only try the last checkpoint, this should cover most cases of
# downloaded (if interrupted).
last_shard = try_to_load_from_cache(
pretrained_model_name_or_path, shard_filenames[-1], cache_dir=cache_dir, revision=_commit_hash
)
show_progress_bar = last_shard is None or force_download
for shard_filename in tqdm(shard_filenames, desc="Downloading shards", disable=not show_progress_bar):
try: try:
# Load from URL # Load from URL
cached_filename = cached_file( cached_filename = cached_file(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment