Unverified Commit 3bb6356d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[From pretrained] Allow download from subfolder inside model repo (#18184)



* add first generation tutorial

* [from_pretrained] Allow loading models from subfolders

* remove gen file

* add doc strings

* allow download from subfolder

* add tests

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply comments

* correct doc string
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent ce015281
...@@ -494,6 +494,9 @@ class PretrainedConfig(PushToHubMixin): ...@@ -494,6 +494,9 @@ class PretrainedConfig(PushToHubMixin):
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
part of `kwargs` which has not been used to update `config` and is otherwise ignored. part of `kwargs` which has not been used to update `config` and is otherwise ignored.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (`Dict[str, Any]`, *optional*): kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
...@@ -577,6 +580,7 @@ class PretrainedConfig(PushToHubMixin): ...@@ -577,6 +580,7 @@ class PretrainedConfig(PushToHubMixin):
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", "")
from_pipeline = kwargs.pop("_from_pipeline", None) from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False) from_auto_class = kwargs.pop("_from_auto", False)
...@@ -589,16 +593,22 @@ class PretrainedConfig(PushToHubMixin): ...@@ -589,16 +593,22 @@ class PretrainedConfig(PushToHubMixin):
local_files_only = True local_files_only = True
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
config_file = pretrained_model_name_or_path config_file = pretrained_model_name_or_path
else: else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(os.path.join(pretrained_model_name_or_path, subfolder)):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file) config_file = os.path.join(pretrained_model_name_or_path, subfolder, configuration_file)
else: else:
config_file = hf_bucket_url( config_file = hf_bucket_url(
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None pretrained_model_name_or_path,
filename=configuration_file,
revision=revision,
subfolder=subfolder if len(subfolder) > 0 else None,
mirror=None,
) )
try: try:
......
...@@ -1691,6 +1691,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1691,6 +1691,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
`True` when there is some disk offload. `True` when there is some disk offload.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here.
kwargs (remaining dictionary of keyword arguments, *optional*): kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
...@@ -1777,6 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1777,6 +1780,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
max_memory = kwargs.pop("max_memory", None) max_memory = kwargs.pop("max_memory", None)
offload_folder = kwargs.pop("offload_folder", None) offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", None) offload_state_dict = kwargs.pop("offload_state_dict", None)
subfolder = kwargs.pop("subfolder", "")
if device_map is not None: if device_map is not None:
if low_cpu_mem_usage is None: if low_cpu_mem_usage is None:
...@@ -1820,6 +1824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1820,6 +1824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class, _from_auto=from_auto_class,
_from_pipeline=from_pipeline, _from_pipeline=from_pipeline,
**kwargs, **kwargs,
...@@ -1837,32 +1842,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1837,32 +1842,38 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if pretrained_model_name_or_path is not None: if pretrained_model_name_or_path is not None:
pretrained_model_name_or_path = str(pretrained_model_name_or_path) pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")): if from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
):
# Load from a TF 1.0 checkpoint in priority if from_tf # Load from a TF 1.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): elif from_tf and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
):
# Load from a TF 2.0 checkpoint in priority if from_tf # Load from a TF 2.0 checkpoint in priority if from_tf
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
elif from_flax and os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): elif from_flax and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
):
# Load from a Flax checkpoint in priority if from_flax # Load from a Flax checkpoint in priority if from_flax
archive_file = os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)):
# Load from a PyTorch checkpoint # Load from a PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_NAME)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)):
# Load from a sharded PyTorch checkpoint # Load from a sharded PyTorch checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_INDEX_NAME) archive_file = os.path.join(pretrained_model_name_or_path, subfolder, WEIGHTS_INDEX_NAME)
is_sharded = True is_sharded = True
# At this stage we don't have a weight file so we will raise an error. # At this stage we don't have a weight file so we will raise an error.
elif os.path.isfile( elif os.path.isfile(
os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index") os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
) or os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but " f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those " "there is a file for TensorFlow weights. Use `from_tf=True` to load this model from those "
"weights." "weights."
) )
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME)): elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)):
raise EnvironmentError( raise EnvironmentError(
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but " f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path} but "
"there is a file for Flax weights. Use `from_flax=True` to load this model from those " "there is a file for Flax weights. Use `from_flax=True` to load this model from those "
...@@ -1873,15 +1884,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1873,15 +1884,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or " f"Error no file named {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or "
f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." f"{FLAX_WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
) )
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)) or is_remote_url(
pretrained_model_name_or_path
):
archive_file = pretrained_model_name_or_path archive_file = pretrained_model_name_or_path
elif os.path.isfile(pretrained_model_name_or_path + ".index"): elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
if not from_tf: if not from_tf:
raise ValueError( raise ValueError(
f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
"from_tf to True to load from this checkpoint." "from_tf to True to load from this checkpoint."
) )
archive_file = pretrained_model_name_or_path + ".index" archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
else: else:
# set correct filename # set correct filename
if from_tf: if from_tf:
...@@ -1892,7 +1905,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1892,7 +1905,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename = WEIGHTS_NAME filename = WEIGHTS_NAME
archive_file = hf_bucket_url( archive_file = hf_bucket_url(
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror pretrained_model_name_or_path,
filename=filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
) )
try: try:
...@@ -1930,6 +1947,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1930,6 +1947,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
filename=WEIGHTS_INDEX_NAME, filename=WEIGHTS_INDEX_NAME,
revision=revision, revision=revision,
mirror=mirror, mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
) )
resolved_archive_file = cached_path( resolved_archive_file = cached_path(
archive_file, archive_file,
...@@ -2016,6 +2034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2016,6 +2034,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
user_agent=user_agent, user_agent=user_agent,
revision=revision, revision=revision,
mirror=mirror, mirror=mirror,
subfolder=subfolder,
) )
# load pt weights early so that we know which dtype to init the model under # load pt weights early so that we know which dtype to init the model under
......
...@@ -1142,6 +1142,7 @@ def get_checkpoint_shard_files( ...@@ -1142,6 +1142,7 @@ def get_checkpoint_shard_files(
user_agent=None, user_agent=None,
revision=None, revision=None,
mirror=None, mirror=None,
subfolder="",
): ):
""" """
For a given model: For a given model:
...@@ -1167,14 +1168,18 @@ def get_checkpoint_shard_files( ...@@ -1167,14 +1168,18 @@ def get_checkpoint_shard_files(
# First, let's deal with local folder. # First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
shard_filenames = [os.path.join(pretrained_model_name_or_path, f) for f in shard_filenames] shard_filenames = [os.path.join(pretrained_model_name_or_path, subfolder, f) for f in shard_filenames]
return shard_filenames, sharded_metadata return shard_filenames, sharded_metadata
# At this stage pretrained_model_name_or_path is a model identifier on the Hub # At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = [] cached_filenames = []
for shard_filename in shard_filenames: for shard_filename in shard_filenames:
shard_url = hf_bucket_url( shard_url = hf_bucket_url(
pretrained_model_name_or_path, filename=shard_filename, revision=revision, mirror=mirror pretrained_model_name_or_path,
filename=shard_filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
) )
try: try:
......
...@@ -157,6 +157,17 @@ class ConfigTester(object): ...@@ -157,6 +157,17 @@ class ConfigTester(object):
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_from_and_save_pretrained_subfolder(self):
config_first = self.config_class(**self.inputs_dict)
subfolder = "test"
with tempfile.TemporaryDirectory() as tmpdirname:
sub_tmpdirname = os.path.join(tmpdirname, subfolder)
config_first.save_pretrained(sub_tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname, subfolder=subfolder)
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_with_num_labels(self): def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5) config = self.config_class(**self.inputs_dict, num_labels=5)
self.parent.assertEqual(len(config.id2label), 5) self.parent.assertEqual(len(config.id2label), 5)
...@@ -197,6 +208,7 @@ class ConfigTester(object): ...@@ -197,6 +208,7 @@ class ConfigTester(object):
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_from_and_save_pretrained_subfolder()
self.create_and_test_config_with_num_labels() self.create_and_test_config_with_num_labels()
self.check_config_can_be_init_without_params() self.check_config_can_be_init_without_params()
self.check_config_arguments_init() self.check_config_arguments_init()
...@@ -308,6 +320,15 @@ class ConfigTestUtils(unittest.TestCase): ...@@ -308,6 +320,15 @@ class ConfigTestUtils(unittest.TestCase):
f" {', '.join(keys_with_defaults)}." f" {', '.join(keys_with_defaults)}."
) )
def test_from_pretrained_subfolder(self):
with self.assertRaises(OSError):
# config is in subfolder, the following should not work without specifying the subfolder
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
self.assertIsNotNone(config)
def test_cached_files_are_used_when_internet_is_down(self): def test_cached_files_are_used_when_internet_is_down(self):
# A mock response for an HTTP head request to emulate server down # A mock response for an HTTP head request to emulate server down
response_mock = mock.Mock() response_mock = mock.Mock()
......
...@@ -2503,6 +2503,15 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): ...@@ -2503,6 +2503,15 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous() return torch.tensor(data=values, dtype=torch.float, device=torch_device).view(shape).contiguous()
def check_models_equal(model1, model2):
models_are_equal = True
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
if model1_p.data.ne(model2_p.data).sum() > 0:
models_are_equal = False
return models_are_equal
@require_torch @require_torch
class ModelUtilsTest(TestCasePlus): class ModelUtilsTest(TestCasePlus):
@slow @slow
...@@ -2531,6 +2540,56 @@ class ModelUtilsTest(TestCasePlus): ...@@ -2531,6 +2540,56 @@ class ModelUtilsTest(TestCasePlus):
self.assertEqual(model.config.output_hidden_states, True) self.assertEqual(model.config.output_hidden_states, True)
self.assertEqual(model.config, config) self.assertEqual(model.config, config)
def test_model_from_pretrained_subfolder(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
subfolder = "bert"
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(os.path.join(tmp_dir, subfolder))
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_subfolder_sharded(self):
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
model = BertModel(config)
subfolder = "bert"
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(tmp_dir)
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
self.assertTrue(check_models_equal(model, model_loaded))
def test_model_from_pretrained_hub_subfolder(self):
subfolder = "bert"
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(model_id)
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model)
def test_model_from_pretrained_hub_subfolder_sharded(self):
subfolder = "bert"
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
with self.assertRaises(OSError):
_ = BertModel.from_pretrained(model_id)
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
self.assertIsNotNone(model)
def test_model_from_pretrained_with_different_pretrained_model_name(self): def test_model_from_pretrained_with_different_pretrained_model_name(self):
model = T5ForConditionalGeneration.from_pretrained(TINY_T5) model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
self.assertIsNotNone(model) self.assertIsNotNone(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment