Unverified Commit 90cde2e9 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Add Mirror Option for Downloads (#6679)

* Add Tuna Mirror for Downloads from China

* format fix

* Use preset instead of hardcoding URL

* Fix

* make style

* update the mirror option doc

* update the mirror
parent e0e0675a
......@@ -337,7 +337,9 @@ class PretrainedConfig(object):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
)
try:
# Load from URL or cache if already cached
......
......@@ -141,6 +141,10 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
def is_torch_available():
......@@ -570,7 +574,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
"""
Resolve a model identifier, and a file name, to a HF-hosted url
on either S3 or Cloudfront (a Content Delivery Network, or CDN).
......@@ -586,7 +590,13 @@ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
are not shared between the two because the cached file's name contains
a hash of the url.
"""
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
endpoint = (
PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror
else CLOUDFRONT_DISTRIB_PREFIX
if use_cdn
else S3_BUCKET_PREFIX
)
legacy_format = "/" not in model_id
if legacy_format:
return f"{endpoint}/{model_id}-{filename}"
......
......@@ -139,7 +139,9 @@ class ModelCard:
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False)
model_card_file = hf_bucket_url(
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
......
......@@ -484,6 +484,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
......@@ -522,6 +526,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
......@@ -564,6 +569,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
use_cdn=use_cdn,
mirror=mirror,
)
try:
......
......@@ -784,6 +784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
......@@ -822,6 +826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig):
......@@ -873,6 +878,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn,
mirror=mirror,
)
try:
......
......@@ -1483,7 +1483,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
full_file_name = None
else:
full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, use_cdn=False
pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
)
vocab_files[file_id] = full_file_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment