# Copyright (c) Alibaba, Inc. and its affiliates. import os import re import tempfile from http.cookiejar import CookieJar from pathlib import Path from typing import Dict, List, Optional, Union from swift.utils.logger import get_logger from .api import HubApi, ModelScopeConfig from .constants import FILE_HASH, MODELSCOPE_DOWNLOAD_PARALLELS, MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB from .file_download import get_file_download_url, http_get_file, parallel_download from .utils.caching import ModelFileSystemCache from .utils.utils import file_integrity_validation, get_cache_dir, model_id_to_group_owner_name logger = get_logger() def snapshot_download(model_id: str, revision: Optional[str] = None, cache_dir: Union[str, Path, None] = None, user_agent: Optional[Union[Dict, str]] = None, local_files_only: Optional[bool] = False, cookies: Optional[CookieJar] = None, ignore_file_pattern: List = None) -> str: """Download all files of a repo. Downloads a whole snapshot of a repo's files at the specified revision. This is useful when you want all files from a repo, because you don't know which ones you will need a priori. All files are nested inside a folder in order to keep their actual filename relative to that folder. An alternative would be to just clone a repo but this would require that the user always has git and git-lfs installed, and properly configured. Args: model_id (str): A user or an organization name and a repo name separated by a `/`. revision (str, optional): An optional Git revision id which can be a branch name, a tag, or a commit hash. NOTE: currently only branch and tag name is supported cache_dir (str, Path, optional): Path to the folder where cached files are stored. user_agent (str, dict, optional): The user-agent info in the form of a dictionary or a string. local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the local cached file if it exists. cookies (CookieJar, optional): The cookie of the request, default None. ignore_file_pattern (`str` or `List`, *optional*, default to `None`): Any file pattern to be ignored in downloading, like exact file names or file extensions. Raises: ValueError: the value details. Returns: str: Local folder path (string) of repo snapshot Note: Raises the following errors: - [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) if `use_auth_token=True` and the token cannot be found. - [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) if ETag cannot be determined. - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if some parameter value is invalid """ if cache_dir is None: cache_dir = get_cache_dir() if isinstance(cache_dir, Path): cache_dir = str(cache_dir) temporary_cache_dir = os.path.join(cache_dir, 'temp') os.makedirs(temporary_cache_dir, exist_ok=True) group_or_owner, name = model_id_to_group_owner_name(model_id) cache = ModelFileSystemCache(cache_dir, group_or_owner, name) if local_files_only: if len(cache.cached_files) == 0: raise ValueError('Cannot find the requested files in the cached path and outgoing' ' traffic has been disabled. To enable model look-ups and downloads' " online, set 'local_files_only' to False.") logger.warning('We can not confirm the cached file is for revision: %s' % revision) return cache.get_root_location() # we can not confirm the cached file is for snapshot 'revision' else: # make headers headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, )} _api = HubApi() if cookies is None: cookies = ModelScopeConfig.get_cookies() revision = _api.get_valid_revision(model_id, revision=revision, cookies=cookies) snapshot_header = headers if 'CI_TEST' in os.environ else {**headers, **{'Snapshot': 'True'}} model_files = _api.get_model_files( model_id=model_id, revision=revision, recursive=True, use_cookies=False if cookies is None else cookies, headers=snapshot_header, ) if ignore_file_pattern is None: ignore_file_pattern = [] if isinstance(ignore_file_pattern, str): ignore_file_pattern = [ignore_file_pattern] with tempfile.TemporaryDirectory(dir=temporary_cache_dir) as temp_cache_dir: for model_file in model_files: if model_file['Type'] == 'tree' or \ any([re.search(pattern, model_file['Name']) is not None for pattern in ignore_file_pattern]): continue # check model_file is exist in cache, if existed, skip download, otherwise download if cache.exists(model_file): file_name = os.path.basename(model_file['Name']) logger.debug(f'File {file_name} already in cache, skip downloading!') continue # get download url url = get_file_download_url(model_id=model_id, file_path=model_file['Path'], revision=revision) if MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB * 1000 * 1000 < model_file[ 'Size'] and MODELSCOPE_DOWNLOAD_PARALLELS > 1: parallel_download( url, temp_cache_dir, model_file['Name'], headers=headers, cookies=None if cookies is None else cookies.get_dict(), file_size=model_file['Size']) else: http_get_file(url, temp_cache_dir, model_file['Name'], headers=headers, cookies=cookies) # check file integrity temp_file = os.path.join(temp_cache_dir, model_file['Name']) if FILE_HASH in model_file: file_integrity_validation(temp_file, model_file[FILE_HASH]) # put file to cache cache.put_file(model_file, temp_file) return os.path.join(cache.get_root_location())