Unverified Commit 90cde2e9 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Add Mirror Option for Downloads (#6679)

* Add Tuna Mirror for Downloads from China

* format fix

* Use preset instead of hardcoding URL

* Fix

* make style

* update the mirror option doc

* update the mirror
parent e0e0675a
...@@ -337,7 +337,9 @@ class PretrainedConfig(object): ...@@ -337,7 +337,9 @@ class PretrainedConfig(object):
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
else: else:
config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False) config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False, mirror=None
)
try: try:
# Load from URL or cache if already cached # Load from URL or cache if already cached
......
...@@ -141,6 +141,10 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] ...@@ -141,6 +141,10 @@ DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
PRESET_MIRROR_DICT = {
"tuna": "https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models",
"bfsu": "https://mirrors.bfsu.edu.cn/hugging-face-models",
}
def is_torch_available(): def is_torch_available():
...@@ -570,7 +574,7 @@ def is_remote_url(url_or_filename): ...@@ -570,7 +574,7 @@ def is_remote_url(url_or_filename):
return parsed.scheme in ("http", "https") return parsed.scheme in ("http", "https")
def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: def hf_bucket_url(model_id: str, filename: str, use_cdn=True, mirror=None) -> str:
""" """
Resolve a model identifier, and a file name, to a HF-hosted url Resolve a model identifier, and a file name, to a HF-hosted url
on either S3 or Cloudfront (a Content Delivery Network, or CDN). on either S3 or Cloudfront (a Content Delivery Network, or CDN).
...@@ -586,7 +590,13 @@ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str: ...@@ -586,7 +590,13 @@ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
are not shared between the two because the cached file's name contains are not shared between the two because the cached file's name contains
a hash of the url. a hash of the url.
""" """
endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX endpoint = (
PRESET_MIRROR_DICT.get(mirror, mirror)
if mirror
else CLOUDFRONT_DISTRIB_PREFIX
if use_cdn
else S3_BUCKET_PREFIX
)
legacy_format = "/" not in model_id legacy_format = "/" not in model_id
if legacy_format: if legacy_format:
return f"{endpoint}/{model_id}-{filename}" return f"{endpoint}/{model_id}-{filename}"
......
...@@ -139,7 +139,9 @@ class ModelCard: ...@@ -139,7 +139,9 @@ class ModelCard:
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path model_card_file = pretrained_model_name_or_path
else: else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False) model_card_file = hf_bucket_url(
pretrained_model_name_or_path, filename=MODEL_CARD_NAME, use_cdn=False, mirror=None
)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP: if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME) model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
......
...@@ -484,6 +484,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -484,6 +484,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
...@@ -522,6 +526,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -522,6 +526,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
...@@ -564,6 +569,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -564,6 +569,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME),
use_cdn=use_cdn, use_cdn=use_cdn,
mirror=mirror,
) )
try: try:
......
...@@ -784,6 +784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -784,6 +784,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`): use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
mirror(:obj:`str`, `optional`, defaults to :obj:`None`):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility problem,
you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. Please
refer to the mirror site for more information.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
...@@ -822,6 +826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -822,6 +826,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
output_loading_info = kwargs.pop("output_loading_info", False) output_loading_info = kwargs.pop("output_loading_info", False)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_cdn = kwargs.pop("use_cdn", True) use_cdn = kwargs.pop("use_cdn", True)
mirror = kwargs.pop("mirror", None)
# Load config if we don't provide a configuration # Load config if we don't provide a configuration
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
...@@ -873,6 +878,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -873,6 +878,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME),
use_cdn=use_cdn, use_cdn=use_cdn,
mirror=mirror,
) )
try: try:
......
...@@ -1483,7 +1483,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ...@@ -1483,7 +1483,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
full_file_name = None full_file_name = None
else: else:
full_file_name = hf_bucket_url( full_file_name = hf_bucket_url(
pretrained_model_name_or_path, filename=file_name, use_cdn=False pretrained_model_name_or_path, filename=file_name, use_cdn=False, mirror=None
) )
vocab_files[file_id] = full_file_name vocab_files[file_id] = full_file_name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment